import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import time
# Define transforms for data augmentation
= {
data_transforms 'train': transforms.Compose([
224),
transforms.RandomResizedCrop(
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),'val': transforms.Compose([
256),
transforms.Resize(224),
transforms.CenterCrop(
transforms.ToTensor(),0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms.Normalize([
]),
}
# Load the Imagenette2-320 dataset
= './data/imagenette2-320/imagenette2-320'
data_dir = {x: datasets.ImageFolder(root=f"{data_dir}/{x}", transform=data_transforms[x])
image_datasets for x in ['train', 'val']}
= {x: DataLoader(image_datasets[x], batch_size=32, shuffle=True, num_workers=4)
dataloaders for x in ['train', 'val']}
= {x: len(image_datasets[x]) for x in ['train', 'val']}
dataset_sizes = image_datasets['train'].classes
class_names
# Check device availability
= torch.device("cuda:0" if torch.cuda.is_available() else "cpu") device
Knowledge Distillation Implementation 2/3
Tech
2. Hint-Based Distillation
2.Hint-Based Distillation (Intermediate Feature Matching) technique
Reference:
- Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2015). FitNets: Hints for Thin Deep Nets. arXiv preprint arXiv:1412.6550. Retrieved from https://arxiv.org/abs/1412.6550
Define Teacher and Student Models
# Load pre-trained teacher (ResNet50) and student (ResNet18)
= models.resnet50(pretrained=True)
teacher_model = models.resnet18(pretrained=True)
student_model
# Adjust final layers to match number of classes in Imagenette (10 classes)
= teacher_model.fc.in_features
num_ftrs_teacher = nn.Linear(num_ftrs_teacher, 10)
teacher_model.fc
= student_model.fc.in_features
num_ftrs_student = nn.Linear(num_ftrs_student, 10)
student_model.fc
# Move models to the appropriate device (GPU if available)
= teacher_model.to(device)
teacher_model = student_model.to(device)
student_model
# Set teacher model to evaluation mode (as it is not being trained)
eval()
teacher_model.
# Define 1x1 convolution to match the dimensions between teacher and student feature maps
# Assuming teacher's layer3 outputs 1024 channels and student's layer3 outputs 256 channels
= nn.Conv2d(1024, 256, kernel_size=1).to(device)
conv_teacher_to_student
Extract Intermediate Feature Representations
We need to extract intermediate features from both the teacher and the student models. One way to achieve this is by using forward hooks in PyTorch to capture activations at specific layers. In this case, we’ll extract features from a chosen layer in both models, for example, the output of the third residual block in both models.
# Helper function to register a hook for feature extraction
def extract_features(module, input, output):
return output
# Extract features from the third residual block (layer3) for both models
= []
teacher_features = []
student_features
def register_hooks(model, features_storage):
def hook(module, input, output):
features_storage.append(output)return hook
# Register hook to extract features from teacher model (layer3 output)
5].register_forward_hook(register_hooks(teacher_model, teacher_features))
teacher_model.layer3[
# Register hook to extract features from student model (layer3 output)
1].register_forward_hook(register_hooks(student_model, student_features)) student_model.layer3[
<torch.utils.hooks.RemovableHandle at 0x21cb76104c0>
Define the Custom Distillation Loss
e now define a custom loss function that combines:
- Cross-Entropy Loss on the student’s hard predictions against the ground truth labels.
- KL Divergence Loss between the teacher’s and student’s soft logits (output of final layer).
- Feature Matching Loss (e.g., L2 loss) between the intermediate feature maps of the teacher and student.
class HintBasedDistillationLoss(nn.Module):
def __init__(self, temperature=3.0, alpha=0.5, beta=0.5):
super(HintBasedDistillationLoss, self).__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
self.kl_div_loss = nn.KLDivLoss(reduction='batchmean')
self.ce_loss = nn.CrossEntropyLoss()
self.l2_loss = nn.MSELoss()
def forward(self, student_logits, teacher_logits, student_features, teacher_features, labels):
# Soft targets: apply temperature scaling to teacher outputs
= torch.softmax(teacher_logits / self.temperature, dim=1)
teacher_soft = torch.log_softmax(student_logits / self.temperature, dim=1)
student_soft
# Distillation loss (KL divergence between student and teacher's softened outputs)
= self.kl_div_loss(student_soft, teacher_soft) * (self.temperature ** 2)
distillation_loss
# Cross entropy loss (between student predictions and true labels)
= self.ce_loss(student_logits, labels)
student_loss
# Feature matching loss (L2 loss between teacher and student feature maps)
= self.l2_loss(student_features, teacher_features)
feature_loss
# Combined loss
return self.alpha * distillation_loss + (1.0 - self.alpha) * student_loss + self.beta * feature_loss
Implement the Training Loop
def train_student(teacher_model, student_model, dataloaders, criterion, optimizer, num_epochs=25):
= time.time()
since
= student_model.state_dict()
best_model_wts = 0.0
best_acc
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
# Set student model to training mode
student_model.train()
= 0.0
running_loss = 0
running_corrects
# Iterate over the data
for inputs, labels in dataloaders['train']:
= inputs.to(device)
inputs = labels.to(device)
labels
# Clear gradients for student model
optimizer.zero_grad()
# Clear the feature lists before every forward pass
# Clear saved teacher features
teacher_features.clear() # Clear saved student features
student_features.clear()
# Forward pass through teacher (for soft labels and features)
with torch.no_grad(): # Disable gradients for teacher model
= teacher_model(inputs)
teacher_logits = teacher_features[0] # Extract intermediate feature from teacher
teacher_feature
# Forward pass through student
= student_model(inputs)
student_logits = student_features[0] # Extract intermediate feature from student
student_feature
# Apply 1x1 convolution to match teacher's feature map dimensions to student's
= conv_teacher_to_student(teacher_feature)
teacher_feature_resized
# Compute loss
= criterion(student_logits, teacher_logits, student_feature, teacher_feature_resized, labels)
loss
# Backward pass and optimization
# Compute gradients only for the student model
loss.backward()
optimizer.step()
# Compute running statistics
= torch.max(student_logits, 1)
_, preds += loss.item() * inputs.size(0)
running_loss += torch.sum(preds == labels.data)
running_corrects
= running_loss / dataset_sizes['train']
epoch_loss = running_corrects.double() / dataset_sizes['train']
epoch_acc
print(f'Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
# Deep copy the model
if epoch_acc > best_acc:
= epoch_acc
best_acc = student_model.state_dict()
best_model_wts
# Training complete
= time.time() - since
time_elapsed print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best Acc: {best_acc:.4f}')
# Load best model weights
student_model.load_state_dict(best_model_wts)return student_model
Training and Evaluation
# Define optimizer
= optim.SGD(student_model.parameters(), lr=0.01, momentum=0.9)
optimizer
# Define hint-based distillation loss
= HintBasedDistillationLoss(temperature=3.0, alpha=0.5, beta=0.5)
criterion
# Train the student model
= train_student(teacher_model, student_model, dataloaders, criterion, optimizer, num_epochs=5) trained_student
Epoch 0/4
----------
Loss: 0.6921 Acc: 0.8747
Epoch 1/4
----------
Loss: 0.6239 Acc: 0.9241
Epoch 2/4
----------
Loss: 0.6072 Acc: 0.9267
Epoch 3/4
----------
Loss: 0.5918 Acc: 0.9378
Epoch 4/4
----------
Loss: 0.5790 Acc: 0.9413
Training complete in 98m 5s
Best Acc: 0.9413
Evaluation
def evaluate_model(model, dataloaders):
eval() # Set to evaluation mode
model.= 0
running_corrects
for inputs, labels in dataloaders['val']:
= inputs.to(device)
inputs = labels.to(device)
labels
with torch.no_grad(): # No need to compute gradients during evaluation
= model(inputs)
outputs = torch.max(outputs, 1)
_, preds += torch.sum(preds == labels.data)
running_corrects
= running_corrects.double() / dataset_sizes['val']
accuracy print(f'Validation Accuracy: {accuracy:.4f}')
# Evaluate trained student model
evaluate_model(trained_student, dataloaders)
Validation Accuracy: 0.9735