Skip to main content

Model Optimization

Quantization (INT8, INT4, dynamic vs static), pruning, knowledge distillation, and neural architecture search for efficiency

~45 min
Listen to this lesson

Model Optimization for Edge Deployment

Deploying ML models to edge devices — mobile phones, IoT sensors, embedded systems — requires models that are small, fast, and energy-efficient. A state-of-the-art model that runs on a GPU cluster is useless if it cannot run on a phone in real time. This lesson covers the four pillars of model optimization: quantization, pruning, knowledge distillation, and efficient architecture design.

The Edge AI Challenge

Edge devices have extreme constraints: a smartphone has ~4-8GB RAM (shared with OS and apps), a microcontroller may have only 256KB. Inference must complete in milliseconds, and power consumption determines battery life. Model optimization is not optional for edge deployment — it is the entire challenge.

Quantization

Quantization reduces the precision of model weights and activations from 32-bit floating point (FP32) to lower-bit representations (INT8, INT4, or even binary). This reduces model size, memory usage, and inference latency.

Why Quantization Works

Neural networks are robust to noise — small perturbations in weight values have minimal impact on output quality. Quantization exploits this by mapping the continuous FP32 range to a discrete set of lower-precision values.

Quantization Types

TypeDescriptionAccuracy LossSpeed Gain
FP32 (baseline)Standard 32-bit floatNone1x
FP16 / BF16Half precisionMinimal~2x
INT88-bit integerSmall (1-2%)~2-4x
INT44-bit integerModerate (2-5%)~4-8x
Binary/Ternary1-2 bitSignificant~10-32x

Dynamic vs Static Quantization

  • Dynamic Quantization: Weights are quantized ahead of time, but activations are quantized on-the-fly during inference. Simplest to apply — often just one line of code.
  • Static Quantization: Both weights and activations are quantized using calibration data. Requires a calibration step but produces faster models.
  • Quantization-Aware Training (QAT): Simulates quantization during training so the model learns to be robust to it. Best accuracy but requires retraining.
  • python
    1import torch
    2import torch.nn as nn
    3import torch.quantization as quant
    4import time
    5
    6# --- Define a simple model ---
    7class SimpleClassifier(nn.Module):
    8    def __init__(self, input_dim=784, hidden=256, output=10):
    9        super().__init__()
    10        self.fc1 = nn.Linear(input_dim, hidden)
    11        self.relu = nn.ReLU()
    12        self.fc2 = nn.Linear(hidden, hidden)
    13        self.relu2 = nn.ReLU()
    14        self.fc3 = nn.Linear(hidden, output)
    15
    16    def forward(self, x):
    17        x = self.relu(self.fc1(x))
    18        x = self.relu2(self.fc2(x))
    19        return self.fc3(x)
    20
    21model = SimpleClassifier()
    22model.eval()
    23
    24# --- Dynamic Quantization (easiest) ---
    25quantized_model = torch.quantization.quantize_dynamic(
    26    model,
    27    {nn.Linear},        # Quantize all Linear layers
    28    dtype=torch.qint8,  # Use INT8
    29)
    30
    31# --- Compare model sizes ---
    32def get_model_size(model):
    33    """Get model size in MB."""
    34    torch.save(model.state_dict(), "/tmp/model.pt")
    35    import os
    36    size = os.path.getsize("/tmp/model.pt") / 1e6
    37    os.remove("/tmp/model.pt")
    38    return size
    39
    40original_size = get_model_size(model)
    41quantized_size = get_model_size(quantized_model)
    42
    43print(f"Original model size:  {original_size:.2f} MB")
    44print(f"Quantized model size: {quantized_size:.2f} MB")
    45print(f"Compression ratio:    {original_size / quantized_size:.1f}x")
    46
    47# --- Compare inference speed ---
    48dummy_input = torch.randn(1, 784)
    49n_iters = 1000
    50
    51start = time.time()
    52for _ in range(n_iters):
    53    with torch.no_grad():
    54        model(dummy_input)
    55original_time = (time.time() - start) / n_iters * 1000
    56
    57start = time.time()
    58for _ in range(n_iters):
    59    with torch.no_grad():
    60        quantized_model(dummy_input)
    61quantized_time = (time.time() - start) / n_iters * 1000
    62
    63print(f"\nOriginal inference:  {original_time:.3f} ms")
    64print(f"Quantized inference: {quantized_time:.3f} ms")
    65print(f"Speedup:             {original_time / quantized_time:.1f}x")

    Pruning

    Pruning removes unnecessary weights or neurons from a model, making it smaller and faster without significant accuracy loss.

    Unstructured Pruning

    Sets individual weights to zero based on magnitude (smallest weights are least important). Creates sparse weight matrices.
  • Easy to apply
  • High compression ratios possible (up to 90% sparsity)
  • Requires sparse computation support for actual speedup
  • Structured Pruning

    Removes entire filters, channels, or layers rather than individual weights. Produces dense, smaller models that run faster on standard hardware.
  • Directly reduces computation
  • Compatible with all hardware
  • Lower compression ratios than unstructured
  • Knowledge Distillation

    Knowledge distillation trains a small "student" model to mimic a large "teacher" model. The student learns from the teacher's soft probability outputs (which contain richer information than hard labels).

    python
    1import torch
    2import torch.nn as nn
    3import torch.nn.functional as F
    4
    5# --- Knowledge Distillation ---
    6
    7class TeacherModel(nn.Module):
    8    """Large, accurate model (e.g., ResNet-152)."""
    9    def __init__(self):
    10        super().__init__()
    11        self.fc1 = nn.Linear(784, 512)
    12        self.fc2 = nn.Linear(512, 256)
    13        self.fc3 = nn.Linear(256, 10)
    14
    15    def forward(self, x):
    16        x = F.relu(self.fc1(x))
    17        x = F.relu(self.fc2(x))
    18        return self.fc3(x)
    19
    20class StudentModel(nn.Module):
    21    """Small, fast model for edge deployment."""
    22    def __init__(self):
    23        super().__init__()
    24        self.fc1 = nn.Linear(784, 64)
    25        self.fc2 = nn.Linear(64, 10)
    26
    27    def forward(self, x):
    28        x = F.relu(self.fc1(x))
    29        return self.fc2(x)
    30
    31def distillation_loss(student_logits, teacher_logits, labels,
    32                       temperature=4.0, alpha=0.7):
    33    """Combined distillation and classification loss.
    34
    35    Args:
    36        student_logits: Raw output from student model
    37        teacher_logits: Raw output from teacher model
    38        labels: True class labels
    39        temperature: Softmax temperature (higher = softer distributions)
    40        alpha: Weight for distillation loss vs classification loss
    41    """
    42    # Soft targets from teacher (with temperature)
    43    soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
    44    soft_student = F.log_softmax(student_logits / temperature, dim=1)
    45
    46    # KL divergence between student and teacher soft outputs
    47    distill_loss = F.kl_div(
    48        soft_student, soft_teacher, reduction="batchmean"
    49    ) * (temperature ** 2)
    50
    51    # Standard cross-entropy with true labels
    52    hard_loss = F.cross_entropy(student_logits, labels)
    53
    54    # Weighted combination
    55    return alpha * distill_loss + (1 - alpha) * hard_loss
    56
    57
    58# --- Training loop (simplified) ---
    59teacher = TeacherModel()
    60student = StudentModel()
    61optimizer = torch.optim.Adam(student.parameters(), lr=0.001)
    62
    63# Simulate training
    64teacher.eval()  # Teacher is frozen
    65student.train()
    66
    67for epoch in range(5):
    68    # Simulated batch
    69    x = torch.randn(32, 784)
    70    labels = torch.randint(0, 10, (32,))
    71
    72    with torch.no_grad():
    73        teacher_logits = teacher(x)
    74
    75    student_logits = student(x)
    76    loss = distillation_loss(student_logits, teacher_logits, labels)
    77
    78    optimizer.zero_grad()
    79    loss.backward()
    80    optimizer.step()
    81
    82    print(f"Epoch {epoch + 1}: loss = {loss.item():.4f}")
    83
    84# Compare sizes
    85teacher_params = sum(p.numel() for p in teacher.parameters())
    86student_params = sum(p.numel() for p in student.parameters())
    87print(f"\nTeacher parameters: {teacher_params:,}")
    88print(f"Student parameters: {student_params:,}")
    89print(f"Compression: {teacher_params / student_params:.1f}x")

    Neural Architecture Search (NAS) for Efficiency

    Rather than manually designing efficient architectures, NAS automates the search for architectures that optimize for both accuracy and efficiency:

  • MobileNet (Google): Depthwise separable convolutions — 8-9x fewer parameters than standard convolutions
  • EfficientNet (Google): Compound scaling of depth, width, and resolution
  • Once-for-All (OFA): Train one "supernet" that can be sliced to any target hardware constraint
  • The key insight: the best architecture depends on the target hardware. A model optimized for a GPU is very different from one optimized for a phone CPU or a microcontroller.

    Combine Techniques for Maximum Compression

    In practice, optimization techniques stack: train a large teacher model, distill to a small student, apply quantization-aware training, then prune. A model that starts at 500MB can often be reduced to under 5MB while retaining 95%+ of accuracy — a 100x compression ratio.