
Vision Transformers for Image Classification: A Deep Dive
A guide to training and using the ViT Architecture from an applied perspective



- Capturing real-world variability: A diverse dataset helps ensure that our model can handle the wide range of images it might encounter in real-world scenarios. This includes variations in lighting conditions, angles, resolutions, and subject compositions.
- Mitigating bias: In sensitive classification tasks, it's crucial to have a dataset that represents different demographics, contexts, and edge cases. This helps prevent the model from learning and perpetuating harmful biases.
- Improving generalization: A well-rounded dataset challenges the model to learn robust features rather than superficial correlations, leading to better performance on unseen data.
- Legal and ethical considerations: When dealing with sensitive content, it's important to ensure that the dataset is collected and used in compliance with relevant regulations and ethical guidelines.
- Image resizing and patching: ViTs typically expect input images of a fixed size (e.g., 224x224 pixels). We resize our images to this dimension, ensuring we maintain aspect ratios by either padding or center-cropping as needed.
- Patch extraction: Unlike CNNs, ViTs process images as sequences of patches. We implement a patching mechanism that divides each image into fixed-size patches (e.g., 16x16 pixels), which will serve as the input to our model.
- Normalization: We normalize the pixel values of our images to a standard range (typically -1 to 1 or 0 to 1) to ensure consistent inputs to the model.
- Data augmentation: To increase the effective size of our dataset and improve model robustness, we apply various augmentation techniques such as random horizontal flips, slight rotations, and color jittering. However, we're careful to ensure these augmentations don't alter the semantic content of sensitive images.
- Tokenization: We convert our image patches into the appropriate input format expected by the ViT model, typically involving flattening and linear projection.
- Assessing class distribution: We begin by analyzing the distribution of classes in our dataset. In sensitive content classification, it's common to have a significant imbalance, with non-sensitive content typically outnumbering sensitive content.
- Oversampling minority classes: To address imbalance, we implement oversampling techniques for the minority class (sensitive content). This could involve simple replication or more advanced methods like SMOTE (Synthetic Minority Over-sampling Technique) adapted for image data.
- Undersampling majority classes: In conjunction with oversampling, we may also undersample the majority class to achieve a more balanced distribution.
- Class weighting: We adjust the loss function to assign higher weights to the minority class, ensuring the model pays more attention to these less frequent but critical examples.
- Stratified sampling: When splitting our data into training, validation, and test sets, we use stratified sampling to maintain the class distribution across all sets.
- Data generation: For extremely sensitive or rare categories, we explore the possibility of generating synthetic data using techniques like style transfer or GANs, always ensuring the synthetic data adheres to ethical guidelines.
- Monitoring and iterative refinement: Throughout the training process, we continuously monitor the model's performance across all classes, refining our balancing strategies as needed.

- Precision and Recall: These are crucial for sensitive content detection. We want high precision to avoid false positives (incorrectly flagging safe content), and high recall to catch as much sensitive content as possible.
- F1 Score: This gives us a balanced view of precision and recall in a single metric.
- Area Under the ROC Curve (AUC-ROC): This helps us understand how well our model distinguishes between classes across various threshold settings.
- Matthews Correlation Coefficient (MCC): Particularly useful for imbalanced datasets, MCC gives a balanced measure of the quality of binary classifications.
- We implemented a weighted F1 score that accounted for class imbalance, giving more importance to the minority class (sensitive content).
- We created a custom threshold-finding function. Instead of using the default 0.5 threshold for binary classification, we searched for the optimal threshold that balanced precision and recall for our specific use case.
- We also implemented a function to calculate the model's confidence calibration, ensuring that the model's predicted probabilities align well with actual correctness.

- Confusion Matrix: We used seaborn to create a heatmap of our confusion matrix, giving us a clear view of where our model was succeeding and where it was struggling.
- ROC and Precision-Recall Curves: These plots helped us visualize the trade-off between true positive rate and false positive rate, and between precision and recall, respectively.
- Sample Predictions: We created a grid of sample images with their true labels, predicted labels, and confidence scores. This was particularly helpful for understanding the types of images our model found challenging.
- Attention Maps: Leveraging the interpretability of ViTs, we visualized attention maps for sample predictions, showing which parts of the image the model focused on for its decisions.
Any opinions in this post are those of the individual author and may not reflect the opinions of AWS.