import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import time
# Define image transformations for training and validation
= {
data_transforms 'train': transforms.Compose([
224),
transforms.RandomResizedCrop(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),'val': transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),
}
# Load Imagenette2-320 dataset
= './data/imagenette2-320/imagenette2-320'
data_dir = {x: datasets.ImageFolder(root=f"{data_dir}/{x}", transform=data_transforms[x])
image_datasets for x in ['train', 'val']}
= {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
dataloaders for x in ['train', 'val']}
= {x: len(image_datasets[x]) for x in ['train', 'val']}
dataset_sizes = image_datasets['train'].classes
class_names
# Set the device to GPU if available
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device
Knowledge Distillation Implementation 3/3
3. Self-Distillation
Self-distillation is a technique in which a model distills knowledge into itself. In this case, the same model is used both as a teacher and a student, a process where intermediate layers’ outputs help guide earlier layers in the same model. The idea is that the student network learns from itself by taking advantage of various self-regularization strategies, and intermediate outputs can be utilized for knowledge transfer.
In this case, instead of the classical teacher-student approach, we’ll consider ResNet18, with the student learning from its own intermediate outputs. Self-distillation usually results in a student model that generalizes better without needing an external teacher.
References
- J. Gou, B. Yu, S. J. Maybank, and D. Tao, “Knowledge Distillation: A Survey,” May 20, 2021, arXiv: arXiv:2006.05525. doi: 10.48550/arXiv.2006.05525. https://arxiv.org/abs/2006.05525
Define Models
This 1x1 convolution layer reduces the number of channels from 512 (layer4 output in ResNet18) to 128 (layer2 output) to ensure compatibility when comparing the feature maps from these two layers in the distillation process.
# Load the pre-trained ResNet18 model (used as the student)
= models.resnet18(pretrained=True)
student_model
# Modify the last layer to match the number of classes in Imagenette (10 classes)
= student_model.fc.in_features
num_ftrs_student = nn.Linear(num_ftrs_student, 10)
student_model.fc
# Move the model to the appropriate device
= student_model.to(device)
student_model
# Define a 1x1 convolution to match the number of channels
= nn.Conv2d(512, 128, kernel_size=1).to(device) conv1x1
Self-Distillation Strategy
Feature Extraction Hooks
- Hooks: Forward hooks are used to extract the intermediate feature maps from layer2 and layer4 of ResNet18 during the forward pass. The extracted features are stored in the intermediate_features dictionary.
# Helper function to register hooks and extract intermediate features
def get_intermediate_features(module, input, output):
return output
# Dictionaries to store the intermediate features during forward pass
= {}
intermediate_features
# Register forward hooks to capture features from desired layers
1].register_forward_hook(lambda m, i, o: intermediate_features.update({"layer2": o}))
student_model.layer2[1].register_forward_hook(lambda m, i, o: intermediate_features.update({"layer4": o})) student_model.layer4[
<torch.utils.hooks.RemovableHandle at 0x2bb98f80f10>
Custom Self-Distillation Loss
- Cross-Entropy Loss (ce_loss): Computes the loss between the predicted class labels (logits) and the ground truth labels.
- MSE Loss (distillation_loss): Encourages the student model to match its intermediate feature maps (layer2) with those of the deeper layers (layer4) by applying Mean Squared Error.
- 1x1 Convolution: Before applying MSE, the deeper layer’s output (layer4) is passed through a 1x1 convolution to match the number of channels with layer2.
- Interpolation: The student’s layer2 feature map is resized to match the spatial dimensions of the deeper layer4 features.
- Alpha: Balances the importance of distillation loss and cross-entropy loss.
class SelfDistillationLoss(nn.Module):
def __init__(self, alpha=0.5):
super(SelfDistillationLoss, self).__init__()
self.ce_loss = nn.CrossEntropyLoss()
self.mse_loss = nn.MSELoss()
self.alpha = alpha # Weight for self-distillation
def forward(self, student_logits, labels, student_intermediate, teacher_intermediate):
# Standard cross-entropy loss on the output logits
= self.ce_loss(student_logits, labels)
ce_loss
# Apply 1x1 convolution to teacher_intermediate to match the number of channels
= conv1x1(teacher_intermediate)
teacher_intermediate_reduced
# Resize the student_intermediate feature map to match teacher_intermediate's spatial size
= F.interpolate(student_intermediate, size=teacher_intermediate_reduced.shape[2:], mode='bilinear', align_corners=False)
student_intermediate_resized
# Self-distillation loss (MSE between resized student intermediate outputs and teacher intermediate outputs)
= self.mse_loss(student_intermediate_resized, teacher_intermediate_reduced)
distillation_loss
# Combine the two losses
return self.alpha * distillation_loss + (1 - self.alpha) * ce_loss
Training Loop
This function trains the student model using the self-distillation loss:
- Training Loop: Iterates over the training data for a specified number of epochs (25 by default).
- Forward Pass: For each batch, the model computes the outputs, and intermediate features are extracted via hooks.
- Loss Computation: The total loss is computed by combining the cross-entropy loss and the self-distillation loss.
- Backward Pass: The loss is used to perform backpropagation and update the model’s weights.
- Accuracy: Tracks the accuracy for each epoch, and the best model weights are saved.
def train_student(model, dataloaders, criterion, optimizer, num_epochs=25):
= time.time()
since
= model.state_dict()
best_model_wts = 0.0
best_acc
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Set the model to training mode
model.train()
= 0.0
running_loss = 0
running_corrects
# Iterate over the data
for inputs, labels in dataloaders['train']:
= inputs.to(device)
inputs = labels.to(device)
labels
# Zero the gradients
optimizer.zero_grad()
# Forward pass through the student model
= model(inputs)
outputs = torch.max(outputs, 1)
_, preds
# Get the intermediate feature maps
= intermediate_features['layer2']
layer2_features = intermediate_features['layer4']
layer4_features
# Compute the loss (self-distillation loss)
= criterion(outputs, labels, layer2_features, layer4_features)
loss
# Backward pass and optimization
loss.backward()
optimizer.step()
# Accumulate statistics
+= loss.item() * inputs.size(0)
running_loss += torch.sum(preds == labels.data)
running_corrects
# Calculate epoch loss and accuracy
= running_loss / dataset_sizes['train']
epoch_loss = running_corrects.double() / dataset_sizes['train']
epoch_acc
print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Deep copy the model if it's the best so far
if epoch_acc > best_acc:
= epoch_acc
best_acc = model.state_dict()
best_model_wts
= time.time() - since
time_elapsed print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best Acc: {best_acc:.4f}')
# Load best model weights
model.load_state_dict(best_model_wts)return model
Optimizer and Training
# Define optimizer
= optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9)
optimizer
# Define the self-distillation loss function
= SelfDistillationLoss(alpha=0.5)
criterion
# Train the student model using self-distillation
= train_student(student_model, dataloaders, criterion, optimizer, num_epochs=5) trained_student
Epoch 0/4
----------
Loss: 0.4253 Acc: 0.8775
Epoch 1/4
----------
Loss: 0.2328 Acc: 0.9135
Epoch 2/4
----------
Loss: 0.2063 Acc: 0.9139
Epoch 3/4
----------
Loss: 0.1861 Acc: 0.9177
Epoch 4/4
----------
Loss: 0.1656 Acc: 0.9272
Training complete in 48m 9s
Best Acc: 0.9272
Evaluation
This function evaluates the trained model on the validation set, computing the accuracy by comparing predictions to ground truth labels.
def evaluate_model(model, dataloaders):
eval()
model.= 0
running_corrects
for inputs, labels in dataloaders['val']:
= inputs.to(device)
inputs = labels.to(device)
labels
with torch.no_grad():
= model(inputs)
outputs = torch.max(outputs, 1)
_, preds += torch.sum(preds == labels.data)
running_corrects
= running_corrects.double() / dataset_sizes['val']
accuracy print(f'Validation Accuracy: {accuracy:.4f}')
# Evaluate the trained student model
evaluate_model(trained_student, dataloaders)
Validation Accuracy: 0.9610