Vision Transformers & Multimodal Models
Transformers, originally designed for NLP, have revolutionized computer vision. This lesson covers the key architectures that bridge vision and language.
Vision Transformer (ViT)
Core Idea
Treat an image as a sequence of patches, just like a sentence is a sequence of tokens:1. Patch Embedding: Split the image into fixed-size patches (e.g., 16x16 pixels). Each patch is flattened and linearly projected to an embedding dimension. - A 224x224 image with 16x16 patches produces 196 tokens (14x14 grid) - Each token: 16 * 16 * 3 = 768 pixels, projected to embedding dimension (e.g., 768)
2. Position Encoding: Add learnable position embeddings to each patch token. Without these, the model has no notion of spatial arrangement.
3. [CLS] Token: Prepend a special classification token (like BERT). After the transformer encoder, this token's representation is used for classification.
4. Transformer Encoder: Standard transformer blocks (multi-head self-attention + MLP + LayerNorm). Every patch attends to every other patch.
ViT Variants
| Model | Layers | Hidden | Heads | Params |
|---|---|---|---|---|
| ViT-S (Small) | 12 | 384 | 6 | 22M |
| ViT-B (Base) | 12 | 768 | 12 | 86M |
| ViT-L (Large) | 24 | 1024 | 16 | 307M |
| ViT-H (Huge) | 32 | 1280 | 16 | 632M |
Key Insight
ViT requires large-scale pretraining (ImageNet-21k or JFT-300M) to outperform CNNs. With small datasets, CNNs still win because their inductive biases (translation equivariance, locality) help with limited data.DeiT (Data-efficient Image Transformers)
Showed that ViT can be trained effectively on ImageNet-1k alone with the right training recipe:Self-Attention in Vision: Global vs Local
DINO / DINOv2 (Self-Supervised Vision)
DINO (Self-DIstillation with NO labels) learns visual features without any labeled data:
DINO Training
1. Create two augmented views of the same image 2. Pass one through the student network, the other through the teacher (exponential moving average of student) 3. Train the student to match the teacher's output distribution 4. No labels needed â the model learns from image structure aloneWhy DINO Features Are Special
DINOv2 (2023)
Meta's improved version trained on 142M curated images:CLIP (Contrastive Language-Image Pre-training)
OpenAI's CLIP (2021) learns to connect images and text:
Architecture
Contrastive Training
Trained on 400 million image-text pairs from the internet: 1. Encode a batch of N images and N texts 2. Compute N x N cosine similarity matrix between all image-text pairs 3. The diagonal entries (matching pairs) should have high similarity 4. Off-diagonal entries (non-matching) should have low similarity 5. Symmetric cross-entropy loss pushes matching pairs togetherZero-Shot Classification
CLIP enables classification without any training on the target dataset: 1. Create text prompts for each class: "a photo of a {class name}" 2. Encode all prompts with the text encoder 3. Encode the image with the image encoder 4. Predict the class whose text embedding is most similar to the image embeddingThis works because CLIP learned a rich alignment between visual concepts and language!
Multimodal Models
LLaVA (Large Language and Vision Assistant)
Connects a vision encoder to a large language model: 1. CLIP vision encoder extracts image features 2. A projection layer maps visual features to the LLM's token space 3. The LLM (e.g., Llama, Vicuna) processes both visual and text tokens 4. Enables visual question answering, image captioning, visual reasoningGPT-4V / GPT-4o
OpenAI's multimodal models that natively understand images:Other Notable Models
Zero-Shot Classification with CLIP
The most practical application of CLIP for everyday use:
1# ==============================================================
2# CLIP Zero-Shot Classification
3# pip install transformers pillow
4# ==============================================================
5import torch
6from PIL import Image
7from transformers import CLIPProcessor, CLIPModel
8
9# Load CLIP
10model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
11processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
12
13# Load an image
14image = Image.open("photo.jpg") # any image
15
16# Define candidate labels (no training needed!)
17labels = [
18 "a photo of a cat",
19 "a photo of a dog",
20 "a photo of a bird",
21 "a photo of a car",
22 "a photo of a building",
23 "a photo of food",
24 "a photo of a person",
25]
26
27# Compute similarities
28inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
29outputs = model(**inputs)
30
31# Get probabilities
32logits_per_image = outputs.logits_per_image # [1, num_labels]
33probs = logits_per_image.softmax(dim=1)
34
35# Print results
36print("Zero-shot classification results:")
37for label, prob in sorted(zip(labels, probs[0].tolist()),
38 key=lambda x: -x[1]):
39 print(f" {label}: {prob:.1%}")1# ==============================================================
2# DINOv2 Feature Extraction
3# ==============================================================
4import torch
5from transformers import AutoImageProcessor, AutoModel
6from PIL import Image
7import numpy as np
8from sklearn.metrics.pairwise import cosine_similarity
9
10# Load DINOv2
11processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
12model = AutoModel.from_pretrained("facebook/dinov2-base")
13model.eval()
14
15def extract_features(image_path):
16 """Extract DINOv2 features from an image."""
17 image = Image.open(image_path).convert("RGB")
18 inputs = processor(images=image, return_tensors="pt")
19 with torch.no_grad():
20 outputs = model(**inputs)
21 # Use the [CLS] token as the image representation
22 features = outputs.last_hidden_state[:, 0, :] # [1, 768]
23 return features.numpy().flatten()
24
25# Extract features from multiple images
26image_paths = ["cat1.jpg", "cat2.jpg", "dog1.jpg", "car1.jpg"]
27features = [extract_features(p) for p in image_paths]
28features = np.stack(features)
29
30# Compute pairwise similarity
31sim_matrix = cosine_similarity(features)
32print("Pairwise cosine similarity:")
33for i, name_i in enumerate(image_paths):
34 for j, name_j in enumerate(image_paths):
35 if i < j:
36 print(f" {name_i} vs {name_j}: {sim_matrix[i,j]:.3f}")
37
38# Use features for nearest-neighbor retrieval
39query_features = extract_features("query.jpg")
40similarities = cosine_similarity([query_features], features)[0]
41ranked = np.argsort(-similarities)
42print("\nMost similar images to query:")
43for idx in ranked:
44 print(f" {image_paths[idx]}: {similarities[idx]:.3f}")1# ==============================================================
2# ViT from scratch (simplified) in PyTorch
3# ==============================================================
4import torch
5import torch.nn as nn
6
7class PatchEmbedding(nn.Module):
8 """Split image into patches and project to embedding dimension."""
9 def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
10 super().__init__()
11 self.num_patches = (img_size // patch_size) ** 2
12 # Conv2d with kernel_size=stride=patch_size acts as patch extraction + projection
13 self.proj = nn.Conv2d(in_channels, embed_dim,
14 kernel_size=patch_size, stride=patch_size)
15
16 def forward(self, x):
17 # x: [B, C, H, W] -> [B, embed_dim, H/P, W/P] -> [B, num_patches, embed_dim]
18 return self.proj(x).flatten(2).transpose(1, 2)
19
20
21class TransformerBlock(nn.Module):
22 def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
23 super().__init__()
24 self.norm1 = nn.LayerNorm(embed_dim)
25 self.attn = nn.MultiheadAttention(embed_dim, num_heads,
26 dropout=dropout, batch_first=True)
27 self.norm2 = nn.LayerNorm(embed_dim)
28 self.mlp = nn.Sequential(
29 nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
30 nn.GELU(),
31 nn.Dropout(dropout),
32 nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
33 nn.Dropout(dropout),
34 )
35
36 def forward(self, x):
37 # Pre-norm architecture
38 h = self.norm1(x)
39 x = x + self.attn(h, h, h)[0] # self-attention + residual
40 x = x + self.mlp(self.norm2(x)) # MLP + residual
41 return x
42
43
44class VisionTransformer(nn.Module):
45 def __init__(self, img_size=224, patch_size=16, in_channels=3,
46 num_classes=1000, embed_dim=768, depth=12,
47 num_heads=12, mlp_ratio=4.0, dropout=0.0):
48 super().__init__()
49 self.patch_embed = PatchEmbedding(img_size, patch_size,
50 in_channels, embed_dim)
51 num_patches = self.patch_embed.num_patches
52
53 # [CLS] token and position embeddings
54 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
55 self.pos_embed = nn.Parameter(
56 torch.zeros(1, num_patches + 1, embed_dim)
57 )
58
59 # Transformer blocks
60 self.blocks = nn.Sequential(
61 *[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
62 for _ in range(depth)]
63 )
64
65 self.norm = nn.LayerNorm(embed_dim)
66 self.head = nn.Linear(embed_dim, num_classes)
67
68 nn.init.trunc_normal_(self.cls_token, std=0.02)
69 nn.init.trunc_normal_(self.pos_embed, std=0.02)
70
71 def forward(self, x):
72 B = x.shape[0]
73 # Patch embedding
74 x = self.patch_embed(x) # [B, num_patches, embed_dim]
75
76 # Prepend [CLS] token
77 cls = self.cls_token.expand(B, -1, -1)
78 x = torch.cat([cls, x], dim=1) # [B, num_patches + 1, embed_dim]
79
80 # Add position embeddings
81 x = x + self.pos_embed
82
83 # Transformer encoder
84 x = self.blocks(x)
85 x = self.norm(x)
86
87 # Classification from [CLS] token
88 return self.head(x[:, 0])
89
90# Test
91vit = VisionTransformer(
92 img_size=224, patch_size=16, num_classes=100,
93 embed_dim=384, depth=12, num_heads=6 # ViT-S config
94)
95x = torch.randn(2, 3, 224, 224)
96out = vit(x)
97print(f"Input: {x.shape}") # [2, 3, 224, 224]
98print(f"Output: {out.shape}") # [2, 100]
99print(f"Parameters: {sum(p.numel() for p in vit.parameters()):,}") # ~21.7M