Knowledge Distillation Implementation 3/3

Tech
3. Self-Distillation
Author

Leila Mozaffari

Published

October 10, 2024

3. Self-Distillation

Self-distillation is a technique in which a model distills knowledge into itself. In this case, the same model is used both as a teacher and a student, a process where intermediate layers’ outputs help guide earlier layers in the same model. The idea is that the student network learns from itself by taking advantage of various self-regularization strategies, and intermediate outputs can be utilized for knowledge transfer.

In this case, instead of the classical teacher-student approach, we’ll consider ResNet18, with the student learning from its own intermediate outputs. Self-distillation usually results in a student model that generalizes better without needing an external teacher.

References

  • J. Gou, B. Yu, S. J. Maybank, and D. Tao, “Knowledge Distillation: A Survey,” May 20, 2021, arXiv: arXiv:2006.05525. doi: 10.48550/arXiv.2006.05525. https://arxiv.org/abs/2006.05525
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import time


# Define image transformations for training and validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Load Imagenette2-320 dataset
data_dir = './data/imagenette2-320/imagenette2-320'
image_datasets = {x: datasets.ImageFolder(root=f"{data_dir}/{x}", transform=data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
               for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

# Set the device to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Define Models

This 1x1 convolution layer reduces the number of channels from 512 (layer4 output in ResNet18) to 128 (layer2 output) to ensure compatibility when comparing the feature maps from these two layers in the distillation process.

# Load the pre-trained ResNet18 model (used as the student)
student_model = models.resnet18(pretrained=True)

# Modify the last layer to match the number of classes in Imagenette (10 classes)
num_ftrs_student = student_model.fc.in_features
student_model.fc = nn.Linear(num_ftrs_student, 10)

# Move the model to the appropriate device
student_model = student_model.to(device)


# Define a 1x1 convolution to match the number of channels
conv1x1 = nn.Conv2d(512, 128, kernel_size=1).to(device)

Self-Distillation Strategy

Feature Extraction Hooks

  • Hooks: Forward hooks are used to extract the intermediate feature maps from layer2 and layer4 of ResNet18 during the forward pass. The extracted features are stored in the intermediate_features dictionary.
# Helper function to register hooks and extract intermediate features
def get_intermediate_features(module, input, output):
    return output

# Dictionaries to store the intermediate features during forward pass
intermediate_features = {}

# Register forward hooks to capture features from desired layers
student_model.layer2[1].register_forward_hook(lambda m, i, o: intermediate_features.update({"layer2": o}))
student_model.layer4[1].register_forward_hook(lambda m, i, o: intermediate_features.update({"layer4": o}))
<torch.utils.hooks.RemovableHandle at 0x2bb98f80f10>

Custom Self-Distillation Loss

  • Cross-Entropy Loss (ce_loss): Computes the loss between the predicted class labels (logits) and the ground truth labels.
  • MSE Loss (distillation_loss): Encourages the student model to match its intermediate feature maps (layer2) with those of the deeper layers (layer4) by applying Mean Squared Error.
  • 1x1 Convolution: Before applying MSE, the deeper layer’s output (layer4) is passed through a 1x1 convolution to match the number of channels with layer2.
  • Interpolation: The student’s layer2 feature map is resized to match the spatial dimensions of the deeper layer4 features.
  • Alpha: Balances the importance of distillation loss and cross-entropy loss.

class SelfDistillationLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(SelfDistillationLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
        self.alpha = alpha  # Weight for self-distillation

    def forward(self, student_logits, labels, student_intermediate, teacher_intermediate):
        # Standard cross-entropy loss on the output logits
        ce_loss = self.ce_loss(student_logits, labels)
        
        # Apply 1x1 convolution to teacher_intermediate to match the number of channels
        teacher_intermediate_reduced = conv1x1(teacher_intermediate)
        
        # Resize the student_intermediate feature map to match teacher_intermediate's spatial size
        student_intermediate_resized = F.interpolate(student_intermediate, size=teacher_intermediate_reduced.shape[2:], mode='bilinear', align_corners=False)
        
        # Self-distillation loss (MSE between resized student intermediate outputs and teacher intermediate outputs)
        distillation_loss = self.mse_loss(student_intermediate_resized, teacher_intermediate_reduced)
        
        # Combine the two losses
        return self.alpha * distillation_loss + (1 - self.alpha) * ce_loss

Training Loop

This function trains the student model using the self-distillation loss:

  • Training Loop: Iterates over the training data for a specified number of epochs (25 by default).
  • Forward Pass: For each batch, the model computes the outputs, and intermediate features are extracted via hooks.
  • Loss Computation: The total loss is computed by combining the cross-entropy loss and the self-distillation loss.
  • Backward Pass: The loss is used to perform backpropagation and update the model’s weights.
  • Accuracy: Tracks the accuracy for each epoch, and the best model weights are saved.
def train_student(model, dataloaders, criterion, optimizer, num_epochs=25):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Set the model to training mode
        model.train()

        running_loss = 0.0
        running_corrects = 0

        # Iterate over the data
        for inputs, labels in dataloaders['train']:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass through the student model
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            # Get the intermediate feature maps
            layer2_features = intermediate_features['layer2']
            layer4_features = intermediate_features['layer4']

            # Compute the loss (self-distillation loss)
            loss = criterion(outputs, labels, layer2_features, layer4_features)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Accumulate statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # Calculate epoch loss and accuracy
        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Deep copy the model if it's the best so far
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

Optimizer and Training

# Define optimizer
optimizer = optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9)

# Define the self-distillation loss function
criterion = SelfDistillationLoss(alpha=0.5)

# Train the student model using self-distillation
trained_student = train_student(student_model, dataloaders, criterion, optimizer, num_epochs=5)
Epoch 0/4
----------
Loss: 0.4253 Acc: 0.8775
Epoch 1/4
----------
Loss: 0.2328 Acc: 0.9135
Epoch 2/4
----------
Loss: 0.2063 Acc: 0.9139
Epoch 3/4
----------
Loss: 0.1861 Acc: 0.9177
Epoch 4/4
----------
Loss: 0.1656 Acc: 0.9272
Training complete in 48m 9s
Best Acc: 0.9272

Evaluation

This function evaluates the trained model on the validation set, computing the accuracy by comparing predictions to ground truth labels.

def evaluate_model(model, dataloaders):
    model.eval()
    running_corrects = 0

    for inputs, labels in dataloaders['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

    accuracy = running_corrects.double() / dataset_sizes['val']
    print(f'Validation Accuracy: {accuracy:.4f}')

# Evaluate the trained student model
evaluate_model(trained_student, dataloaders)
Validation Accuracy: 0.9610