Skip to main content

Vision Transformers & Multimodal Models

ViT, DINO, CLIP, and multimodal models bridging vision and language

~50 min
Listen to this lesson

Vision Transformers & Multimodal Models

Transformers, originally designed for NLP, have revolutionized computer vision. This lesson covers the key architectures that bridge vision and language.

Vision Transformer (ViT)

Core Idea

Treat an image as a sequence of patches, just like a sentence is a sequence of tokens:

1. Patch Embedding: Split the image into fixed-size patches (e.g., 16x16 pixels). Each patch is flattened and linearly projected to an embedding dimension. - A 224x224 image with 16x16 patches produces 196 tokens (14x14 grid) - Each token: 16 * 16 * 3 = 768 pixels, projected to embedding dimension (e.g., 768)

2. Position Encoding: Add learnable position embeddings to each patch token. Without these, the model has no notion of spatial arrangement.

3. [CLS] Token: Prepend a special classification token (like BERT). After the transformer encoder, this token's representation is used for classification.

4. Transformer Encoder: Standard transformer blocks (multi-head self-attention + MLP + LayerNorm). Every patch attends to every other patch.

ViT Variants

ModelLayersHiddenHeadsParams
ViT-S (Small)12384622M
ViT-B (Base)127681286M
ViT-L (Large)24102416307M
ViT-H (Huge)32128016632M

Key Insight

ViT requires large-scale pretraining (ImageNet-21k or JFT-300M) to outperform CNNs. With small datasets, CNNs still win because their inductive biases (translation equivariance, locality) help with limited data.

DeiT (Data-efficient Image Transformers)

Showed that ViT can be trained effectively on ImageNet-1k alone with the right training recipe:
  • Strong augmentation (RandAugment, Mixup, CutMix)
  • Regularization (stochastic depth, label smoothing, repeated augmentation)
  • Knowledge distillation: A CNN teacher guides the ViT student via a distillation token
  • Self-Attention in Vision: Global vs Local

    In CNNs, each layer has a limited receptive field (e.g., 3x3 or 5x5). Global context builds up gradually through many layers. In ViT, self-attention connects EVERY patch to EVERY other patch in a single layer. This gives immediate global context but at O(n^2) cost where n is the number of patches. This is why ViT excels at capturing long-range dependencies (e.g., relating a person's face to their shoes) but struggles with efficiency on high-resolution images.

    DINO / DINOv2 (Self-Supervised Vision)

    DINO (Self-DIstillation with NO labels) learns visual features without any labeled data:

    DINO Training

    1. Create two augmented views of the same image 2. Pass one through the student network, the other through the teacher (exponential moving average of student) 3. Train the student to match the teacher's output distribution 4. No labels needed — the model learns from image structure alone

    Why DINO Features Are Special

  • Attention maps naturally highlight object boundaries (emergent segmentation!)
  • Features are highly transferable to downstream tasks
  • Nearest-neighbor retrieval works remarkably well
  • DINOv2 (2023)

    Meta's improved version trained on 142M curated images:
  • State-of-the-art self-supervised features
  • Works as a strong frozen backbone (no fine-tuning needed for many tasks)
  • Excellent for: classification, segmentation, depth estimation, retrieval
  • Available in ViT-S/B/L/g variants
  • CLIP (Contrastive Language-Image Pre-training)

    OpenAI's CLIP (2021) learns to connect images and text:

    Architecture

  • Image Encoder: ViT or ResNet that maps images to a shared embedding space
  • Text Encoder: Transformer that maps text to the same embedding space
  • Contrastive Training

    Trained on 400 million image-text pairs from the internet: 1. Encode a batch of N images and N texts 2. Compute N x N cosine similarity matrix between all image-text pairs 3. The diagonal entries (matching pairs) should have high similarity 4. Off-diagonal entries (non-matching) should have low similarity 5. Symmetric cross-entropy loss pushes matching pairs together

    Zero-Shot Classification

    CLIP enables classification without any training on the target dataset: 1. Create text prompts for each class: "a photo of a {class name}" 2. Encode all prompts with the text encoder 3. Encode the image with the image encoder 4. Predict the class whose text embedding is most similar to the image embedding

    This works because CLIP learned a rich alignment between visual concepts and language!

    Multimodal Models

    LLaVA (Large Language and Vision Assistant)

    Connects a vision encoder to a large language model: 1. CLIP vision encoder extracts image features 2. A projection layer maps visual features to the LLM's token space 3. The LLM (e.g., Llama, Vicuna) processes both visual and text tokens 4. Enables visual question answering, image captioning, visual reasoning

    GPT-4V / GPT-4o

    OpenAI's multimodal models that natively understand images:
  • Can describe, analyze, and reason about images
  • Reads text in images (OCR)
  • Understands charts, diagrams, and code screenshots
  • Combines vision with general reasoning capabilities
  • Other Notable Models

  • Gemini: Google's natively multimodal model (text, image, audio, video)
  • Claude 3: Anthropic's model with vision understanding
  • Florence-2: Microsoft's unified vision model (detection, segmentation, captioning, OCR in one model)
  • Zero-Shot Classification with CLIP

    The most practical application of CLIP for everyday use:

    python
    1# ==============================================================
    2# CLIP Zero-Shot Classification
    3# pip install transformers pillow
    4# ==============================================================
    5import torch
    6from PIL import Image
    7from transformers import CLIPProcessor, CLIPModel
    8
    9# Load CLIP
    10model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    11processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    12
    13# Load an image
    14image = Image.open("photo.jpg")  # any image
    15
    16# Define candidate labels (no training needed!)
    17labels = [
    18    "a photo of a cat",
    19    "a photo of a dog",
    20    "a photo of a bird",
    21    "a photo of a car",
    22    "a photo of a building",
    23    "a photo of food",
    24    "a photo of a person",
    25]
    26
    27# Compute similarities
    28inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
    29outputs = model(**inputs)
    30
    31# Get probabilities
    32logits_per_image = outputs.logits_per_image  # [1, num_labels]
    33probs = logits_per_image.softmax(dim=1)
    34
    35# Print results
    36print("Zero-shot classification results:")
    37for label, prob in sorted(zip(labels, probs[0].tolist()),
    38                           key=lambda x: -x[1]):
    39    print(f"  {label}: {prob:.1%}")
    python
    1# ==============================================================
    2# DINOv2 Feature Extraction
    3# ==============================================================
    4import torch
    5from transformers import AutoImageProcessor, AutoModel
    6from PIL import Image
    7import numpy as np
    8from sklearn.metrics.pairwise import cosine_similarity
    9
    10# Load DINOv2
    11processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    12model = AutoModel.from_pretrained("facebook/dinov2-base")
    13model.eval()
    14
    15def extract_features(image_path):
    16    """Extract DINOv2 features from an image."""
    17    image = Image.open(image_path).convert("RGB")
    18    inputs = processor(images=image, return_tensors="pt")
    19    with torch.no_grad():
    20        outputs = model(**inputs)
    21    # Use the [CLS] token as the image representation
    22    features = outputs.last_hidden_state[:, 0, :]  # [1, 768]
    23    return features.numpy().flatten()
    24
    25# Extract features from multiple images
    26image_paths = ["cat1.jpg", "cat2.jpg", "dog1.jpg", "car1.jpg"]
    27features = [extract_features(p) for p in image_paths]
    28features = np.stack(features)
    29
    30# Compute pairwise similarity
    31sim_matrix = cosine_similarity(features)
    32print("Pairwise cosine similarity:")
    33for i, name_i in enumerate(image_paths):
    34    for j, name_j in enumerate(image_paths):
    35        if i < j:
    36            print(f"  {name_i} vs {name_j}: {sim_matrix[i,j]:.3f}")
    37
    38# Use features for nearest-neighbor retrieval
    39query_features = extract_features("query.jpg")
    40similarities = cosine_similarity([query_features], features)[0]
    41ranked = np.argsort(-similarities)
    42print("\nMost similar images to query:")
    43for idx in ranked:
    44    print(f"  {image_paths[idx]}: {similarities[idx]:.3f}")
    python
    1# ==============================================================
    2# ViT from scratch (simplified) in PyTorch
    3# ==============================================================
    4import torch
    5import torch.nn as nn
    6
    7class PatchEmbedding(nn.Module):
    8    """Split image into patches and project to embedding dimension."""
    9    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
    10        super().__init__()
    11        self.num_patches = (img_size // patch_size) ** 2
    12        # Conv2d with kernel_size=stride=patch_size acts as patch extraction + projection
    13        self.proj = nn.Conv2d(in_channels, embed_dim,
    14                              kernel_size=patch_size, stride=patch_size)
    15
    16    def forward(self, x):
    17        # x: [B, C, H, W] -> [B, embed_dim, H/P, W/P] -> [B, num_patches, embed_dim]
    18        return self.proj(x).flatten(2).transpose(1, 2)
    19
    20
    21class TransformerBlock(nn.Module):
    22    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0, dropout=0.0):
    23        super().__init__()
    24        self.norm1 = nn.LayerNorm(embed_dim)
    25        self.attn = nn.MultiheadAttention(embed_dim, num_heads,
    26                                           dropout=dropout, batch_first=True)
    27        self.norm2 = nn.LayerNorm(embed_dim)
    28        self.mlp = nn.Sequential(
    29            nn.Linear(embed_dim, int(embed_dim * mlp_ratio)),
    30            nn.GELU(),
    31            nn.Dropout(dropout),
    32            nn.Linear(int(embed_dim * mlp_ratio), embed_dim),
    33            nn.Dropout(dropout),
    34        )
    35
    36    def forward(self, x):
    37        # Pre-norm architecture
    38        h = self.norm1(x)
    39        x = x + self.attn(h, h, h)[0]   # self-attention + residual
    40        x = x + self.mlp(self.norm2(x))  # MLP + residual
    41        return x
    42
    43
    44class VisionTransformer(nn.Module):
    45    def __init__(self, img_size=224, patch_size=16, in_channels=3,
    46                 num_classes=1000, embed_dim=768, depth=12,
    47                 num_heads=12, mlp_ratio=4.0, dropout=0.0):
    48        super().__init__()
    49        self.patch_embed = PatchEmbedding(img_size, patch_size,
    50                                           in_channels, embed_dim)
    51        num_patches = self.patch_embed.num_patches
    52
    53        # [CLS] token and position embeddings
    54        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    55        self.pos_embed = nn.Parameter(
    56            torch.zeros(1, num_patches + 1, embed_dim)
    57        )
    58
    59        # Transformer blocks
    60        self.blocks = nn.Sequential(
    61            *[TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
    62              for _ in range(depth)]
    63        )
    64
    65        self.norm = nn.LayerNorm(embed_dim)
    66        self.head = nn.Linear(embed_dim, num_classes)
    67
    68        nn.init.trunc_normal_(self.cls_token, std=0.02)
    69        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    70
    71    def forward(self, x):
    72        B = x.shape[0]
    73        # Patch embedding
    74        x = self.patch_embed(x)  # [B, num_patches, embed_dim]
    75
    76        # Prepend [CLS] token
    77        cls = self.cls_token.expand(B, -1, -1)
    78        x = torch.cat([cls, x], dim=1)  # [B, num_patches + 1, embed_dim]
    79
    80        # Add position embeddings
    81        x = x + self.pos_embed
    82
    83        # Transformer encoder
    84        x = self.blocks(x)
    85        x = self.norm(x)
    86
    87        # Classification from [CLS] token
    88        return self.head(x[:, 0])
    89
    90# Test
    91vit = VisionTransformer(
    92    img_size=224, patch_size=16, num_classes=100,
    93    embed_dim=384, depth=12, num_heads=6  # ViT-S config
    94)
    95x = torch.randn(2, 3, 224, 224)
    96out = vit(x)
    97print(f"Input: {x.shape}")    # [2, 3, 224, 224]
    98print(f"Output: {out.shape}")  # [2, 100]
    99print(f"Parameters: {sum(p.numel() for p in vit.parameters()):,}")  # ~21.7M

    When to Use What

    - **Image classification**: Use a pretrained ViT or EfficientNet with fine-tuning - **Zero-shot classification**: Use CLIP (no training data needed for target classes) - **Feature extraction / retrieval**: Use DINOv2 (best general-purpose features) - **Visual QA / captioning**: Use LLaVA, GPT-4V, or Gemini - **Object detection + segmentation**: Use YOLO or Florence-2 General rule: Start with the simplest approach that works. CLIP zero-shot is often surprisingly good and requires zero training.