Image Classification Deep Dive
Image classification is the task of assigning a label to an entire image from a predefined set of categories. It is one of the foundational problems in computer vision and the benchmark that drove much of the deep learning revolution.
The Evolution of CNN Architectures
The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) was the proving ground for deep learning in vision. Let's trace the architectural evolution:
LeNet-5 (1998)
Yann LeCun's pioneering architecture for handwritten digit recognition:AlexNet (2012)
The architecture that ignited the deep learning revolution by winning ILSVRC 2012 with a massive margin:VGGNet (2014)
Showed that deeper networks with small filters outperform shallower networks with large filters:GoogLeNet / Inception (2014)
Introduced the Inception module â parallel convolutions at multiple scales:ResNet (2015)
The most influential architecture in deep learning history. Introduced residual connections:Residual Connections: Why They Work
EfficientNet (2019)
Introduced compound scaling â systematically scaling depth, width, and resolution together:ConvNeXt (2022)
"A ConvNet for the 2020s" â modernized ResNet to compete with Vision Transformers:Data Augmentation
Data augmentation artificially increases the effective size of your training set. It is one of the most important techniques for building robust classifiers.
Standard Augmentations
| Augmentation | Description | When to Use |
|---|---|---|
| Random Horizontal Flip | Mirror the image left-right | Most natural image tasks |
| Random Crop | Crop a random region and resize | Almost always |
| Color Jitter | Randomly adjust brightness, contrast, saturation, hue | When lighting varies |
| Random Rotation | Rotate by a random angle | When orientation varies |
| Random Erasing / Cutout | Mask out a random rectangle | Occlusion robustness |
Advanced Augmentations
Augmentation Best Practice
Training Strategies
Modern training recipes combine several techniques for maximum performance:
Learning Rate Warm-up
Start with a very small learning rate and linearly increase it to the target LR over the first few epochs. This stabilizes training, especially with large batch sizes.Cosine Annealing
After warm-up, decay the learning rate following a cosine curve. This provides a smooth, gradual decay with a natural "restart" capability:$$\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + \cos(\frac{t}{T}\pi))$$
Label Smoothing
Instead of hard targets (0 or 1), use soft targets like 0.1/K for wrong classes and 1 - 0.1 + 0.1/K for the correct class. This prevents the model from becoming overconfident and improves generalization.Other Key Techniques
1import torch
2import torch.nn as nn
3import torchvision
4import torchvision.transforms as T
5from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
6
7# ==============================================================
8# Data augmentation pipeline for CIFAR-10
9# ==============================================================
10train_transform = T.Compose([
11 T.RandomCrop(32, padding=4),
12 T.RandomHorizontalFlip(),
13 T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
14 T.RandAugment(num_ops=2, magnitude=9), # RandAugment
15 T.ToTensor(),
16 T.Normalize(mean=[0.4914, 0.4822, 0.4465],
17 std=[0.2470, 0.2435, 0.2616]),
18 T.RandomErasing(p=0.25), # Cutout-style augmentation
19])
20
21val_transform = T.Compose([
22 T.ToTensor(),
23 T.Normalize(mean=[0.4914, 0.4822, 0.4465],
24 std=[0.2470, 0.2435, 0.2616]),
25])
26
27# Load CIFAR-10
28train_dataset = torchvision.datasets.CIFAR10(
29 root="./data", train=True, download=True, transform=train_transform
30)
31val_dataset = torchvision.datasets.CIFAR10(
32 root="./data", train=False, download=True, transform=val_transform
33)
34
35train_loader = torch.utils.data.DataLoader(
36 train_dataset, batch_size=128, shuffle=True, num_workers=2
37)
38val_loader = torch.utils.data.DataLoader(
39 val_dataset, batch_size=256, shuffle=False, num_workers=2
40)
41
42# ==============================================================
43# Model: ResNet-18 adapted for CIFAR-10 (32x32 images)
44# ==============================================================
45model = torchvision.models.resnet18(num_classes=10)
46# Replace first conv for small images (CIFAR-10 is 32x32)
47model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
48model.maxpool = nn.Identity() # Remove maxpool for small images
49
50device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51model = model.to(device)
52
53# ==============================================================
54# Optimizer + LR schedule with warm-up + cosine annealing
55# ==============================================================
56optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.05)
57num_epochs = 100
58warmup_epochs = 5
59
60warmup_scheduler = LinearLR(
61 optimizer, start_factor=0.01, total_iters=warmup_epochs
62)
63cosine_scheduler = CosineAnnealingLR(
64 optimizer, T_max=num_epochs - warmup_epochs, eta_min=1e-6
65)
66scheduler = SequentialLR(
67 optimizer, schedulers=[warmup_scheduler, cosine_scheduler],
68 milestones=[warmup_epochs]
69)
70
71# Label smoothing built into CrossEntropyLoss
72criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
73
74# ==============================================================
75# Training loop
76# ==============================================================
77for epoch in range(num_epochs):
78 model.train()
79 running_loss = 0.0
80 correct = 0
81 total = 0
82
83 for images, labels in train_loader:
84 images, labels = images.to(device), labels.to(device)
85 optimizer.zero_grad()
86 outputs = model(images)
87 loss = criterion(outputs, labels)
88 loss.backward()
89 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
90 optimizer.step()
91
92 running_loss += loss.item()
93 _, predicted = outputs.max(1)
94 total += labels.size(0)
95 correct += predicted.eq(labels).sum().item()
96
97 scheduler.step()
98
99 train_acc = 100.0 * correct / total
100 lr = optimizer.param_groups[0]["lr"]
101
102 # Validation
103 model.eval()
104 val_correct = 0
105 val_total = 0
106 with torch.no_grad():
107 for images, labels in val_loader:
108 images, labels = images.to(device), labels.to(device)
109 outputs = model(images)
110 _, predicted = outputs.max(1)
111 val_total += labels.size(0)
112 val_correct += predicted.eq(labels).sum().item()
113
114 val_acc = 100.0 * val_correct / val_total
115 if (epoch + 1) % 10 == 0:
116 print(f"Epoch {epoch+1}/{num_epochs} | "
117 f"LR: {lr:.6f} | "
118 f"Train Acc: {train_acc:.1f}% | "
119 f"Val Acc: {val_acc:.1f}%")1# ==============================================================
2# TensorFlow / Keras implementation for comparison
3# ==============================================================
4import tensorflow as tf
5from tensorflow import keras
6from tensorflow.keras import layers
7import tensorflow.keras.applications as apps
8
9# Data augmentation layer (built into tf.keras)
10data_augmentation = keras.Sequential([
11 layers.RandomFlip("horizontal"),
12 layers.RandomTranslation(0.1, 0.1),
13 layers.RandomZoom(0.1),
14 layers.RandomContrast(0.2),
15])
16
17# Load CIFAR-10
18(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
19x_train = x_train.astype("float32") / 255.0
20x_test = x_test.astype("float32") / 255.0
21
22# Build a simple ResNet-style model for CIFAR-10
23def residual_block(x, filters, stride=1):
24 shortcut = x
25 x = layers.Conv2D(filters, 3, strides=stride, padding="same")(x)
26 x = layers.BatchNormalization()(x)
27 x = layers.ReLU()(x)
28 x = layers.Conv2D(filters, 3, padding="same")(x)
29 x = layers.BatchNormalization()(x)
30 if stride != 1 or shortcut.shape[-1] != filters:
31 shortcut = layers.Conv2D(filters, 1, strides=stride)(shortcut)
32 shortcut = layers.BatchNormalization()(shortcut)
33 x = layers.Add()([x, shortcut])
34 x = layers.ReLU()(x)
35 return x
36
37inputs = keras.Input(shape=(32, 32, 3))
38x = data_augmentation(inputs) # augmentation as a layer
39x = layers.Conv2D(64, 3, padding="same")(x)
40x = layers.BatchNormalization()(x)
41x = layers.ReLU()(x)
42for _ in range(3):
43 x = residual_block(x, 64)
44x = residual_block(x, 128, stride=2)
45for _ in range(2):
46 x = residual_block(x, 128)
47x = residual_block(x, 256, stride=2)
48for _ in range(2):
49 x = residual_block(x, 256)
50x = layers.GlobalAveragePooling2D()(x)
51outputs = layers.Dense(10, activation="softmax")(x)
52
53model = keras.Model(inputs, outputs)
54
55# Cosine decay schedule with warm-up
56warmup_steps = 5 * (len(x_train) // 128)
57total_steps = 100 * (len(x_train) // 128)
58lr_schedule = keras.optimizers.schedules.CosineDecay(
59 initial_learning_rate=1e-3,
60 decay_steps=total_steps - warmup_steps,
61 alpha=1e-6,
62 warmup_target=1e-3,
63 warmup_steps=warmup_steps,
64)
65
66model.compile(
67 optimizer=keras.optimizers.AdamW(
68 learning_rate=lr_schedule, weight_decay=0.05
69 ),
70 loss=keras.losses.SparseCategoricalCrossentropy(label_smoothing=0.1),
71 metrics=["accuracy"],
72)
73
74model.fit(x_train, y_train, epochs=100, batch_size=128,
75 validation_data=(x_test, y_test), verbose=2)