Skip to main content

Computer Vision & Audio with PyTorch

Image classification, transfer learning, audio processing, and data augmentation pipelines

~50 min
Listen to this lesson

Computer Vision & Audio with PyTorch

PyTorch's ecosystem includes torchvision for computer vision and torchaudio for audio processing. These libraries provide pre-trained models, standard datasets, and common transforms — so you don't have to build everything from scratch.

torchvision: Datasets and Transforms

torchvision provides standard image datasets and a composable transform pipeline:

python
1from torchvision import datasets, transforms
2
3# --- Composable transforms ---
4train_transform = transforms.Compose([
5    transforms.RandomResizedCrop(224),          # Randomly crop and resize to 224x224
6    transforms.RandomHorizontalFlip(p=0.5),     # 50% chance of horizontal flip
7    transforms.ColorJitter(                     # Randomly alter brightness, contrast, etc.
8        brightness=0.2, contrast=0.2,
9        saturation=0.2, hue=0.1
10    ),
11    transforms.RandomRotation(15),              # Rotate up to 15 degrees
12    transforms.ToTensor(),                      # Convert PIL Image -> float tensor [0, 1]
13    transforms.Normalize(                       # Normalize with ImageNet stats
14        mean=[0.485, 0.456, 0.406],
15        std=[0.229, 0.224, 0.225]
16    ),
17])
18
19# Validation transforms: no augmentation, just resize and normalize
20val_transform = transforms.Compose([
21    transforms.Resize(256),
22    transforms.CenterCrop(224),
23    transforms.ToTensor(),
24    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
25])
26
27# Standard datasets
28cifar10 = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
29imagenet = datasets.ImageNet(root="./data/imagenet", split="train", transform=train_transform)
30
31# Load custom image folder: expects structure data/class1/img.jpg, data/class2/img.jpg, ...
32custom_dataset = datasets.ImageFolder(
33    root="./data/my_images",
34    transform=train_transform,
35)
36print(f"Classes: {custom_dataset.classes}")  # ['cats', 'dogs', ...]
37print(f"Class->Idx: {custom_dataset.class_to_idx}")

Data Augmentation

Data augmentation artificially increases the diversity of your training set by applying random transformations (flips, rotations, color changes, crops) to each image at training time. This helps the model generalize better and reduces overfitting — especially when you have limited data. Important: augmentation is applied ONLY during training, never during validation or testing.

Pre-trained Models and Transfer Learning

torchvision includes models pre-trained on ImageNet (1.2M images, 1000 classes). Transfer learning uses these pre-trained weights as a starting point for your own task:

python
1import torch
2import torch.nn as nn
3from torchvision import models
4
5# --- Load a pre-trained ResNet-50 ---
6# New API (torchvision 0.13+): use weights parameter
7weights = models.ResNet50_Weights.IMAGENET1K_V2
8model = models.resnet50(weights=weights)
9
10# Old API (still works): pretrained=True
11# model = models.resnet50(pretrained=True)
12
13# --- Transfer Learning Strategy ---
14# Step 1: Freeze all pre-trained layers
15for param in model.parameters():
16    param.requires_grad = False
17
18# Step 2: Replace the final fully connected layer for your task
19num_classes = 5  # Your number of classes
20model.fc = nn.Sequential(
21    nn.Dropout(0.3),
22    nn.Linear(model.fc.in_features, num_classes),
23)
24# The new layer's parameters have requires_grad=True by default
25
26# Step 3: Only train the new layer (much faster!)
27optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
28
29# --- Fine-tuning: unfreeze some layers later ---
30# After training the head, optionally unfreeze deeper layers
31for param in model.layer4.parameters():
32    param.requires_grad = True
33
34# Use a lower learning rate for pre-trained layers
35optimizer = torch.optim.Adam([
36    {"params": model.layer4.parameters(), "lr": 1e-5},   # Pre-trained: low LR
37    {"params": model.fc.parameters(), "lr": 1e-3},        # New head: higher LR
38])
39
40# Other available models
41vgg16 = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
42efficientnet = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
43vit = models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)  # Vision Transformer

Transfer Learning Workflow

1. Load a pre-trained model (ResNet, EfficientNet, ViT, etc.) 2. Freeze all parameters with requires_grad = False 3. Replace the final classification head for your number of classes 4. Train only the head for a few epochs 5. (Optional) Unfreeze deeper layers and fine-tune with a very small learning rate This works because early layers learn generic features (edges, textures) that transfer to any image task. Only the final layers need task-specific training.

torchaudio: Audio Processing

torchaudio provides tools for loading, transforming, and classifying audio:

python
1import torch
2import torchaudio
3import torchaudio.transforms as T
4
5# --- Load an audio file ---
6waveform, sample_rate = torchaudio.load("audio.wav")
7print(f"Waveform shape: {waveform.shape}")   # (channels, num_samples)
8print(f"Sample rate: {sample_rate}")          # e.g., 16000
9
10# --- Resample to a standard rate ---
11resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
12waveform = resampler(waveform)
13
14# --- Compute Spectrogram ---
15spectrogram_transform = T.Spectrogram(
16    n_fft=1024,
17    win_length=1024,
18    hop_length=512,
19)
20spectrogram = spectrogram_transform(waveform)
21print(f"Spectrogram shape: {spectrogram.shape}")  # (channels, freq_bins, time_frames)
22
23# --- Mel Spectrogram (most common for ML) ---
24mel_transform = T.MelSpectrogram(
25    sample_rate=16000,
26    n_fft=1024,
27    hop_length=512,
28    n_mels=64,          # Number of mel filter banks
29)
30mel_spec = mel_transform(waveform)
31print(f"Mel spectrogram shape: {mel_spec.shape}")  # (channels, n_mels, time_frames)
32
33# Convert to log scale (decibels) — better for neural networks
34amplitude_to_db = T.AmplitudeToDB()
35log_mel_spec = amplitude_to_db(mel_spec)
36
37# --- MFCC (Mel-Frequency Cepstral Coefficients) ---
38mfcc_transform = T.MFCC(
39    sample_rate=16000,
40    n_mfcc=13,
41    melkwargs={"n_mels": 64, "n_fft": 1024},
42)
43mfcc = mfcc_transform(waveform)
44
45# --- Audio augmentation ---
46time_stretch = T.TimeStretch(n_freq=513)  # Speed up/slow down
47freq_masking = T.FrequencyMasking(freq_mask_param=15)  # Mask frequency bands
48time_masking = T.TimeMasking(time_mask_param=35)       # Mask time steps

What is a Mel Spectrogram?

A mel spectrogram represents audio as an image-like 2D matrix: one axis is time, the other is frequency (on the mel scale, which mimics human perception). Neural networks process mel spectrograms just like images — you can even use CNNs and pre-trained image models on them. The mel scale compresses higher frequencies, reflecting how humans are better at distinguishing low frequencies than high frequencies.

Audio Classification Pipeline

You can classify audio by converting to mel spectrograms and using a CNN:

python
1import torch
2import torch.nn as nn
3import torchaudio
4import torchaudio.transforms as T
5from torch.utils.data import Dataset, DataLoader
6
7class AudioClassifier(nn.Module):
8    """Simple CNN for audio classification on mel spectrograms."""
9    def __init__(self, n_mels=64, num_classes=10):
10        super().__init__()
11        self.mel_transform = T.MelSpectrogram(
12            sample_rate=16000, n_fft=1024, hop_length=512, n_mels=n_mels
13        )
14        self.db_transform = T.AmplitudeToDB()
15
16        self.cnn = nn.Sequential(
17            nn.Conv2d(1, 16, kernel_size=3, padding=1),
18            nn.BatchNorm2d(16),
19            nn.ReLU(),
20            nn.MaxPool2d(2),
21            nn.Conv2d(16, 32, kernel_size=3, padding=1),
22            nn.BatchNorm2d(32),
23            nn.ReLU(),
24            nn.AdaptiveAvgPool2d((4, 4)),  # Fixed output size regardless of input length
25        )
26        self.classifier = nn.Sequential(
27            nn.Flatten(),
28            nn.Linear(32 * 4 * 4, 128),
29            nn.ReLU(),
30            nn.Dropout(0.3),
31            nn.Linear(128, num_classes),
32        )
33
34    def forward(self, waveform):
35        # waveform: (batch, 1, num_samples)
36        mel = self.mel_transform(waveform)   # (batch, 1, n_mels, time)
37        mel = self.db_transform(mel)
38        features = self.cnn(mel)
39        return self.classifier(features)
40
41# Usage
42model = AudioClassifier(n_mels=64, num_classes=10)
43dummy_audio = torch.randn(8, 1, 16000)  # 8 clips, 1 second each at 16kHz
44output = model(dummy_audio)
45print(f"Output shape: {output.shape}")  # (8, 10)

Transfer Learning for Audio

A powerful technique is to convert audio to mel spectrograms (as 2D images) and use a pre-trained image model like ResNet. Treat the spectrogram as a 1-channel image and modify the first conv layer: model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3). This often outperforms training from scratch.