Knowledge Distillation Implementation 2/3

Tech
2. Hint-Based Distillation
Author

Leila Mozaffari

Published

October 10, 2024

2.Hint-Based Distillation (Intermediate Feature Matching) technique

Reference:

  • Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2015). FitNets: Hints for Thin Deep Nets. arXiv preprint arXiv:1412.6550. Retrieved from https://arxiv.org/abs/1412.6550
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import time

# Define transforms for data augmentation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load the Imagenette2-320 dataset
data_dir = './data/imagenette2-320/imagenette2-320'
image_datasets = {x: datasets.ImageFolder(root=f"{data_dir}/{x}", transform=data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# Check device availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Define Teacher and Student Models

# Load pre-trained teacher (ResNet50) and student (ResNet18)
teacher_model = models.resnet50(pretrained=True)
student_model = models.resnet18(pretrained=True)

# Adjust final layers to match number of classes in Imagenette (10 classes)
num_ftrs_teacher = teacher_model.fc.in_features
teacher_model.fc = nn.Linear(num_ftrs_teacher, 10)

num_ftrs_student = student_model.fc.in_features
student_model.fc = nn.Linear(num_ftrs_student, 10)

# Move models to the appropriate device (GPU if available)
teacher_model = teacher_model.to(device)
student_model = student_model.to(device)

# Set teacher model to evaluation mode (as it is not being trained)
teacher_model.eval()

# Define 1x1 convolution to match the dimensions between teacher and student feature maps
# Assuming teacher's layer3 outputs 1024 channels and student's layer3 outputs 256 channels

conv_teacher_to_student = nn.Conv2d(1024, 256, kernel_size=1).to(device)

Extract Intermediate Feature Representations

We need to extract intermediate features from both the teacher and the student models. One way to achieve this is by using forward hooks in PyTorch to capture activations at specific layers. In this case, we’ll extract features from a chosen layer in both models, for example, the output of the third residual block in both models.

# Helper function to register a hook for feature extraction
def extract_features(module, input, output):
    return output

# Extract features from the third residual block (layer3) for both models
teacher_features = []
student_features = []

def register_hooks(model, features_storage):
    def hook(module, input, output):
        features_storage.append(output)
    return hook

# Register hook to extract features from teacher model (layer3 output)
teacher_model.layer3[5].register_forward_hook(register_hooks(teacher_model, teacher_features))

# Register hook to extract features from student model (layer3 output)
student_model.layer3[1].register_forward_hook(register_hooks(student_model, student_features))
<torch.utils.hooks.RemovableHandle at 0x21cb76104c0>

Define the Custom Distillation Loss

e now define a custom loss function that combines:

    1. Cross-Entropy Loss on the student’s hard predictions against the ground truth labels.
    1. KL Divergence Loss between the teacher’s and student’s soft logits (output of final layer).
    1. Feature Matching Loss (e.g., L2 loss) between the intermediate feature maps of the teacher and student.
class HintBasedDistillationLoss(nn.Module):
    def __init__(self, temperature=3.0, alpha=0.5, beta=0.5):
        super(HintBasedDistillationLoss, self).__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta
        self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')
        self.ce_loss = nn.CrossEntropyLoss()
        self.l2_loss = nn.MSELoss()

    def forward(self, student_logits, teacher_logits, student_features, teacher_features, labels):
        # Soft targets: apply temperature scaling to teacher outputs
        teacher_soft = torch.softmax(teacher_logits / self.temperature, dim=1)
        student_soft = torch.log_softmax(student_logits / self.temperature, dim=1)

        # Distillation loss (KL divergence between student and teacher's softened outputs)
        distillation_loss = self.kl_div_loss(student_soft, teacher_soft) * (self.temperature ** 2)

        # Cross entropy loss (between student predictions and true labels)
        student_loss = self.ce_loss(student_logits, labels)

        # Feature matching loss (L2 loss between teacher and student feature maps)
        feature_loss = self.l2_loss(student_features, teacher_features)

        # Combined loss
        return self.alpha * distillation_loss + (1.0 - self.alpha) * student_loss + self.beta * feature_loss

Implement the Training Loop

def train_student(teacher_model, student_model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    best_model_wts = student_model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Set student model to training mode
        student_model.train()

        running_loss = 0.0
        running_corrects = 0

        # Iterate over the data
        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Clear gradients for student model
            optimizer.zero_grad()

            # Clear the feature lists before every forward pass
            teacher_features.clear()  # Clear saved teacher features
            student_features.clear()  # Clear saved student features

            # Forward pass through teacher (for soft labels and features)
            with torch.no_grad():  # Disable gradients for teacher model
                teacher_logits = teacher_model(inputs)
                teacher_feature = teacher_features[0]  # Extract intermediate feature from teacher

            # Forward pass through student
            student_logits = student_model(inputs)
            student_feature = student_features[0]  # Extract intermediate feature from student

            # Apply 1x1 convolution to match teacher's feature map dimensions to student's
            teacher_feature_resized = conv_teacher_to_student(teacher_feature)

            # Compute loss
            loss = criterion(student_logits, teacher_logits, student_feature, teacher_feature_resized, labels)

            # Backward pass and optimization
            loss.backward()  # Compute gradients only for the student model
            optimizer.step()

            # Compute running statistics
            _, preds = torch.max(student_logits, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Deep copy the model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = student_model.state_dict()

    # Training complete
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Acc: {best_acc:.4f}')

    # Load best model weights
    student_model.load_state_dict(best_model_wts)
    return student_model

Training and Evaluation

# Define optimizer
optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9)

# Define hint-based distillation loss
criterion = HintBasedDistillationLoss(temperature=3.0, alpha=0.5, beta=0.5)

# Train the student model
trained_student = train_student(teacher_model, student_model, dataloaders, criterion, optimizer, num_epochs=5)
Epoch 0/4
----------
Loss: 0.6921 Acc: 0.8747
Epoch 1/4
----------
Loss: 0.6239 Acc: 0.9241
Epoch 2/4
----------
Loss: 0.6072 Acc: 0.9267
Epoch 3/4
----------
Loss: 0.5918 Acc: 0.9378
Epoch 4/4
----------
Loss: 0.5790 Acc: 0.9413
Training complete in 98m 5s
Best Acc: 0.9413

Evaluation

def evaluate_model(model, dataloaders):
    model.eval()  # Set to evaluation mode
    running_corrects = 0

    for inputs, labels in dataloaders['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad(): # No need to compute gradients during evaluation
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

    accuracy = running_corrects.double() / dataset_sizes['val']
    print(f'Validation Accuracy: {accuracy:.4f}')

# Evaluate trained student model
evaluate_model(trained_student, dataloaders)
Validation Accuracy: 0.9735