

``` python
import warnings
warnings.filterwarnings('ignore')

from fastai.vision.all import *
from fasterbench.benchmark import evaluate_cpu_speed, get_model_size, get_num_parameters
```

<!-- WARNING: THIS FILE WAS AUTOGENERATED! DO NOT EDIT! -->

``` python
import torch.nn as nn
import torch
class dfus_block(nn.Module):
    def __init__(self, dim):
        super(dfus_block, self).__init__()
        self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False)

        self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
        self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)

        self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
        self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)

        self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False)

        #### activation function
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        feat = self.relu(self.conv1(x))
        feat_up1 = self.relu(self.conv_up1(feat))
        feat_up2 = self.relu(self.conv_up2(feat_up1))
        feat_down1 = self.relu(self.conv_down1(feat))
        feat_down2 = self.relu(self.conv_down2(feat_down1))
        feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
        feat_fution = self.relu(self.conv_fution(feat_fution))
        out = torch.cat([x, feat_fution], dim=1)
        return out

class ddfn(nn.Module):
    def __init__(self, dim, num_blocks=78):
        super(ddfn, self).__init__()

        self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
        self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)

        self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
        self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)

        dfus_blocks = [dfus_block(dim=128+32*i) for i in range(num_blocks)]
        self.dfus_blocks = nn.Sequential(*dfus_blocks)

        #### activation function
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        feat_up1 = self.relu(self.conv_up1(x))
        feat_up2 = self.relu(self.conv_up2(feat_up1))
        feat_down1 = self.relu(self.conv_down1(x))
        feat_down2 = self.relu(self.conv_down2(feat_down1))
        feat_fution = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
        out = self.dfus_blocks(feat_fution)
        return out

class HSCNN_Plus(nn.Module):
    def __init__(self, in_channels=3, out_channels=31, num_blocks=30):
        super(HSCNN_Plus, self).__init__()

        self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks)
        self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False)

    def forward(self, x):
        """
        x: [b,c,h,w]
        return out:[b,c,h,w]
        """
        fea = self.ddfn(x)
        out =  self.conv_out(fea)
        return out
```

``` python
# def get_dls(size, bs):
#     path = URLs.IMAGENETTE_160
#     source = untar_data(path)
#     blocks=(ImageBlock, CategoryBlock)
#     tfms = [RandomResizedCrop(size, min_scale=0.35), FlipItem(0.5)]
#     batch_tfms = [Normalize.from_stats(*imagenet_stats)]

#     csv_file = 'noisy_imagenette.csv'
#     inp = pd.read_csv(source/csv_file)
#     dblock = DataBlock(blocks=blocks,
#                splitter=ColSplitter(),
#                get_x=ColReader('path', pref=source),
#                get_y=ColReader(f'noisy_labels_0'),
#                item_tfms=tfms,
#                batch_tfms=batch_tfms)

#     return dblock.dataloaders(inp, path=source, bs=bs)
```

``` python
# size, bs = 128, 32
# dls = get_dls(size, bs)
```

``` python
model_path='/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_develop_code/model_zoo/hscnn_plus.pth'
data_root= '/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/'
```

``` python
# model_path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_challenge_code/model_zoo/hscnn_plus.pth')

# path = '/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/Train_RGB/'
```

``` python
from fastai.vision.all import *
from pathlib import Path
import torch

# Set your dataset path
path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/')
val_path = path / 'Test_RGB'  # Adjust based on your folder structure
```

``` python
# from fastai.vision.all import *

# Define the path to your dataset
path = Path('/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/dataset/')  # Set this to your validation data folder

# DataBlock for image-to-image tasks
data_block = DataBlock(
    blocks=(ImageBlock, ImageBlock),  # Both input and output are images
    get_items=get_image_files,  # Gets the image files
    get_x=lambda f: PILImage.create(f),  # Use image as input
    get_y=lambda f: PILImage.create(f),  # Use the same image as output
    splitter=RandomSplitter(valid_pct=0.2),  # Split for training/validation (adjust as needed)
    item_tfms=Resize(64),  # Resize transformation, adjust as per your requirement
)

# Create DataLoaders
dls = data_block.dataloaders(path, bs=1)  # Adjust batch size based on memory limits
```

``` python
# #| eval: false
# dls = data_block.dataloaders(val_path, bs=5)  # Use the appropriate batch size
# dls.show_batch()
```

``` python
# Grab a batch from the training DataLoader
x, y = dls.one_batch()

# Check the shape of inputs and outputs
print("Input (x) shape:", x.shape)
print("Target (y) shape:", y.shape)
```

    Input (x) shape: torch.Size([5, 3, 64, 64])
    Target (y) shape: torch.Size([5, 3, 64, 64])

``` python
# # | eval: false
# model = HSCNN_Plus()
# checkpoint = torch.load(model_path)
# if 'state_dict' in checkpoint:
#     model.load_state_dict(checkpoint['state_dict'])
# else:
#     model.load_state_dict(checkpoint)
# model.eval()
```

``` python
# print(model)
# print(torch.load(model_path).keys())
```

``` python
model = HSCNN_Plus()  # Initialize your custom model
# Load model checkpoint
checkpoint = torch.load(model_path, map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
if 'state_dict' in checkpoint:
    model.load_state_dict(checkpoint['state_dict'])
else:
    model.load_state_dict(checkpoint)

model.eval()  # Set to evaluation mode (good practice for inference)
```

    HSCNN_Plus(
      (ddfn): ddfn(
        (conv_up1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv_up2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (conv_down1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (conv_down2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dfus_blocks): Sequential(
          (0): dfus_block(
            (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (1): dfus_block(
            (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (2): dfus_block(
            (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (3): dfus_block(
            (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (4): dfus_block(
            (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (5): dfus_block(
            (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (6): dfus_block(
            (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (7): dfus_block(
            (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (8): dfus_block(
            (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (9): dfus_block(
            (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (10): dfus_block(
            (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (11): dfus_block(
            (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (12): dfus_block(
            (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (13): dfus_block(
            (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (14): dfus_block(
            (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (15): dfus_block(
            (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (16): dfus_block(
            (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (17): dfus_block(
            (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (18): dfus_block(
            (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (19): dfus_block(
            (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (20): dfus_block(
            (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (21): dfus_block(
            (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (22): dfus_block(
            (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (23): dfus_block(
            (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (24): dfus_block(
            (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (25): dfus_block(
            (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (26): dfus_block(
            (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (27): dfus_block(
            (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (28): dfus_block(
            (conv1): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
          (29): dfus_block(
            (conv1): Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
            (relu): ReLU(inplace=True)
          )
        )
        (relu): ReLU(inplace=True)
      )
      (conv_out): Conv2d(1088, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
    )

``` python
from torch.nn import MSELoss

# Create the Learner with MSE Loss
learn = Learner(dls, model, loss_func=MSELoss())
```

``` python
from torch.nn import MSELoss
# Train or fine-tune the model (optional)
model = HSCNN_Plus(in_channels=3, out_channels=3, num_blocks=5)  # Reduce num_blocks significantly
learn = Learner(dls, model.to('cpu'), loss_func=MSELoss())

learn.fit_one_cycle(5, lr_max=1e-4)
# learn.fit_one_cycle(4, 1e-4)
# Run inference on validation set
# preds, targs = learn.get_preds(dl=dls.valid)  # Get predictions
```

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

<table class="dataframe" data-quarto-postprocess="true" data-border="1">
<thead>
<tr style="text-align: left;">
<th data-quarto-table-cell-role="th">epoch</th>
<th data-quarto-table-cell-role="th">train_loss</th>
<th data-quarto-table-cell-role="th">valid_loss</th>
<th data-quarto-table-cell-role="th">time</th>
</tr>
</thead>
<tbody>
<tr>
<td>0</td>
<td>0.047585</td>
<td>0.043253</td>
<td>00:00</td>
</tr>
<tr>
<td>1</td>
<td>0.045987</td>
<td>0.039191</td>
<td>00:00</td>
</tr>
<tr>
<td>2</td>
<td>0.043653</td>
<td>0.035421</td>
<td>00:00</td>
</tr>
<tr>
<td>3</td>
<td>0.041533</td>
<td>0.033387</td>
<td>00:00</td>
</tr>
<tr>
<td>4</td>
<td>0.039822</td>
<td>0.032988</td>
<td>00:00</td>
</tr>
</tbody>
</table>

``` python
# from fastai.callback.all import GradientAccumulation

# # Set gradient accumulation steps to effectively multiply your batch size by this factor
# accumulation_steps = 8  # Adjust based on your needs and memory constraints

# # Create the Learner with gradient accumulation and mixed precision
# learn = Learner(dls, model, loss_func=MSELoss(), cbs=[GradientAccumulation(n_acc=accumulation_steps)]).to_fp16()
# learn.fit_one_cycle(5, lr_max=1e-4)
```

``` python
# files = get_image_files(path)

# def label_func(f): return f[0].isupper()

# dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(128),bs=32)
```

``` python
import torch
torch.cuda.empty_cache()
```

``` python
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
```

``` python
from torch.nn import MSELoss  # Use Mean Squared Error as loss function for image-to-image tasks

# Define the Learner with MSE loss
learn = Learner(dls, model, loss_func=MSELoss())
```

``` python
# learn = Learner(dls, model, metrics=[accuracy])
```

``` python
num_parameters = get_num_parameters(learn.model)
disk_size = get_model_size(learn.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")
```

    Model Size: 2.08 MB (disk), 516640 parameters

``` python
model = learn.model.eval().to('cpu')
x,y = dls.one_batch()
```

``` python
print(f'Inference Speed: {evaluate_cpu_speed(learn.model, x[0][None])[0]:.2f}ms')
```

    Inference Speed: 15.21ms

``` python
x, y = dls.one_batch()
print("Input Shape:", x.shape)
print("Target Shape:", y.shape)
```

    Input Shape: torch.Size([5, 3, 64, 64])
    Target Shape: torch.Size([5, 3, 64, 64])

------------------------------------------------------------------------

<br>

## **Knowledge Distillation**

<br>

<blockquote>

<pre><b><i> KnowledgeDistillation(teacher.model, loss) </i></b></pre>

<p style="font-size: 15px">

<i> You only need to give to the callback function your teacher learner.
Behind the scenes, FasterAI will take care of making your model train
using knowledge distillation. </i>
</p>

</blockquote>

<br>

``` python
from fasterai.distill.all import *
```

``` python
import torch

torch.cuda.empty_cache()
```

``` python
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
```

``` python
# sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6  # Total trainable parameters in millions
```

``` python
# import torch

# print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024 ** 2:.2f} MB")
# print(f"Cached memory: {torch.cuda.memory_reserved() / 1024 ** 2:.2f} MB")
```

``` python
# !nvidia-smi
```

``` python
# !kill -9 58089
```

``` python
from torch.nn import MSELoss
# Train or fine-tune the model (optional)
model = HSCNN_Plus(in_channels=3, out_channels=3, num_blocks=5)  # Reduce num_blocks significantly
teacher = Learner(dls, model.to('cpu'), loss_func=MSELoss())

teacher.fit_one_cycle(10, lr_max=1e-4)
# learn.fit_one_cycle(4, 1e-4)
```

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

<table class="dataframe" data-quarto-postprocess="true" data-border="1">
<thead>
<tr style="text-align: left;">
<th data-quarto-table-cell-role="th">epoch</th>
<th data-quarto-table-cell-role="th">train_loss</th>
<th data-quarto-table-cell-role="th">valid_loss</th>
<th data-quarto-table-cell-role="th">time</th>
</tr>
</thead>
<tbody>
<tr>
<td>0</td>
<td>0.042094</td>
<td>0.039655</td>
<td>00:00</td>
</tr>
<tr>
<td>1</td>
<td>0.041489</td>
<td>0.037667</td>
<td>00:00</td>
</tr>
<tr>
<td>2</td>
<td>0.040237</td>
<td>0.034148</td>
<td>00:00</td>
</tr>
<tr>
<td>3</td>
<td>0.038269</td>
<td>0.029985</td>
<td>00:00</td>
</tr>
<tr>
<td>4</td>
<td>0.035919</td>
<td>0.025685</td>
<td>00:00</td>
</tr>
<tr>
<td>5</td>
<td>0.033408</td>
<td>0.021808</td>
<td>00:00</td>
</tr>
<tr>
<td>6</td>
<td>0.030831</td>
<td>0.018923</td>
<td>00:00</td>
</tr>
<tr>
<td>7</td>
<td>0.028445</td>
<td>0.017224</td>
<td>00:00</td>
</tr>
<tr>
<td>8</td>
<td>0.026424</td>
<td>0.016530</td>
<td>00:00</td>
</tr>
<tr>
<td>9</td>
<td>0.024772</td>
<td>0.016417</td>
<td>00:00</td>
</tr>
</tbody>
</table>

``` python
from fastai.vision.all import *
from fastai.callback.all import *
from fastai.vision.models.unet import DynamicUnet
from torchvision.models import resnet18

# Step 1: Define the student model with Tiny U-Net structure
# Use only the feature layers (up to the last convolution) of ResNet-18 as the encoder
encoder = nn.Sequential(*list(resnet18(pretrained=True).children())[:-2])  # Remove the last fully connected layers
student_model = DynamicUnet(encoder, n_out=3, img_size=(64, 64))  # Match output channels for your task

# Step 2: Define the Learner for the student model
# Set a suitable loss function for image-to-image tasks like MSELoss
student = Learner(
    dls, 
    student_model, 
    loss_func=MSELoss()#, 
    # metrics=[PSNR()]  # PSNR (Peak Signal-to-Noise Ratio) can be useful for image quality
)

# Step 3: Initialize the KnowledgeDistillationCallback
# Assuming `teacher` is the pre-trained HSCNN_Plus model
kd_cb = KnowledgeDistillationCallback(teacher.model, SoftTarget)

# Step 4: Train the student model with knowledge distillation
student.fit_one_cycle(10, 1e-4, cbs=kd_cb)
```

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

<table class="dataframe" data-quarto-postprocess="true" data-border="1">
<thead>
<tr style="text-align: left;">
<th data-quarto-table-cell-role="th">epoch</th>
<th data-quarto-table-cell-role="th">train_loss</th>
<th data-quarto-table-cell-role="th">valid_loss</th>
<th data-quarto-table-cell-role="th">time</th>
</tr>
</thead>
<tbody>
<tr>
<td>0</td>
<td>206.853897</td>
<td>4.708280</td>
<td>00:01</td>
</tr>
<tr>
<td>1</td>
<td>117.306244</td>
<td>4.507239</td>
<td>00:01</td>
</tr>
<tr>
<td>2</td>
<td>78.637085</td>
<td>3.295578</td>
<td>00:01</td>
</tr>
<tr>
<td>3</td>
<td>57.473248</td>
<td>2.667634</td>
<td>00:01</td>
</tr>
<tr>
<td>4</td>
<td>44.227245</td>
<td>3.118005</td>
<td>00:01</td>
</tr>
<tr>
<td>5</td>
<td>35.264557</td>
<td>3.569392</td>
<td>00:01</td>
</tr>
<tr>
<td>6</td>
<td>28.856447</td>
<td>3.880894</td>
<td>00:01</td>
</tr>
<tr>
<td>7</td>
<td>24.112537</td>
<td>4.008141</td>
<td>00:01</td>
</tr>
<tr>
<td>8</td>
<td>20.488949</td>
<td>4.021497</td>
<td>00:01</td>
</tr>
<tr>
<td>9</td>
<td>17.672688</td>
<td>4.249494</td>
<td>00:01</td>
</tr>
</tbody>
</table>

``` python
num_parameters = get_num_parameters(student.model)
disk_size = get_model_size(student.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters} parameters")
```

    Model Size: 124.56 MB (disk), 31113108 parameters

------------------------------------------------------------------------

<br>

## Quantization

``` python
from fasterai.quantize.quantize_callback import *
```

``` python
teacher.fit_one_cycle(5, 1e-5, cbs=QuantizeCallback())
```

<style>
    /* Turns off some styling */
    progress {
        /* gets rid of default border in Firefox and Opera. */
        border: none;
        /* Needs to be in here for Safari polyfill so background images work as expected. */
        background-size: auto;
    }
    progress:not([value]), progress:not([value])::-webkit-progress-bar {
        background: repeating-linear-gradient(45deg, #7e7e7e, #7e7e7e 10px, #5c5c5c 10px, #5c5c5c 20px);
    }
    .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {
        background: #F44336;
    }
</style>

<table class="dataframe" data-quarto-postprocess="true" data-border="1">
<thead>
<tr style="text-align: left;">
<th data-quarto-table-cell-role="th">epoch</th>
<th data-quarto-table-cell-role="th">train_loss</th>
<th data-quarto-table-cell-role="th">valid_loss</th>
<th data-quarto-table-cell-role="th">time</th>
</tr>
</thead>
<tbody>
<tr>
<td>0</td>
<td>0.018622</td>
<td>0.017505</td>
<td>00:01</td>
</tr>
<tr>
<td>1</td>
<td>0.018870</td>
<td>0.017797</td>
<td>00:00</td>
</tr>
<tr>
<td>2</td>
<td>0.019117</td>
<td>0.018241</td>
<td>00:00</td>
</tr>
<tr>
<td>3</td>
<td>0.019400</td>
<td>0.018464</td>
<td>00:00</td>
</tr>
<tr>
<td>4</td>
<td>0.019566</td>
<td>0.018610</td>
<td>00:00</td>
</tr>
</tbody>
</table>

``` python
print(f'Inference Speed: {evaluate_cpu_speed(teacher.model, x[0][None])[0]:.2f}ms')
```

    Inference Speed: 11.60ms

``` python
def count_parameters_quantized(model):
    total_params = 0
    for module in model.modules():
        if isinstance(module, torch.nn.modules.conv.Conv2d) or \
           isinstance(module, torch.nn.Linear) or \
           isinstance(module, torch.ao.nn.quantized.modules.conv.Conv2d) or \
           isinstance(module, torch.ao.nn.quantized.modules.linear.Linear):
            
            total_params += module.weight().numel()
            
            if module.bias() is not None:
                total_params += module.bias().numel()
    return total_params
```

``` python
num_parameters = count_parameters_quantized(teacher.model)
disk_size = get_model_size(teacher.model)
print(f"Model Size: {disk_size / 1e6:.2f} MB (disk), {num_parameters:,} parameters")
```

    Model Size: 0.59 MB (disk), 514,976 parameters
