# create environment: conda create -n ostrack python=3.8
# conda activate MST
#cd MST-plus-plus/
Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction (CVPRW 2022)
MST++: Multi-stage Spectral-wise Transformer for Efficient Spectral Reconstruction (CVPRW 2022)
This work was a winner at the Spectral Reconstruction Challenge 2022- NTIRE.
GitHub link: https://github.com/caiyuanhao1998/MST-plus-plus/tree/master
This repo is a baseline and toolbox containing 11 image restoration algorithms for Spectral Reconstruction.Paper Link: https://arxiv.org/abs/2204.07908
Dataset: https://github.com/boazarad/ARAD_1K
Hyperspectral imaging (HSI), captures the spectral information of a scene across multiple wavelengths, making it much richer than regular RGB images. HSI has widespread applications in fields like medical imaging, remote sensing, and environmental monitoring. Traditional methods for HSI need expensive hardware and are time-intensive, which limits their use in dynamic or real-time scenarios. This work aims to generate high-quality HSIs from RGB images. Current approaches using convolutional neural networks (CNNs) which perform well but face different challenges: * They struggle to model long-range dependencies. * They fail to effectively capture self-similarity within the spectral data.
In order to address these limitations, the authors introduce MST++, a Transformer-based framework for spectral reconstruction.
Pipeline of MST++ Model Architecture
- MST++: Multi-stage Spectral-wise Transformer
- SST: Single-stage Spectral-wise Transformer
- SAB: Spectral-wise Attention Block
- FFN: Feed Forward Network
- S-MSA: Spectral-wise Multi-head Self-attention
Dataset:
ARAD_1K, The dataset provided by NTIRE 2022 Spectral Reconstruction Challenge contains 1000 RGB-HSI pairs. This dataset is split into train, valid, and test subsets in proportional to 18:1:1. Each HSI at size of 482Γ512 has 31 wavelengths from 400 nm to 700 nm.
This data set contains 950 publicly images and 50 βtestβ images which have only been released as RGB or Multi-Spectral-Filter-Array βRAWβ files.
Publicly available images have been published as:
- 31 channel hyperspectral images in the 400-700nm range
- 16 channel multi-spectral images in the 400-1000nm range
1. Create Envirement:
Python 3 (Recommend to use Anaconda)
NVIDIA GPU + CUDA
Python packages : requirements.txt
2. Data Preparation:
Download datasets from the main GitHub repo. * training spectral images (Google Drive / Baidu Disk, code: mst1): Size 19G * training RGB images (Google Drive / Baidu Disk): Size 22M * validation spectral images (Google Drive / Baidu Disk): Size 1G * validation RGB images (Google Drive / Baidu Disk): Size 1M * testing RGB images (Google Drive / Baidu Disk): Size 1M
Directory of the dataset:
Place the training spectral images and validation spectral images to /MST-plus-plus/dataset/Train_Spec/.
Place the training RGB images and validation RGB images to /MST-plus-plus/dataset/Train_RGB/.
Place the testing RGB images to /MST-plus-plus/dataset/Test_RGB/.
Then this repo is collected as the following form:
|βMST-plus-plus |βtest_challenge_code |βtest_develop_code |βtrain_code
|βdataset |βTrain_Spec |βARAD_1K_0001.mat |βARAD_1K_0002.mat οΌ |βARAD_1K_0950.mat |βTrain_RGB |βARAD_1K_0001.jpg |βARAD_1K_0002.jpg οΌ |βARAD_1K_0950.jpg |βTest_RGB |βARAD_1K_0951.jpg |βARAD_1K_0952.jpg οΌ |βARAD_1K_1000.jpg |βsplit_txt |βtrain_list.txt |βvalid_list.txt
# !pip install -r requirements.txt
Different Methods:
- HSCNN+
- HRNet
- EDSR
- AWAN
- HDNet
- HINet
- MIRNet
- Restormer
- MPRNet
- MST-L
- MST ++
MST_Plus_Plus
import torch.nn as nn
import torch
import torch.nn.functional as F
from einops import rearrange
import math
import warnings
from torch.nn.init import _calculate_fan_in_and_fan_out
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
warnings.warn("The distribution of values may be incorrect.",
=2)
stacklevelwith torch.no_grad():
= norm_cdf((a - mean) / std)
l = norm_cdf((b - mean) / std)
u 2 * l - 1, 2 * u - 1)
tensor.uniform_(
tensor.erfinv_()* math.sqrt(2.))
tensor.mul_(std
tensor.add_(mean)min=a, max=b)
tensor.clamp_(return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
= _calculate_fan_in_and_fan_out(tensor)
fan_in, fan_out if mode == 'fan_in':
= fan_in
denom elif mode == 'fan_out':
= fan_out
denom elif mode == 'fan_avg':
= (fan_in + fan_out) / 2
denom = scale / denom
variance if distribution == "truncated_normal":
=math.sqrt(variance) / .87962566103423978)
trunc_normal_(tensor, stdelif distribution == "normal":
=math.sqrt(variance))
tensor.normal_(stdelif distribution == "uniform":
= math.sqrt(3 * variance)
bound -bound, bound)
tensor.uniform_(else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
='fan_in', distribution='truncated_normal')
variance_scaling_(tensor, mode
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, *args, **kwargs):
= self.norm(x)
x return self.fn(x, *args, **kwargs)
class GELU(nn.Module):
def forward(self, x):
return F.gelu(x)
def conv(in_channels, out_channels, kernel_size, bias=False, padding = 1, stride = 1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,=(kernel_size//2), bias=bias, stride=stride)
padding
def shift_back(inputs,step=2): # input [bs,28,256,310] output [bs, 28, 256, 256]
= inputs.shape
[bs, nC, row, col] = 256//row
down_sample = float(step)/float(down_sample*down_sample)
step = row
out_col for i in range(nC):
= \
inputs[:,i,:,:out_col] int(step*i):int(step*i)+out_col]
inputs[:,i,:,return inputs[:, :, :, :out_col]
class MS_MSA(nn.Module):
def __init__(
self,
dim,
dim_head,
heads,
):super().__init__()
self.num_heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
self.proj = nn.Linear(dim_head * heads, dim, bias=True)
self.pos_emb = nn.Sequential(
3, 1, 1, bias=False, groups=dim),
nn.Conv2d(dim, dim,
GELU(),3, 1, 1, bias=False, groups=dim),
nn.Conv2d(dim, dim,
)self.dim = dim
def forward(self, x_in):
"""
x_in: [b,h,w,c]
return out: [b,h,w,c]
"""
= x_in.shape
b, h, w, c = x_in.reshape(b,h*w,c)
x = self.to_q(x)
q_inp = self.to_k(x)
k_inp = self.to_v(x)
v_inp = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
q, k, v
(q_inp, k_inp, v_inp))= v
v # q: b,heads,hw,c
= q.transpose(-2, -1)
q = k.transpose(-2, -1)
k = v.transpose(-2, -1)
v = F.normalize(q, dim=-1, p=2)
q = F.normalize(k, dim=-1, p=2)
k = (k @ q.transpose(-2, -1)) # A = K^T*Q
attn = attn * self.rescale
attn = attn.softmax(dim=-1)
attn = attn @ v # b,heads,d,hw
x = x.permute(0, 3, 1, 2) # Transpose
x = x.reshape(b, h * w, self.num_heads * self.dim_head)
x = self.proj(x).view(b, h, w, c)
out_c = self.pos_emb(v_inp.reshape(b,h,w,c).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
out_p = out_c + out_p
out
return out
class FeedForward(nn.Module):
def __init__(self, dim, mult=4):
super().__init__()
self.net = nn.Sequential(
* mult, 1, 1, bias=False),
nn.Conv2d(dim, dim
GELU(),* mult, dim * mult, 3, 1, 1, bias=False, groups=dim * mult),
nn.Conv2d(dim
GELU(),* mult, dim, 1, 1, bias=False),
nn.Conv2d(dim
)
def forward(self, x):
"""
x: [b,h,w,c]
return out: [b,h,w,c]
"""
= self.net(x.permute(0, 3, 1, 2))
out return out.permute(0, 2, 3, 1)
class MSAB(nn.Module):
def __init__(
self,
dim,
dim_head,
heads,
num_blocks,
):super().__init__()
self.blocks = nn.ModuleList([])
for _ in range(num_blocks):
self.blocks.append(nn.ModuleList([
=dim, dim_head=dim_head, heads=heads),
MS_MSA(dim=dim))
PreNorm(dim, FeedForward(dim
]))
def forward(self, x):
"""
x: [b,c,h,w]
return out: [b,c,h,w]
"""
= x.permute(0, 2, 3, 1)
x for (attn, ff) in self.blocks:
= attn(x) + x
x = ff(x) + x
x = x.permute(0, 3, 1, 2)
out return out
class MST(nn.Module):
def __init__(self, in_dim=31, out_dim=31, dim=31, stage=2, num_blocks=[2,4,4]):
super(MST, self).__init__()
self.dim = dim
self.stage = stage
# Input projection
self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)
# Encoder
self.encoder_layers = nn.ModuleList([])
= dim
dim_stage for i in range(stage):
self.encoder_layers.append(nn.ModuleList([
MSAB(=dim_stage, num_blocks=num_blocks[i], dim_head=dim, heads=dim_stage // dim),
dim* 2, 4, 2, 1, bias=False),
nn.Conv2d(dim_stage, dim_stage
]))*= 2
dim_stage
# Bottleneck
self.bottleneck = MSAB(
=dim_stage, dim_head=dim, heads=dim_stage // dim, num_blocks=num_blocks[-1])
dim
# Decoder
self.decoder_layers = nn.ModuleList([])
for i in range(stage):
self.decoder_layers.append(nn.ModuleList([
// 2, stride=2, kernel_size=2, padding=0, output_padding=0),
nn.ConvTranspose2d(dim_stage, dim_stage // 2, 1, 1, bias=False),
nn.Conv2d(dim_stage, dim_stage
MSAB(=dim_stage // 2, num_blocks=num_blocks[stage - 1 - i], dim_head=dim,
dim=(dim_stage // 2) // dim),
heads
]))//= 2
dim_stage
# Output projection
self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)
#### activation function
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
=.02)
trunc_normal_(m.weight, stdif isinstance(m, nn.Linear) and m.bias is not None:
0)
nn.init.constant_(m.bias, elif isinstance(m, nn.LayerNorm):
0)
nn.init.constant_(m.bias, 1.0)
nn.init.constant_(m.weight,
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
# Embedding
= self.embedding(x)
fea
# Encoder
= []
fea_encoder for (MSAB, FeaDownSample) in self.encoder_layers:
= MSAB(fea)
fea
fea_encoder.append(fea)= FeaDownSample(fea)
fea
# Bottleneck
= self.bottleneck(fea)
fea
# Decoder
for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
= FeaUpSample(fea)
fea = Fution(torch.cat([fea, fea_encoder[self.stage-1-i]], dim=1))
fea = LeWinBlcok(fea)
fea
# Mapping
= self.mapping(fea) + x
out
return out
class MST_Plus_Plus(nn.Module):
def __init__(self, in_channels=3, out_channels=31, n_feat=31, stage=3):
super(MST_Plus_Plus, self).__init__()
self.stage = stage
self.conv_in = nn.Conv2d(in_channels, n_feat, kernel_size=3, padding=(3 - 1) // 2,bias=False)
= [MST(dim=31, stage=2, num_blocks=[1,1,1]) for _ in range(stage)]
modules_body self.body = nn.Sequential(*modules_body)
self.conv_out = nn.Conv2d(n_feat, out_channels, kernel_size=3, padding=(3 - 1) // 2,bias=False)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
= x.shape
b, c, h_inp, w_inp = 8, 8
hb, wb = (hb - h_inp % hb) % hb
pad_h = (wb - w_inp % wb) % wb
pad_w = F.pad(x, [0, pad_w, 0, pad_h], mode='reflect')
x = self.conv_in(x)
x = self.body(x)
h = self.conv_out(h)
h += x
h return h[:, :, :h_inp, :w_inp]
# from architecture.mst_plus_plus import MST_Plus_Plus
import architecture
from utils import my_summary
256, 256, 3, 1) my_summary(MST_Plus_Plus(),
MST_Plus_Plus(
(conv_in): Conv2d(3, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(body): Sequential(
(0): MST(
(embedding): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(encoder_layers): ModuleList(
(0): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(31, 62, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
(1): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(62, 124, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
(bottleneck): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=124, out_features=124, bias=False)
(to_k): Linear(in_features=124, out_features=124, bias=False)
(to_v): Linear(in_features=124, out_features=124, bias=False)
(proj): Linear(in_features=124, out_features=124, bias=True)
(pos_emb): Sequential(
(0): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(124, 496, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(496, 496, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=496, bias=False)
(3): GELU()
(4): Conv2d(496, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((124,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(decoder_layers): ModuleList(
(0): ModuleList(
(0): ConvTranspose2d(124, 62, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(124, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
(1): ModuleList(
(0): ConvTranspose2d(62, 31, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(62, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
)
(mapping): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
)
(1): MST(
(embedding): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(encoder_layers): ModuleList(
(0): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(31, 62, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
(1): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(62, 124, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
(bottleneck): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=124, out_features=124, bias=False)
(to_k): Linear(in_features=124, out_features=124, bias=False)
(to_v): Linear(in_features=124, out_features=124, bias=False)
(proj): Linear(in_features=124, out_features=124, bias=True)
(pos_emb): Sequential(
(0): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(124, 496, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(496, 496, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=496, bias=False)
(3): GELU()
(4): Conv2d(496, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((124,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(decoder_layers): ModuleList(
(0): ModuleList(
(0): ConvTranspose2d(124, 62, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(124, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
(1): ModuleList(
(0): ConvTranspose2d(62, 31, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(62, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
)
(mapping): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
)
(2): MST(
(embedding): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(encoder_layers): ModuleList(
(0): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(31, 62, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
(1): ModuleList(
(0): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(1): Conv2d(62, 124, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)
)
(bottleneck): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=124, out_features=124, bias=False)
(to_k): Linear(in_features=124, out_features=124, bias=False)
(to_v): Linear(in_features=124, out_features=124, bias=False)
(proj): Linear(in_features=124, out_features=124, bias=True)
(pos_emb): Sequential(
(0): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(124, 496, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(496, 496, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=496, bias=False)
(3): GELU()
(4): Conv2d(496, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((124,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(decoder_layers): ModuleList(
(0): ModuleList(
(0): ConvTranspose2d(124, 62, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(124, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=62, out_features=62, bias=False)
(to_k): Linear(in_features=62, out_features=62, bias=False)
(to_v): Linear(in_features=62, out_features=62, bias=False)
(proj): Linear(in_features=62, out_features=62, bias=True)
(pos_emb): Sequential(
(0): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
(1): GELU()
(2): Conv2d(62, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=62, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(62, 248, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(248, 248, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=248, bias=False)
(3): GELU()
(4): Conv2d(248, 62, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((62,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
(1): ModuleList(
(0): ConvTranspose2d(62, 31, kernel_size=(2, 2), stride=(2, 2))
(1): Conv2d(62, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): MSAB(
(blocks): ModuleList(
(0): ModuleList(
(0): MS_MSA(
(to_q): Linear(in_features=31, out_features=31, bias=False)
(to_k): Linear(in_features=31, out_features=31, bias=False)
(to_v): Linear(in_features=31, out_features=31, bias=False)
(proj): Linear(in_features=31, out_features=31, bias=True)
(pos_emb): Sequential(
(0): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
(1): GELU()
(2): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=31, bias=False)
)
)
(1): PreNorm(
(fn): FeedForward(
(net): Sequential(
(0): Conv2d(31, 124, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): GELU()
(2): Conv2d(124, 124, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=124, bias=False)
(3): GELU()
(4): Conv2d(124, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
(norm): LayerNorm((31,), eps=1e-05, elementwise_affine=True)
)
)
)
)
)
)
(mapping): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(lrelu): LeakyReLU(negative_slope=0.1, inplace=True)
)
)
(conv_out): Conv2d(31, 31, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
Unsupported operator aten::rsub encountered 2 time(s)
Unsupported operator aten::pad encountered 1 time(s)
Unsupported operator aten::mul encountered 45 time(s)
Unsupported operator aten::linalg_vector_norm encountered 30 time(s)
Unsupported operator aten::clamp_min encountered 30 time(s)
Unsupported operator aten::expand_as encountered 30 time(s)
Unsupported operator aten::div encountered 30 time(s)
Unsupported operator aten::softmax encountered 15 time(s)
Unsupported operator aten::gelu encountered 45 time(s)
Unsupported operator aten::add encountered 48 time(s)
Unsupported operator aten::add_ encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
body.0.lrelu, body.1.lrelu, body.2.lrelu
GMac:20.759536743164062
Params:1619625
# import torch
# # Instantiate the model
# model = MST_Plus_Plus()
# # Load the entire checkpoint
# checkpoint = torch.load('./test_develop_code/model_zoo/mst_plus_plus.pth')
# # Extract the actual state_dict
# model.load_state_dict(checkpoint['state_dict']) # Load only the 'state_dict' part
# print(model)
HSCNN_Plus
import torch.nn as nn
import torch
class dfus_block(nn.Module):
def __init__(self, dim):
super(dfus_block, self).__init__()
self.conv1 = nn.Conv2d(dim, 128, 1, 1, 0, bias=False)
self.conv_up1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
self.conv_up2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
self.conv_down1 = nn.Conv2d(128, 32, 3, 1, 1, bias=False)
self.conv_down2 = nn.Conv2d(32, 16, 1, 1, 0, bias=False)
self.conv_fution = nn.Conv2d(96, 32, 1, 1, 0, bias=False)
#### activation function
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
= self.relu(self.conv1(x))
feat = self.relu(self.conv_up1(feat))
feat_up1 = self.relu(self.conv_up2(feat_up1))
feat_up2 = self.relu(self.conv_down1(feat))
feat_down1 = self.relu(self.conv_down2(feat_down1))
feat_down2 = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
feat_fution = self.relu(self.conv_fution(feat_fution))
feat_fution = torch.cat([x, feat_fution], dim=1)
out return out
class ddfn(nn.Module):
def __init__(self, dim, num_blocks=78):
super(ddfn, self).__init__()
self.conv_up1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
self.conv_up2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
self.conv_down1 = nn.Conv2d(dim, 32, 3, 1, 1, bias=False)
self.conv_down2 = nn.Conv2d(32, 32, 1, 1, 0, bias=False)
= [dfus_block(dim=128+32*i) for i in range(num_blocks)]
dfus_blocks self.dfus_blocks = nn.Sequential(*dfus_blocks)
#### activation function
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
= self.relu(self.conv_up1(x))
feat_up1 = self.relu(self.conv_up2(feat_up1))
feat_up2 = self.relu(self.conv_down1(x))
feat_down1 = self.relu(self.conv_down2(feat_down1))
feat_down2 = torch.cat([feat_up1,feat_up2,feat_down1,feat_down2],dim=1)
feat_fution = self.dfus_blocks(feat_fution)
out return out
class HSCNN_Plus(nn.Module):
def __init__(self, in_channels=3, out_channels=31, num_blocks=30):
super(HSCNN_Plus, self).__init__()
self.ddfn = ddfn(dim=in_channels, num_blocks=num_blocks)
self.conv_out = nn.Conv2d(128+32*num_blocks, out_channels, 1, 1, 0, bias=False)
def forward(self, x):
"""
x: [b,c,h,w]
return out:[b,c,h,w]
"""
= self.ddfn(x)
fea = self.conv_out(fea)
out return out
from utils import my_summary
256, 256, 3, 1) my_summary(HSCNN_Plus(),
HSCNN_Plus(
(ddfn): ddfn(
(conv_up1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(dfus_blocks): Sequential(
(0): dfus_block(
(conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(1): dfus_block(
(conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(2): dfus_block(
(conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(3): dfus_block(
(conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(4): dfus_block(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(5): dfus_block(
(conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(6): dfus_block(
(conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(7): dfus_block(
(conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(8): dfus_block(
(conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(9): dfus_block(
(conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(10): dfus_block(
(conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(11): dfus_block(
(conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(12): dfus_block(
(conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(13): dfus_block(
(conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(14): dfus_block(
(conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(15): dfus_block(
(conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(16): dfus_block(
(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(17): dfus_block(
(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(18): dfus_block(
(conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(19): dfus_block(
(conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(20): dfus_block(
(conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(21): dfus_block(
(conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(22): dfus_block(
(conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(23): dfus_block(
(conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(24): dfus_block(
(conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(25): dfus_block(
(conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(26): dfus_block(
(conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(27): dfus_block(
(conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(28): dfus_block(
(conv1): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
(29): dfus_block(
(conv1): Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_up1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_up2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_down1): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv_down2): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(conv_fution): Conv2d(96, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
)
)
(relu): ReLU(inplace=True)
)
(conv_out): Conv2d(1088, 31, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
GMac:283.5390625
Params:4645504
3. Evaluation on the Test Set:
Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/test_challenge_code/model_zoo/.
Run the following command to test the model on the testing RGB images.
cd test_challenge_code
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_challenge_code
Test MST++ (Transformer Model)
!python test.py --data_root ../dataset/ --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/ --gpu_id 0
load model from ./model_zoo/mst_plus_plus.pth
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_challenge_code/architecture/__init__.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(pretrained_model_path)
ARAD_1K_0951.jpg
ARAD_1K_0952.jpg
ARAD_1K_0953.jpg
ARAD_1K_0954.jpg
ARAD_1K_0955.jpg
ARAD_1K_0956.jpg
ARAD_1K_0957.jpg
ARAD_1K_0958.jpg
ARAD_1K_0959.jpg
ARAD_1K_0960.jpg
ARAD_1K_0961.jpg
ARAD_1K_0962.jpg
ARAD_1K_0963.jpg
ARAD_1K_0964.jpg
ARAD_1K_0965.jpg
ARAD_1K_0966.jpg
ARAD_1K_0967.jpg
ARAD_1K_0968.jpg
ARAD_1K_0969.jpg
ARAD_1K_0970.jpg
ARAD_1K_0971.jpg
ARAD_1K_0972.jpg
ARAD_1K_0973.jpg
ARAD_1K_0974.jpg
ARAD_1K_0975.jpg
ARAD_1K_0976.jpg
ARAD_1K_0977.jpg
ARAD_1K_0978.jpg
ARAD_1K_0979.jpg
ARAD_1K_0980.jpg
ARAD_1K_0981.jpg
ARAD_1K_0982.jpg
ARAD_1K_0983.jpg
ARAD_1K_0984.jpg
ARAD_1K_0985.jpg
ARAD_1K_0986.jpg
ARAD_1K_0987.jpg
ARAD_1K_0988.jpg
ARAD_1K_0989.jpg
ARAD_1K_0990.jpg
ARAD_1K_0991.jpg
ARAD_1K_0992.jpg
ARAD_1K_0993.jpg
ARAD_1K_0994.jpg
ARAD_1K_0995.jpg
ARAD_1K_0996.jpg
ARAD_1K_0997.jpg
ARAD_1K_0998.jpg
ARAD_1K_0999.jpg
ARAD_1K_1000.jpg
./exp/mst_plus_plus/
./exp/mst_plus_plus//submission
Cropping files from input directory
100%|βββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:16<00:00, 3.02it/s]
Compressing submission
100%|βββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:08<00:00, 6.06it/s]
Removing temporary files
100%|βββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:00<00:00, 1098.49it/s]
Verifying submission size - SUCCESS!
Submission generated @ ./exp/mst_plus_plus//submission/submission.zip
Test HSCNN+
!python test.py --data_root ../dataset/ --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/ --gpu_id 0
load model from ./model_zoo/hscnn_plus.pth
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_challenge_code/architecture/__init__.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(pretrained_model_path)
ARAD_1K_0951.jpg
ARAD_1K_0952.jpg
ARAD_1K_0953.jpg
ARAD_1K_0954.jpg
ARAD_1K_0955.jpg
ARAD_1K_0956.jpg
ARAD_1K_0957.jpg
ARAD_1K_0958.jpg
ARAD_1K_0959.jpg
ARAD_1K_0960.jpg
ARAD_1K_0961.jpg
ARAD_1K_0962.jpg
ARAD_1K_0963.jpg
ARAD_1K_0964.jpg
ARAD_1K_0965.jpg
ARAD_1K_0966.jpg
ARAD_1K_0967.jpg
ARAD_1K_0968.jpg
ARAD_1K_0969.jpg
ARAD_1K_0970.jpg
ARAD_1K_0971.jpg
ARAD_1K_0972.jpg
ARAD_1K_0973.jpg
ARAD_1K_0974.jpg
ARAD_1K_0975.jpg
ARAD_1K_0976.jpg
ARAD_1K_0977.jpg
ARAD_1K_0978.jpg
ARAD_1K_0979.jpg
ARAD_1K_0980.jpg
ARAD_1K_0981.jpg
ARAD_1K_0982.jpg
ARAD_1K_0983.jpg
ARAD_1K_0984.jpg
ARAD_1K_0985.jpg
ARAD_1K_0986.jpg
ARAD_1K_0987.jpg
ARAD_1K_0988.jpg
ARAD_1K_0989.jpg
ARAD_1K_0990.jpg
ARAD_1K_0991.jpg
ARAD_1K_0992.jpg
ARAD_1K_0993.jpg
ARAD_1K_0994.jpg
ARAD_1K_0995.jpg
ARAD_1K_0996.jpg
ARAD_1K_0997.jpg
ARAD_1K_0998.jpg
ARAD_1K_0999.jpg
ARAD_1K_1000.jpg
./exp/hscnn_plus/
./exp/hscnn_plus//submission
Cropping files from input directory
100%|βββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:16<00:00, 3.06it/s]
Compressing submission
100%|βββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:08<00:00, 5.94it/s]
Removing temporary files
100%|βββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:00<00:00, 1004.40it/s]
Verifying submission size - SUCCESS!
Submission generated @ ./exp/hscnn_plus//submission/submission.zip
The results and submission.zip will be saved in /MST-plus-plus/test_challenge_code/exp/.
4.Prediction:
Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/predict_code/model_zoo/.
Run the following command to reconstruct your own RGB image.
cd ..
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus
/ cd predict_code
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/predict_code
Reconstruct by MST++
!python test.py --rgb_path ./demo/ARAD_1K_0912.jpg --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/ --gpu_id 0
load model from ./model_zoo/mst_plus_plus.pth
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/predict_code/architecture/__init__.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(pretrained_model_path)
Reconstructing ./demo/ARAD_1K_0912.jpg
The reconstructed hyper spectral image are saved as ./exp/mst_plus_plus/ARAD_1K_0912.mat.
Reconstruct by HSCNN+
!python test.py --rgb_path ./demo/ARAD_1K_1000.jpg --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/ --gpu_id 0
load model from ./model_zoo/hscnn_plus.pth
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/predict_code/architecture/__init__.py:38: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(pretrained_model_path)
Reconstructing ./demo/ARAD_1K_1000.jpg
The reconstructed hyper spectral image are saved as ./exp/hscnn_plus/ARAD_1K_1000.mat.
You can replace β./demo/ARAD_1K_0912.jpgβ with your RGB image path. The reconstructed results will be saved in /MST-plus-plus/predict_code/exp/.
5. Visualization
Put the reconstruted HSI in visualization/simulation_results/results/.
Generate the RGB images of the reconstructed HSIs
Please put the reconstructed HSI from /MST-plus-plus/predict_code/exp/. to visualization/simulation_results/results/ and rename it as method.mat, e.g., mst_s.mat. (Please read README.md
file in each folder)
# cd visualization
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/visualization
# Run show_simulation.m
cd ..
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus
import h5py
import numpy as np
import matplotlib.pyplot as plt
def load_and_visualize_mat_v7_3(file_path):
try:
# Load the .mat file
with h5py.File(file_path, 'r') as mat_file:
# Extract the 'cube' dataset
= mat_file['cube'][()] # Load the full dataset into a NumPy array
cube
# Check the shape of the data cube
print(f"Data cube shape: {cube.shape}")
# Visualize one of the spectral bands (e.g., the 15th band)
= 15 # Choose a band to visualize
band_index if band_index < cube.shape[0]:
= cube[band_index, :, :]
band_data = np.transpose(band_data) # Correct rotation
band_data_corrected
='gray')
plt.imshow(band_data_corrected, cmapf'Spectral Band {band_index + 1}')
plt.title(
plt.colorbar()
plt.show()else:
print(f"Band index {band_index} is out of range for cube with shape {cube.shape}")
# Visualize an average over all bands (create an RGB-like view)
= np.mean(cube, axis=0)
average_band = np.transpose(average_band) # Correct rotation
average_band_corrected
='gray')
plt.imshow(average_band_corrected, cmap'Average Over All Bands')
plt.title(
plt.colorbar()
plt.show()
except FileNotFoundError:
print("The specified file was not found. Please check the file path.")
except Exception as e:
print(f"An error occurred: {e}")
# Replace with the path to your .mat file
= "./visualization/simulation_results/results/mst_plus_plus.mat"
file_path load_and_visualize_mat_v7_3(file_path)
Data cube shape: (31, 512, 482)
6. Evaluation on the Validation Set:
Download the pretrained model zoo from (Google Drive / Baidu Disk, code: mst1) and place them to /MST-plus-plus/test_develop_code/model_zoo/.
Run the following command to test the model on the validation RGB images.
cd test_develop_code
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/test_develop_code
ls
__pycache__/ exp/ model_zoo/ utils.py
architecture/ hsi_dataset.py test.py
Test MST++
!python test.py --data_root ../dataset/ --method mst_plus_plus --pretrained_model_path ./model_zoo/mst_plus_plus.pth --outf ./exp/mst_plus_plus/ --gpu_id 0
len(hyper_valid) of ntire2022 dataset:50
len(bgr_valid) of ntire2022 dataset:50
Ntire2022 scene 0 is loaded.
Ntire2022 scene 1 is loaded.
Ntire2022 scene 2 is loaded.
Ntire2022 scene 3 is loaded.
Ntire2022 scene 4 is loaded.
Ntire2022 scene 5 is loaded.
Ntire2022 scene 6 is loaded.
Ntire2022 scene 7 is loaded.
Ntire2022 scene 8 is loaded.
Ntire2022 scene 9 is loaded.
Ntire2022 scene 10 is loaded.
Ntire2022 scene 11 is loaded.
Ntire2022 scene 12 is loaded.
Ntire2022 scene 13 is loaded.
Ntire2022 scene 14 is loaded.
Ntire2022 scene 15 is loaded.
Ntire2022 scene 16 is loaded.
Ntire2022 scene 17 is loaded.
Ntire2022 scene 18 is loaded.
Ntire2022 scene 19 is loaded.
Ntire2022 scene 20 is loaded.
Ntire2022 scene 21 is loaded.
Ntire2022 scene 22 is loaded.
Ntire2022 scene 23 is loaded.
Ntire2022 scene 24 is loaded.
Ntire2022 scene 25 is loaded.
Ntire2022 scene 26 is loaded.
Ntire2022 scene 27 is loaded.
Ntire2022 scene 28 is loaded.
Ntire2022 scene 29 is loaded.
Ntire2022 scene 30 is loaded.
Ntire2022 scene 31 is loaded.
Ntire2022 scene 32 is loaded.
Ntire2022 scene 33 is loaded.
Ntire2022 scene 34 is loaded.
Ntire2022 scene 35 is loaded.
Ntire2022 scene 36 is loaded.
Ntire2022 scene 37 is loaded.
Ntire2022 scene 38 is loaded.
Ntire2022 scene 39 is loaded.
Ntire2022 scene 40 is loaded.
Ntire2022 scene 41 is loaded.
Ntire2022 scene 42 is loaded.
Ntire2022 scene 43 is loaded.
Ntire2022 scene 44 is loaded.
Ntire2022 scene 45 is loaded.
Ntire2022 scene 46 is loaded.
Ntire2022 scene 47 is loaded.
Ntire2022 scene 48 is loaded.
Ntire2022 scene 49 is loaded.
load model from ./model_zoo/mst_plus_plus.pth
method:mst_plus_plus, mrae:0.1645723432302475, rmse:0.02476534992456436, psnr:18.959978103637695
Test HSCNN+
!python test.py --data_root ../dataset/ --method hscnn_plus --pretrained_model_path ./model_zoo/hscnn_plus.pth --outf ./exp/hscnn_plus/ --gpu_id 0
len(hyper_valid) of ntire2022 dataset:50
len(bgr_valid) of ntire2022 dataset:50
Ntire2022 scene 0 is loaded.
Ntire2022 scene 1 is loaded.
Ntire2022 scene 2 is loaded.
Ntire2022 scene 3 is loaded.
Ntire2022 scene 4 is loaded.
Ntire2022 scene 5 is loaded.
Ntire2022 scene 6 is loaded.
Ntire2022 scene 7 is loaded.
Ntire2022 scene 8 is loaded.
Ntire2022 scene 9 is loaded.
Ntire2022 scene 10 is loaded.
Ntire2022 scene 11 is loaded.
Ntire2022 scene 12 is loaded.
Ntire2022 scene 13 is loaded.
Ntire2022 scene 14 is loaded.
Ntire2022 scene 15 is loaded.
Ntire2022 scene 16 is loaded.
Ntire2022 scene 17 is loaded.
Ntire2022 scene 18 is loaded.
Ntire2022 scene 19 is loaded.
Ntire2022 scene 20 is loaded.
Ntire2022 scene 21 is loaded.
Ntire2022 scene 22 is loaded.
Ntire2022 scene 23 is loaded.
Ntire2022 scene 24 is loaded.
Ntire2022 scene 25 is loaded.
Ntire2022 scene 26 is loaded.
Ntire2022 scene 27 is loaded.
Ntire2022 scene 28 is loaded.
Ntire2022 scene 29 is loaded.
Ntire2022 scene 30 is loaded.
Ntire2022 scene 31 is loaded.
Ntire2022 scene 32 is loaded.
Ntire2022 scene 33 is loaded.
Ntire2022 scene 34 is loaded.
Ntire2022 scene 35 is loaded.
Ntire2022 scene 36 is loaded.
Ntire2022 scene 37 is loaded.
Ntire2022 scene 38 is loaded.
Ntire2022 scene 39 is loaded.
Ntire2022 scene 40 is loaded.
Ntire2022 scene 41 is loaded.
Ntire2022 scene 42 is loaded.
Ntire2022 scene 43 is loaded.
Ntire2022 scene 44 is loaded.
Ntire2022 scene 45 is loaded.
Ntire2022 scene 46 is loaded.
Ntire2022 scene 47 is loaded.
Ntire2022 scene 48 is loaded.
Ntire2022 scene 49 is loaded.
load model from ./model_zoo/hscnn_plus.pth
method:hscnn_plus, mrae:0.38142824172973633, rmse:0.05884117633104324, psnr:26.3621826171875
Test HDNet
!python test.py --data_root ../dataset/ --method hdnet --pretrained_model_path ./model_zoo/hdnet.pth --outf ./exp/hdnet/ --gpu_id 0
len(hyper_valid) of ntire2022 dataset:50
len(bgr_valid) of ntire2022 dataset:50
Ntire2022 scene 0 is loaded.
Ntire2022 scene 1 is loaded.
Ntire2022 scene 2 is loaded.
Ntire2022 scene 3 is loaded.
Ntire2022 scene 4 is loaded.
Ntire2022 scene 5 is loaded.
Ntire2022 scene 6 is loaded.
Ntire2022 scene 7 is loaded.
Ntire2022 scene 8 is loaded.
Ntire2022 scene 9 is loaded.
Ntire2022 scene 10 is loaded.
Ntire2022 scene 11 is loaded.
Ntire2022 scene 12 is loaded.
Ntire2022 scene 13 is loaded.
Ntire2022 scene 14 is loaded.
Ntire2022 scene 15 is loaded.
Ntire2022 scene 16 is loaded.
Ntire2022 scene 17 is loaded.
Ntire2022 scene 18 is loaded.
Ntire2022 scene 19 is loaded.
Ntire2022 scene 20 is loaded.
Ntire2022 scene 21 is loaded.
Ntire2022 scene 22 is loaded.
Ntire2022 scene 23 is loaded.
Ntire2022 scene 24 is loaded.
Ntire2022 scene 25 is loaded.
Ntire2022 scene 26 is loaded.
Ntire2022 scene 27 is loaded.
Ntire2022 scene 28 is loaded.
Ntire2022 scene 29 is loaded.
Ntire2022 scene 30 is loaded.
Ntire2022 scene 31 is loaded.
Ntire2022 scene 32 is loaded.
Ntire2022 scene 33 is loaded.
Ntire2022 scene 34 is loaded.
Ntire2022 scene 35 is loaded.
Ntire2022 scene 36 is loaded.
Ntire2022 scene 37 is loaded.
Ntire2022 scene 38 is loaded.
Ntire2022 scene 39 is loaded.
Ntire2022 scene 40 is loaded.
Ntire2022 scene 41 is loaded.
Ntire2022 scene 42 is loaded.
Ntire2022 scene 43 is loaded.
Ntire2022 scene 44 is loaded.
Ntire2022 scene 45 is loaded.
Ntire2022 scene 46 is loaded.
Ntire2022 scene 47 is loaded.
Ntire2022 scene 48 is loaded.
Ntire2022 scene 49 is loaded.
load model from ./model_zoo/hdnet.pth
method:hdnet, mrae:0.20484349131584167, rmse:0.03170507773756981, psnr:32.129539489746094
Traning
/ cd train_code
/root/Ninjalabo/HSI/MST-plus-plus/MST-plus-plus/train_code
# !python train.py --method mst_plus_plus --batch_size 20 --end_epoch 300 --init_lr 4e-4 --outf ./exp/mst_plus_plus/ --data_root ../dataset/ --patch_size 128 --stride 8 --gpu_id 0
Supported Algorithms:
- Figure: PSNR-Params-FLOPS comparisons with other spectral reconstruction algorithms.
- FLOPS (Floating Point Operations Per Second): computational complexity
- PSNR ( Peak signal-to-noise ratio): performance
- The circle radius is Params: memory cost
- MST++: surpasses other methods while requiring significantly cheaper FLOPS and Params.
Comparison with Other Methods
- MST++ achieved the highest PSNR of 34.32 dB on the validation set, a significant improvement over previous methods.
- In terms of efficiency, MST++ required only 23.05 FLOPS, making it one of the most lightweight models.
- MST++ achieved the lowest error rates- RMSE (root mean square error), and MRAE (mean relative absolute error).
Experiment Results
The dataset for evaluation consisted of RGB-HSI pairs with 31 spectral bands. #### Quantitative Results: * MST++ outperformed all other models in metrics like PSNR, RMSE, and MRAE. #### Qualitative Results: * Visual comparisons show that MST++ produces sharper, more detailed hyperspectral reconstructions with fewer artifacts. * It excels in preserving spatial smoothness and restoring spectral consistency.