Repository: CUHK-AIM-Group/U-KAN Branch: main Commit: b20bf63490f0 Files: 42 Total size: 236.5 KB Directory structure: gitextract_suz9sp40/ ├── Diffusion_UKAN/ │ ├── Diffusion/ │ │ ├── Diffusion.py │ │ ├── Model.py │ │ ├── Model_ConvKan.py │ │ ├── Model_UKAN_Hybrid.py │ │ ├── Model_UMLP.py │ │ ├── Train.py │ │ ├── UNet.py │ │ ├── __init__.py │ │ ├── kan_utils/ │ │ │ ├── __init__.py │ │ │ ├── fastkanconv.py │ │ │ └── kan.py │ │ └── utils.py │ ├── Main.py │ ├── Main_Test.py │ ├── README.md │ ├── Scheduler.py │ ├── data/ │ │ └── readme.txt │ ├── inception-score-pytorch/ │ │ ├── LICENSE.md │ │ ├── README.md │ │ └── inception_score.py │ ├── released_models/ │ │ └── readme.txt │ ├── requirements.txt │ ├── tools/ │ │ ├── resive_cvc.py │ │ ├── resize_busi.py │ │ └── resize_glas.py │ └── training_scripts/ │ ├── busi.sh │ ├── cvc.sh │ └── glas.sh ├── README.md └── Seg_UKAN/ ├── LICENSE ├── archs.py ├── config.py ├── dataset.py ├── environment.yml ├── kan.py ├── losses.py ├── metrics.py ├── requirements.txt ├── scripts.sh ├── train.py ├── utils.py └── val.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: Diffusion_UKAN/Diffusion/Diffusion.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tqdm import tqdm def extract(v, t, x_shape): """ Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ device = t.device out = torch.gather(v, index=t, dim=0).float().to(device) return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) class GaussianDiffusionTrainer(nn.Module): def __init__(self, model, beta_1, beta_T, T): super().__init__() self.model = model self.T = T self.register_buffer( 'betas', torch.linspace(beta_1, beta_T, T).double()) alphas = 1. - self.betas alphas_bar = torch.cumprod(alphas, dim=0) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer( 'sqrt_alphas_bar', torch.sqrt(alphas_bar)) self.register_buffer( 'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)) def forward(self, x_0): """ Algorithm 1. """ t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device) noise = torch.randn_like(x_0) x_t = ( extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise) loss = F.mse_loss(self.model(x_t, t), noise, reduction='none') return loss class GaussianDiffusionSampler(nn.Module): def __init__(self, model, beta_1, beta_T, T): super().__init__() self.model = model self.T = T self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double()) alphas = 1. - self.betas alphas_bar = torch.cumprod(alphas, dim=0) alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] self.register_buffer('coeff1', torch.sqrt(1. / alphas)) self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar)) self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) def predict_xt_prev_mean_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps ) def p_mean_variance(self, x_t, t): # below: only log_variance is used in the KL computations var = torch.cat([self.posterior_var[1:2], self.betas[1:]]) var = extract(var, t, x_t.shape) eps = self.model(x_t, t) xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps) return xt_prev_mean, var def forward(self, x_T): """ Algorithm 2. """ x_t = x_T print('Start Sampling') for time_step in tqdm(reversed(range(self.T))): t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step mean, var= self.p_mean_variance(x_t=x_t, t=t) # no noise when t == 0 if time_step > 0: noise = torch.randn_like(x_t) else: noise = 0 x_t = mean + torch.sqrt(var) * noise assert torch.isnan(x_t).int().sum() == 0, "nan in tensor." x_0 = x_t return torch.clip(x_0, -1, 1) ================================================ FILE: Diffusion_UKAN/Diffusion/Model.py ================================================ import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, out_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h # return x class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class Ukan(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.fc1 = KANLinear( kan_c, kan_c *2, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # print(now_ch) # self.dwconv = DWConv(kan_c *2) self.act = nn.GELU() self.fc2 = KANLinear( kan_c *2, kan_c, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.permute(0, 2, 3, 1).reshape(B*H*W, C) h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class DW_bn_relu(nn.Module): def __init__(self, dim=768): super(DW_bn_relu, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.bn = nn.BatchNorm2d(dim) self.relu = nn.ReLU() def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = self.bn(x) x = self.relu(x) x = x.flatten(2).transpose(1, 2) return x class kan(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dim = in_features grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] # self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = KANLinear( in_features, hidden_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.fc3 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # ############################################## self.version = 4 # version 4 hard code ���ܶ����� # ############################################## if self.version == 1: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() self.dwconv_4 = DWConv(hidden_features) self.act_4 = act_layer() elif self.version == 2: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 3: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) elif self.version == 4: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) elif self.version == 5: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 6: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 7: self.dwconv_1 = DWConv(hidden_features) self.act_1 = act_layer() self.dwconv_2 = DWConv(hidden_features) self.act_2 = act_layer() self.dwconv_3 = DWConv(hidden_features) self.act_3 = act_layer() elif self.version == 8: self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) self.drop = nn.Dropout(drop) self.shift_size = shift_size self.pad = shift_size // 2 self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): # pdb.set_trace() B, N, C = x.shape if self.version == 1: x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_4(x, H, W) x = self.act_4(x) elif self.version == 2: x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() elif self.version == 3: x = self.dwconv_1(x, H, W) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() elif self.version == 4: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) elif self.version == 5: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) elif self.version == 6: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) elif self.version == 7: x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.act_1(x) x = self.drop(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.act_2(x) x = self.drop(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.act_3(x) x = self.drop(x) elif self.version == 8: x = self.dwconv_1(x, H, W) x = self.drop(x) x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.drop(x) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) x = self.drop(x) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() return x class Ukan_v3(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.kan1 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4) self.kan2 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.reshape(B,C, H*W).permute(0, 2, 1) h = self.kan1(h, H, W) h = self.kan2(h, H, W) h = h.permute(0, 2, 1).reshape(B, C, H, W) # h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # B, N, C = x.shape # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class Ukan_v2(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,version=4): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] kan_c=512 self.kan = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=version) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.reshape(B,C, H*W).permute(0, 2, 1) h = self.kan(h, H, W) h = h.permute(0, 2, 1).reshape(B, C, H, W) # h =self.fc2( self.fc1(h)) # transform B*H*W, C into B, C, H, W # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # B, N, C = x.shape # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class UNet(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle for layer in self.middleblocks: h = layer(h, temb) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class UNet_MLP(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks1 = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), # ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.middleblocks2 = nn.ModuleList([ # ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) kan_c=512 self.fc1 = nn.Linear( kan_c, kan_c *2, ) self.act = nn.GELU() self.fc2 = nn.Linear( kan_c *2, kan_c, ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # for layer in self.middleblocks1: # h = layer(h, temb) B, C, H, W = h.shape # transform B, C, H, W into B*H*W, C h = h.permute(0, 2, 3, 1).reshape(B*H*W, C) h =self.fc2(self.act(self.fc1(h))) # transform B*H*W, C into B, C, H, W h = h.reshape(B, H, W, C).permute(0, 3, 1, 2) # for layer in self.middleblocks2: # h = layer(h, temb) ### Stage 3 # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UNet( T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], num_res_blocks=2, dropout=0.1) x = torch.randn(batch_size, 3, 32, 32) t = torch.randint(1000, (batch_size, )) y = model(x, t) print(y.shape) ================================================ FILE: Diffusion_UKAN/Diffusion/Model_ConvKan.py ================================================ import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F from Diffusion.kan_utils.fastkanconv import FastKANConvLayer, SplineConv2D class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=2, padding=1) # self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=1, padding=1) # self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), # Swish(), # nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), FastKANConvLayer(in_ch, out_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, out_ch), # Swish(), nn.Dropout(dropout), # nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), FastKANConvLayer(out_ch, out_ch, 3, stride=1, padding=1), ) if in_ch != out_ch: # self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) self.shortcut = FastKANConvLayer(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)) and not isinstance(module, (SplineConv2D)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) # init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h # return x class UNet_ConvKan(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), # Swish(), # nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) FastKANConvLayer(now_ch, 3, 3, stride=1, padding=1) ) # self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle for layer in self.middleblocks: h = layer(h, temb) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UNet_ConvKan( T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1], num_res_blocks=2, dropout=0.1) x = torch.randn(batch_size, 3, 32, 32) t = torch.randint(1000, (batch_size, )) y = model(x, t) print(y.shape) ================================================ FILE: Diffusion_UKAN/Diffusion/Model_UKAN_Hybrid.py ================================================ import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class KAN(torch.nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = torch.nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) def shift(dim): x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] x_cat = torch.cat(x_shift, 1) x_cat = torch.narrow(x_cat, 2, self.pad, H) x_cat = torch.narrow(x_cat, 3, self.pad, W) return x_cat class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = self.proj(x) _, _, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, H, W class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) def swish(x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class kan(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dim = in_features grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=Swish grid_eps=0.02 grid_range=[-1, 1] self.fc1 = KANLinear( in_features, hidden_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): B, N, C = x.shape x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() return x class shiftedBlock(nn.Module): def __init__(self, dim, mlp_ratio=4.,drop_path=0.,norm_layer=nn.LayerNorm): super().__init__() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, dim), ) self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W, temb): temb = self.temb_proj(temb) x = self.drop_path(self.kan(self.norm2(x), H, W)) x = x + temb.unsqueeze(1) return x class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) return x class DW_bn_relu(nn.Module): def __init__(self, dim=768): super(DW_bn_relu, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.bn = nn.GroupNorm(32, dim) # self.relu = Swish() def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = self.bn(x) x = swish(x) x = x.flatten(2).transpose(1, 2) return x class SingleConv(nn.Module): def __init__(self, in_ch, h_ch): super(SingleConv, self).__init__() self.conv = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class DoubleConv(nn.Module): def __init__(self, in_ch, h_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, h_ch, 3, padding=1), nn.GroupNorm(32, h_ch), Swish(), nn.Conv2d(h_ch, h_ch, 3, padding=1), nn.GroupNorm(32, h_ch), Swish() ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class D_SingleConv(nn.Module): def __init__(self, in_ch, h_ch): super(D_SingleConv, self).__init__() self.conv = nn.Sequential( nn.GroupNorm(32,in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class D_DoubleConv(nn.Module): def __init__(self, in_ch, h_ch): super(D_DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, 3, padding=1), nn.GroupNorm(32,in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), nn.GroupNorm(32,h_ch), Swish() ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input,temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, h_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, h_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, h_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1), ) if in_ch != h_ch: self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(h_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h class UKan_Hybrid(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) attn = [] self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record hput channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): h_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, h_ch=h_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = h_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): h_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = h_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) # embed_dims = [256, 320, 512] norm_layer = nn.LayerNorm dpr = [0.0, 0.0, 0.0] self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.norm3 = norm_layer(embed_dims[1]) self.norm4 = norm_layer(embed_dims[2]) self.dnorm3 = norm_layer(embed_dims[1]) self.kan_block1 = nn.ModuleList([shiftedBlock( dim=embed_dims[1], mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)]) self.kan_block2 = nn.ModuleList([shiftedBlock( dim=embed_dims[2], mlp_ratio=1, drop_path=dpr[1], norm_layer=norm_layer)]) self.kan_dblock1 = nn.ModuleList([shiftedBlock( dim=embed_dims[1], mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)]) self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1]) self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0]) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) t3 = h B = x.shape[0] h, H, W = self.patch_embed3(h) for i, blk in enumerate(self.kan_block1): h = blk(h, H, W, temb) h = self.norm3(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() t4 = h h, H, W= self.patch_embed4(h) for i, blk in enumerate(self.kan_block2): h = blk(h, H, W, temb) h = self.norm4(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() ### Stage 4 h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear')) h = torch.add(h, t4) _, _, H, W = h.shape h = h.flatten(2).transpose(1,2) for i, blk in enumerate(self.kan_dblock1): h = blk(h, H, W, temb) ### Stage 3 h = self.dnorm3(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear')) h = torch.add(h,t3) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UKan_Hybrid( T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[], num_res_blocks=2, dropout=0.1) ================================================ FILE: Diffusion_UKAN/Diffusion/Model_UMLP.py ================================================ import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F from timm.models.layers import DropPath, to_2tuple, trunc_normal_ class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class KAN(torch.nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = torch.nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False) def shift(dim): x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))] x_cat = torch.cat(x_shift, 1) x_cat = torch.narrow(x_cat, 2, self.pad, H) x_cat = torch.narrow(x_cat, 3, self.pad, W) return x_cat class OverlapPatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = self.proj(x) _, _, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, H, W class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) def swish(x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class kan(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4, kan_val=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dim = in_features grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=Swish grid_eps=0.02 grid_range=[-1, 1] if kan_val: self.fc1 = nn.Linear(in_features, hidden_features) self.fc2 = nn.Linear(hidden_features, out_features) self.fc3 = nn.Linear(hidden_features, out_features) else: self.fc1 = nn.Sequential( nn.Linear(in_features, hidden_features), Swish(), nn.Linear(hidden_features, out_features)) self.drop = nn.Dropout(drop) self.shift_size = shift_size self.pad = shift_size // 2 self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): B, N, C = x.shape x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() return x class shiftedBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, version=1, kan_val=False): super().__init__() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, dim), ) # self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, kan_val=kan_val) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W, temb): temb = self.temb_proj(temb) # x = x + self.drop_path(self.kan(self.norm2(x), H, W)) x = self.drop_path(self.kan(self.norm2(x), H, W)) x = x + temb.unsqueeze(1) return x class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) return x class DW_bn_relu(nn.Module): def __init__(self, dim=768): super(DW_bn_relu, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.bn = nn.GroupNorm(32, dim) # self.relu = Swish() def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = self.bn(x) x = swish(x) x = x.flatten(2).transpose(1, 2) return x class SingleConv(nn.Module): def __init__(self, in_ch, h_ch): super(SingleConv, self).__init__() self.conv = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class DoubleConv(nn.Module): def __init__(self, in_ch, h_ch): super(DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, h_ch, 3, padding=1), nn.GroupNorm(32, h_ch), Swish(), nn.Conv2d(h_ch, h_ch, 3, padding=1), nn.GroupNorm(32, h_ch), Swish() ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class D_SingleConv(nn.Module): def __init__(self, in_ch, h_ch): super(D_SingleConv, self).__init__() self.conv = nn.Sequential( nn.GroupNorm(32,in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input, temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class D_DoubleConv(nn.Module): def __init__(self, in_ch, h_ch): super(D_DoubleConv, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, 3, padding=1), nn.GroupNorm(32,in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, padding=1), nn.GroupNorm(32,h_ch), Swish() ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(256, h_ch), ) def forward(self, input,temb): return self.conv(input) + self.temb_proj(temb)[:,:,None, None] class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, h_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, h_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, h_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1), ) if in_ch != h_ch: self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(h_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h class UMLP(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) attn = [] self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record hput channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): h_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, h_ch=h_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = h_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): h_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = h_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) # embed_dims = [256, 320, 512] drop_rate = 0.0 attn_drop_rate = 0.0 kan_val = False version = 4 sr_ratios = [8, 4, 2, 1] num_heads=[1, 2, 4, 8] qkv_bias=False qk_scale=None norm_layer = nn.LayerNorm dpr = [0.0, 0.0, 0.0] self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.norm3 = norm_layer(embed_dims[1]) self.norm4 = norm_layer(embed_dims[2]) self.dnorm3 = norm_layer(embed_dims[1]) self.dnorm4 = norm_layer(embed_dims[0]) self.kan_block1 = nn.ModuleList([shiftedBlock( dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)]) self.kan_block2 = nn.ModuleList([shiftedBlock( dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)]) self.kan_dblock1 = nn.ModuleList([shiftedBlock( dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer, sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)]) # self.kan_dblock2 = nn.ModuleList([shiftedBlock( # dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale, # drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer, # sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)]) self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1]) self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0]) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) t3 = h B = x.shape[0] h, H, W = self.patch_embed3(h) for i, blk in enumerate(self.kan_block1): h = blk(h, H, W, temb) h = self.norm3(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() t4 = h h, H, W= self.patch_embed4(h) for i, blk in enumerate(self.kan_block2): h = blk(h, H, W, temb) h = self.norm4(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() ### Stage 4 h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear')) h = torch.add(h, t4) _, _, H, W = h.shape h = h.flatten(2).transpose(1,2) for i, blk in enumerate(self.kan_dblock1): h = blk(h, H, W, temb) ### Stage 3 h = self.dnorm3(h) h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear')) h = torch.add(h,t3) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UMLP( T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[], num_res_blocks=2, dropout=0.1) ================================================ FILE: Diffusion_UKAN/Diffusion/Train.py ================================================ import os from typing import Dict import torch import torch.optim as optim from tqdm import tqdm from torch.utils.data import DataLoader from torchvision import transforms, transforms # from torchvision.datasets import CIFAR10 from torchvision.utils import save_image from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer from Diffusion.UNet import UNet, UNet_Baseline from Diffusion.Model_ConvKan import UNet_ConvKan from Diffusion.Model_UMLP import UMLP from Diffusion.Model_UKAN_Hybrid import UKan_Hybrid from Scheduler import GradualWarmupScheduler from skimage import io import os from torchvision.transforms import ToTensor, Normalize, Compose from torch.utils.data import Dataset import sys model_dict = { 'UNet': UNet, 'UNet_ConvKan': UNet_ConvKan, # dose not work 'UMLP': UMLP, 'UKan_Hybrid': UKan_Hybrid, 'UNet_Baseline': UNet_Baseline, } class UnlabeledDataset(Dataset): def __init__(self, folder, transform=None, repeat_n=1): self.folder = folder self.transform = transform # self.image_files = os.listdir(folder) * repeat_n self.image_files = os.listdir(folder) def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_file = self.image_files[idx] image_path = os.path.join(self.folder, image_file) image = io.imread(image_path) if self.transform: image = self.transform(image) return image, torch.Tensor([0]) def train(modelConfig: Dict): device = torch.device(modelConfig["device"]) log_print = True if log_print: file = open(modelConfig["save_weight_dir"]+'log.txt', "w") sys.stdout = file transform = Compose([ ToTensor(), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) if modelConfig["dataset"] == 'cvc': dataset = UnlabeledDataset('data/cvc/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"]) elif modelConfig["dataset"] == 'glas': dataset = UnlabeledDataset('data/glas/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"]) elif modelConfig["dataset"] == 'glas_resize': dataset = UnlabeledDataset('data/glas/images_64_resize/', transform=transform, repeat_n=modelConfig["dataset_repeat"]) elif modelConfig["dataset"] == 'busi': dataset = UnlabeledDataset('data/busi/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"]) else: raise ValueError('dataset not found') print('modelConfig: ') for key, value in modelConfig.items(): print(key, ' : ', value) dataloader = DataLoader( dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True) print('Using {}'.format(modelConfig["model"])) # model setup net_model =model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"], num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device) if modelConfig["training_load_weight"] is not None: net_model.load_state_dict(torch.load(os.path.join( modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device)) optimizer = torch.optim.AdamW( net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4) cosineScheduler = optim.lr_scheduler.CosineAnnealingLR( optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1) warmUpScheduler = GradualWarmupScheduler( optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler) trainer = GaussianDiffusionTrainer( net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device) # start training for e in range(1,modelConfig["epoch"]+1): with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader: for images, labels in tqdmDataLoader: # train optimizer.zero_grad() x_0 = images.to(device) loss = trainer(x_0).sum() / 1000. loss.backward() torch.nn.utils.clip_grad_norm_( net_model.parameters(), modelConfig["grad_clip"]) optimizer.step() tqdmDataLoader.set_postfix(ordered_dict={ "epoch": e, "loss: ": loss.item(), "img shape: ": x_0.shape, "LR": optimizer.state_dict()['param_groups'][0]["lr"] }) # print version if log_print: print("epoch: ", e, "loss: ", loss.item(), "img shape: ", x_0.shape, "LR: ", optimizer.state_dict()['param_groups'][0]["lr"]) warmUpScheduler.step() if e % 50 ==0: torch.save(net_model.state_dict(), os.path.join( modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt")) modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(e) eval_tmp(modelConfig, e) torch.save(net_model.state_dict(), os.path.join( modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt")) if log_print: file.close() sys.stdout = sys.__stdout__ def eval_tmp(modelConfig: Dict, nme: int): # load model and evaluate with torch.no_grad(): device = torch.device(modelConfig["device"]) model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"], num_res_blocks=modelConfig["num_res_blocks"], dropout=0.) ckpt = torch.load(os.path.join( modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device) model.load_state_dict(ckpt) print("model load weight done.") model.eval() sampler = GaussianDiffusionSampler( model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device) # Sampled from standard normal distribution noisyImage = torch.randn( size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device) # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1) # save_image(saveNoisy, os.path.join( # modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"]) sampledImgs = sampler(noisyImage) sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1] save_root = modelConfig["sampled_dir"].replace('Gens','Tmp') os.makedirs(save_root, exist_ok=True) save_image(sampledImgs, os.path.join( save_root, modelConfig["sampledImgName"].replace('.png','_{}.png').format(nme)), nrow=modelConfig["nrow"]) if nme < 0.95 * modelConfig["epoch"]: os.remove(os.path.join( modelConfig["save_weight_dir"], modelConfig["test_load_weight"])) def eval(modelConfig: Dict): # load model and evaluate with torch.no_grad(): device = torch.device(modelConfig["device"]) model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"], num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device) ckpt = torch.load(os.path.join( modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device) model.load_state_dict(ckpt) print("model load weight done.") model.eval() sampler = GaussianDiffusionSampler( model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device) # Sampled from standard normal distribution noisyImage = torch.randn( size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device) # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1) # save_image(saveNoisy, os.path.join( # modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"]) sampledImgs = sampler(noisyImage) sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1] for i, image in enumerate(sampledImgs): save_image(image, os.path.join(modelConfig["sampled_dir"], modelConfig["sampledImgName"].replace('.png','_{}.png').format(i)), nrow=modelConfig["nrow"]) ================================================ FILE: Diffusion_UKAN/Diffusion/UNet.py ================================================ import math import torch from torch import nn from torch.nn import init from torch.nn import functional as F class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class TimeEmbedding(nn.Module): def __init__(self, T, d_model, dim): assert d_model % 2 == 0 super().__init__() emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000) emb = torch.exp(-emb) pos = torch.arange(T).float() emb = pos[:, None] * emb[None, :] assert list(emb.shape) == [T, d_model // 2] emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1) assert list(emb.shape) == [T, d_model // 2, 2] emb = emb.view(T, d_model) self.timembedding = nn.Sequential( nn.Embedding.from_pretrained(emb), nn.Linear(d_model, dim), Swish(), nn.Linear(dim, dim), ) self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, nn.Linear): init.xavier_uniform_(module.weight) init.zeros_(module.bias) def forward(self, t): emb = self.timembedding(t) return emb class DownSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): x = self.main(x) return x class UpSample(nn.Module): def __init__(self, in_ch): super().__init__() self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1) self.initialize() def initialize(self): init.xavier_uniform_(self.main.weight) init.zeros_(self.main.bias) def forward(self, x, temb): _, _, H, W = x.shape x = F.interpolate( x, scale_factor=2, mode='nearest') x = self.main(x) return x class AttnBlock(nn.Module): def __init__(self, in_ch): super().__init__() self.group_norm = nn.GroupNorm(32, in_ch) self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0) self.initialize() def initialize(self): for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]: init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.proj.weight, gain=1e-5) def forward(self, x): B, C, H, W = x.shape h = self.group_norm(x) q = self.proj_q(h) k = self.proj_k(h) v = self.proj_v(h) q = q.permute(0, 2, 3, 1).view(B, H * W, C) k = k.view(B, C, H * W) w = torch.bmm(q, k) * (int(C) ** (-0.5)) assert list(w.shape) == [B, H * W, H * W] w = F.softmax(w, dim=-1) v = v.permute(0, 2, 3, 1).view(B, H * W, C) h = torch.bmm(w, v) assert list(h.shape) == [B, H * W, C] h = h.view(B, H, W, C).permute(0, 3, 1, 2) h = self.proj(h) return x + h class ResBlock(nn.Module): def __init__(self, in_ch, out_ch, tdim, dropout, attn=False): super().__init__() self.block1 = nn.Sequential( nn.GroupNorm(32, in_ch), Swish(), nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1), ) self.temb_proj = nn.Sequential( Swish(), nn.Linear(tdim, out_ch), ) self.block2 = nn.Sequential( nn.GroupNorm(32, out_ch), Swish(), nn.Dropout(dropout), nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1), ) if in_ch != out_ch: self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0) else: self.shortcut = nn.Identity() if attn: self.attn = AttnBlock(out_ch) else: self.attn = nn.Identity() self.initialize() def initialize(self): for module in self.modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): init.xavier_uniform_(module.weight) init.zeros_(module.bias) init.xavier_uniform_(self.block2[-1].weight, gain=1e-5) def forward(self, x, temb): h = self.block1(x) h += self.temb_proj(temb)[:, :, None, None] h = self.block2(h) h = h + self.shortcut(x) h = self.attn(h) return h class UNet(nn.Module): def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.middleblocks = nn.ModuleList([ ResBlock(now_ch, now_ch, tdim, dropout, attn=True), ResBlock(now_ch, now_ch, tdim, dropout, attn=False), ]) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Middle # torch.Size([8, 512, 4, 4]) for layer in self.middleblocks: h = layer(h, temb) # torch.Size([8, 512, 4, 4]) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h class UNet_Baseline(nn.Module): # Remove the middle blocks def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout): super().__init__() assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound' tdim = ch * 4 self.time_embedding = TimeEmbedding(T, ch, tdim) self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1) self.downblocks = nn.ModuleList() chs = [ch] # record output channel when dowmsample for upsample now_ch = ch for i, mult in enumerate(ch_mult): out_ch = ch * mult for _ in range(num_res_blocks): self.downblocks.append(ResBlock( in_ch=now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch chs.append(now_ch) if i != len(ch_mult) - 1: self.downblocks.append(DownSample(now_ch)) chs.append(now_ch) self.upblocks = nn.ModuleList() for i, mult in reversed(list(enumerate(ch_mult))): out_ch = ch * mult for _ in range(num_res_blocks + 1): self.upblocks.append(ResBlock( in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim, dropout=dropout, attn=(i in attn))) now_ch = out_ch if i != 0: self.upblocks.append(UpSample(now_ch)) assert len(chs) == 0 self.tail = nn.Sequential( nn.GroupNorm(32, now_ch), Swish(), nn.Conv2d(now_ch, 3, 3, stride=1, padding=1) ) self.initialize() def initialize(self): init.xavier_uniform_(self.head.weight) init.zeros_(self.head.bias) init.xavier_uniform_(self.tail[-1].weight, gain=1e-5) init.zeros_(self.tail[-1].bias) def forward(self, x, t): # Timestep embedding temb = self.time_embedding(t) # Downsampling h = self.head(x) hs = [h] for layer in self.downblocks: h = layer(h, temb) hs.append(h) # Upsampling for layer in self.upblocks: if isinstance(layer, ResBlock): h = torch.cat([h, hs.pop()], dim=1) h = layer(h, temb) h = self.tail(h) assert len(hs) == 0 return h if __name__ == '__main__': batch_size = 8 model = UNet( T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[1], num_res_blocks=2, dropout=0.1) ================================================ FILE: Diffusion_UKAN/Diffusion/__init__.py ================================================ from .Diffusion import * from .UNet import * from .Train import * ================================================ FILE: Diffusion_UKAN/Diffusion/kan_utils/__init__.py ================================================ from .kan import * from .fastkanconv import * # from .kan_convolutional import * ================================================ FILE: Diffusion_UKAN/Diffusion/kan_utils/fastkanconv.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Tuple, Union class PolynomialFunction(nn.Module): def __init__(self, degree: int = 3): super().__init__() self.degree = degree def forward(self, x): return torch.stack([x ** i for i in range(self.degree)], dim=-1) class BSplineFunction(nn.Module): def __init__(self, grid_min: float = -2., grid_max: float = 2., degree: int = 3, num_basis: int = 8): super(BSplineFunction, self).__init__() self.degree = degree self.num_basis = num_basis self.knots = torch.linspace(grid_min, grid_max, num_basis + degree + 1) # Uniform knots def basis_function(self, i, k, t): if k == 0: return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float() else: left_num = (t - self.knots[i]) * self.basis_function(i, k - 1, t) left_den = self.knots[i + k] - self.knots[i] left = left_num / left_den if left_den != 0 else 0 right_num = (self.knots[i + k + 1] - t) * self.basis_function(i + 1, k - 1, t) right_den = self.knots[i + k + 1] - self.knots[i + 1] right = right_num / right_den if right_den != 0 else 0 return left + right def forward(self, x): x = x.squeeze() # Assuming x is of shape (B, 1) basis_functions = torch.stack([self.basis_function(i, self.degree, x) for i in range(self.num_basis)], dim=-1) return basis_functions class ChebyshevFunction(nn.Module): def __init__(self, degree: int = 4): super(ChebyshevFunction, self).__init__() self.degree = degree def forward(self, x): chebyshev_polynomials = [torch.ones_like(x), x] for n in range(2, self.degree): chebyshev_polynomials.append(2 * x * chebyshev_polynomials[-1] - chebyshev_polynomials[-2]) return torch.stack(chebyshev_polynomials, dim=-1) class FourierBasisFunction(nn.Module): def __init__(self, num_frequencies: int = 4, period: float = 1.0): super(FourierBasisFunction, self).__init__() assert num_frequencies % 2 == 0, "num_frequencies must be even" self.num_frequencies = num_frequencies self.period = nn.Parameter(torch.Tensor([period]), requires_grad=False) def forward(self, x): frequencies = torch.arange(1, self.num_frequencies // 2 + 1, device=x.device) sin_components = torch.sin(2 * torch.pi * frequencies * x[..., None] / self.period) cos_components = torch.cos(2 * torch.pi * frequencies * x[..., None] / self.period) basis_functions = torch.cat([sin_components, cos_components], dim=-1) return basis_functions class RadialBasisFunction(nn.Module): def __init__( self, grid_min: float = -2., grid_max: float = 2., num_grids: int = 4, denominator: float = None, ): super().__init__() grid = torch.linspace(grid_min, grid_max, num_grids) self.grid = torch.nn.Parameter(grid, requires_grad=False) self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1) def forward(self, x): return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2) class SplineConv2D(nn.Conv2d): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, init_scale: float = 0.1, padding_mode: str = "zeros", **kw ) -> None: self.init_scale = init_scale super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, **kw ) def reset_parameters(self) -> None: nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale) if self.bias is not None: nn.init.zeros_(self.bias) class FastKANConvLayer(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, grid_min: float = -2., grid_max: float = 2., num_grids: int = 4, use_base_update: bool = True, base_activation = F.silu, spline_weight_init_scale: float = 0.1, padding_mode: str = "zeros", kan_type: str = "BSpline", # kan_type: str = "RBF", ) -> None: super().__init__() if kan_type == "RBF": self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids) elif kan_type == "Fourier": self.rbf = FourierBasisFunction(num_grids) elif kan_type == "Poly": self.rbf = PolynomialFunction(num_grids) elif kan_type == "Chebyshev": self.rbf = ChebyshevFunction(num_grids) elif kan_type == "BSpline": self.rbf = BSplineFunction(grid_min, grid_max, 3, num_grids) self.spline_conv = SplineConv2D(in_channels * num_grids, out_channels, kernel_size, stride, padding, dilation, groups, bias, spline_weight_init_scale, padding_mode) self.use_base_update = use_base_update if use_base_update: self.base_activation = base_activation self.base_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode) def forward(self, x): batch_size, channels, height, width = x.shape x_rbf = self.rbf(x.view(batch_size, channels, -1)).view(batch_size, channels, height, width, -1) x_rbf = x_rbf.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, -1, height, width) # Apply spline convolution ret = self.spline_conv(x_rbf) if self.use_base_update: base = self.base_conv(self.base_activation(x)) ret = ret + base return ret ================================================ FILE: Diffusion_UKAN/Diffusion/kan_utils/kan.py ================================================ import torch import torch.nn.functional as F import math class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class KAN(torch.nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = torch.nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) ================================================ FILE: Diffusion_UKAN/Diffusion/utils.py ================================================ import argparse import torch.nn as nn class qkv_transform(nn.Conv1d): """Conv1d for qkv_transform""" def str2bool(v): if v.lower() in ['true', 1]: return True elif v.lower() in ['false', 0]: return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def count_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count ================================================ FILE: Diffusion_UKAN/Main.py ================================================ from Diffusion.Train import train, eval import os import argparse import torch import numpy as np def main(model_config = None): if model_config is not None: modelConfig = model_config if modelConfig["state"] == "train": train(modelConfig) modelConfig['batch_size'] = 64 modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch']) for i in range(32): modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i) eval(modelConfig) else: for i in range(32): modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i) eval(modelConfig) def seed_all(args): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False np.random.seed(args.seed) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--state', type=str, default='train') # train or eval parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc parser.add_argument('--epoch', type=int, default=1000) # 1000 for cvc/glas, 5000 for busi parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--T', type=int, default=1000) parser.add_argument('--channel', type=int, default=64) # 64 or 128 parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt') parser.add_argument('--num_res_blocks', type=int, default=2) parser.add_argument('--dropout', type=float, default=0.15) parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--img_size', type=float, default=64) parser.add_argument('--dataset_repeat', type=int, default=1) # did not use parser.add_argument('--seed', type=int, default=0) # did not use parser.add_argument('--model', type=str, default='UKAN_Hybrid') parser.add_argument('--exp_nme', type=str, default='UKAN_Hybrid') parser.add_argument('--save_root', type=str, default='./Output/') args = parser.parse_args() save_root = args.save_root if args.seed != 0: seed_all(args) modelConfig = { "dataset": args.dataset, "state": args.state, # or eval "epoch": args.epoch, "batch_size": args.batch_size, "T": args.T, "channel": args.channel, "channel_mult": [1, 2, 3, 4], "attn": [2], "num_res_blocks": args.num_res_blocks, "dropout": args.dropout, "lr": args.lr, "multiplier": 2., "beta_1": 1e-4, "beta_T": 0.02, "img_size": 64, "grad_clip": 1., "device": "cuda", ### MAKE SURE YOU HAVE A GPU !!! "training_load_weight": None, "save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"), "sampled_dir": os.path.join(save_root, args.exp_nme, "Gens"), "test_load_weight": args.test_load_weight, "sampledNoisyImgName": "NoisyNoGuidenceImgs.png", "sampledImgName": "SampledNoGuidenceImgs.png", "nrow": 8, "model":args.model, "version": 1, "dataset_repeat": args.dataset_repeat, "seed": args.seed, "save_root": args.save_root, } os.makedirs(modelConfig["save_weight_dir"], exist_ok=True) os.makedirs(modelConfig["sampled_dir"], exist_ok=True) # backup import shutil shutil.copy("Diffusion/Model_UKAN_Hybrid.py", os.path.join(save_root, args.exp_nme)) shutil.copy("Diffusion/Train.py", os.path.join(save_root, args.exp_nme)) main(modelConfig) ================================================ FILE: Diffusion_UKAN/Main_Test.py ================================================ from Diffusion.Train import train, eval, eval_tmp import os import argparse import torch def main(model_config = None): if model_config is not None: modelConfig = model_config if modelConfig["state"] == "train": train(modelConfig) modelConfig['batch_size'] = 64 modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch']) for i in range(32): modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i) eval(modelConfig) else: for i in range(1): modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i) eval_tmp(modelConfig,1000) # for grid visualization # eval(modelConfig) # for metric evaluation def seed_all(args): torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False import numpy as np np.random.seed(args.seed) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--state', type=str, default='eval') parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc parser.add_argument('--epoch', type=int, default=1000) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--T', type=int, default=1000) parser.add_argument('--channel', type=int, default=64) parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt') parser.add_argument('--num_res_blocks', type=int, default=2) parser.add_argument('--dropout', type=float, default=0.15) parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--img_size', type=float, default=64) # 64 or 128 parser.add_argument('--dataset_repeat', type=int, default=1) # didnot use parser.add_argument('--seed', type=int, default=0) parser.add_argument('--model', type=str, default='UKan_Hybrid') parser.add_argument('--exp_nme', type=str, default='./') parser.add_argument('--save_root', type=str, default='released_models/ukan_cvc') # parser.add_argument('--save_root', type=str, default='released_models/ukan_glas') # parser.add_argument('--save_root', type=str, default='released_models/ukan_busi') args = parser.parse_args() save_root = args.save_root if args.seed != 0: seed_all(args) modelConfig = { "dataset": args.dataset, "state": args.state, # or eval "epoch": args.epoch, "batch_size": args.batch_size, "T": args.T, "channel": args.channel, "channel_mult": [1, 2, 3, 4], "attn": [2], "num_res_blocks": args.num_res_blocks, "dropout": args.dropout, "lr": args.lr, "multiplier": 2., "beta_1": 1e-4, "beta_T": 0.02, "img_size": 64, "grad_clip": 1., "device": "cuda", ### MAKE SURE YOU HAVE A GPU !!! "training_load_weight": None, "save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"), "sampled_dir": os.path.join(save_root, args.exp_nme, "FinalCheck"), "test_load_weight": args.test_load_weight, "sampledNoisyImgName": "NoisyNoGuidenceImgs.png", "sampledImgName": "SampledNoGuidenceImgs.png", "nrow": 8, "model":args.model, "version": 1, "dataset_repeat": args.dataset_repeat, "seed": args.seed, "save_root": args.save_root, } os.makedirs(modelConfig["save_weight_dir"], exist_ok=True) os.makedirs(modelConfig["sampled_dir"], exist_ok=True) main(modelConfig) ================================================ FILE: Diffusion_UKAN/README.md ================================================ # Diffusion UKAN (arxiv) > [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)
> [Chenxin Li](https://xggnet.github.io/)\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)\*, [Hengyu Liu](), [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)
The Chinese Univerisity of Hong Kong Contact: wuyangli@cuhk.edu.hk ## 💡 Environment You can change the torch and Cuda versions to satisfy your device. ```bash conda create --name UKAN python=3.10 conda activate UKAN conda install cudatoolkit=11.3 pip install -r requirement.txt ``` ## 🖼️ Gallery of Diffusion UKAN ![image](./assets/gen.png) ## 📚 Prepare datasets Download the pre-processed dataset from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/ESqX-V_eLSBEuaJXAzf64JMB16xF9kz3661pJSwQ-hOspg?e=XdABCH) and unzip it into the project folder. The data is pre-processed by the scripts in [tools](./tools). ``` Diffusion_UKAN | data | └─ cvc | └─ images_64 | └─ busi | └─ images_64 | └─ glas | └─ images_64 ``` ## 📦 Prepare pre-trained models Download released_models from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/EUVSH8QFUmpJlxyoEj8Pr2IB8PzGbVJg53rc6GcqxGgLDg?e=a4glNt) and unzip it in the project folder. ``` Diffusion_UKAN | released_models | └─ ukan_cvc | └─ FinalCheck   # generated toy images (see next section) | └─ Gens         # the generated images used for evaluation in our paper | └─ Tmp          # saved generated images during model training with a 50-epoch interval | └─ Weights      # The final checkpoint | └─ FID.txt      # raw evaluation data | └─ IS.txt       # raw evaluation data   | └─ ukan_busi | └─ ukan_glas ``` ## 🧸 Toy example Images will be generated in `released_models/ukan_cvc/FinalCheck` by running this: ```python python Main_Test.py ``` ## 🔥 Training Please refer to the [training_scripts](./training_scripts) folder. Besides, you can play with different network variations by modifying `MODEL` according to the following dictionary, ```python model_dict = { 'UNet': UNet, 'UNet_ConvKan': UNet_ConvKan, 'UMLP': UMLP, 'UKan_Hybrid': UKan_Hybrid, 'UNet_Baseline': UNet_Baseline, } ``` ## 🤞 Acknowledgement Thanks for We mainly appreciate these excellent projects - [Simple DDPM](https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-) - [Kolmogorov-Arnold Network](https://github.com/mintisan/awesome-kan) - [Efficient Kolmogorov-Arnold Network](https://github.com/Blealtan/efficient-kan.git) ## 📜Citation If you find this work helpful for your project, please consider citing the following paper: ``` @article{li2024ukan, title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation}, author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan}, journal={arXiv preprint arXiv:2406.02918}, year={2024} } ``` ================================================ FILE: Diffusion_UKAN/Scheduler.py ================================================ from torch.optim.lr_scheduler import _LRScheduler class GradualWarmupScheduler(_LRScheduler): def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None): self.multiplier = multiplier self.total_epoch = warm_epoch self.after_scheduler = after_scheduler self.finished = False self.last_epoch = None self.base_lrs = None super().__init__(optimizer) def get_lr(self): if self.last_epoch > self.total_epoch: if self.after_scheduler: if not self.finished: self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] self.finished = True return self.after_scheduler.get_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] def step(self, epoch=None, metrics=None): if self.finished and self.after_scheduler: if epoch is None: self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.total_epoch) else: return super(GradualWarmupScheduler, self).step(epoch) ================================================ FILE: Diffusion_UKAN/data/readme.txt ================================================ download data.zip and unzip here ================================================ FILE: Diffusion_UKAN/inception-score-pytorch/LICENSE.md ================================================ Copyright 2017 Shane T. Barratt Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Diffusion_UKAN/inception-score-pytorch/README.md ================================================ # Inception Score Pytorch Pytorch was lacking code to calculate the Inception Score for GANs. This repository fills this gap. However, we do not recommend using the Inception Score to evaluate generative models, see [our note](https://arxiv.org/abs/1801.01973) for why. ## Getting Started Clone the repository and navigate to it: ``` $ git clone git@github.com:sbarratt/inception-score-pytorch.git $ cd inception-score-pytorch ``` To generate random 64x64 images and calculate the inception score, do the following: ``` $ python inception_score.py ``` The only function is `inception_score`. It takes a list of numpy images normalized to the range [0,1] and a set of arguments and then calculates the inception score. Please assure your images are 3x299x299 and if not (e.g. your GAN was trained on CIFAR), pass `resize=True` to the function to have it automatically resize using bilinear interpolation before passing the images to the inception network. ```python def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): """Computes the inception score of the generated images imgs imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] cuda -- whether or not to run on GPU batch_size -- batch size for feeding into Inception v3 splits -- number of splits """ ``` ### Prerequisites You will need [torch](http://pytorch.org/), [torchvision](https://github.com/pytorch/vision), [numpy/scipy](https://scipy.org/). ## License This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details ## Acknowledgments * Inception Score from [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) ================================================ FILE: Diffusion_UKAN/inception-score-pytorch/inception_score.py ================================================ import torch from torch import nn from torch.autograd import Variable from torch.nn import functional as F import torch.utils.data from torchvision.models.inception import inception_v3 import os from skimage import io import cv2 import os import numpy as np from scipy.stats import entropy import torchvision.datasets as dset import torchvision.transforms as transforms import argparse def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=32): """Computes the inception score of the generated images imgs imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] cuda -- whether or not to run on GPU batch_size -- batch size for feeding into Inception v3 splits -- number of splits """ N = len(imgs) assert batch_size > 0 assert N > batch_size # Set up dtype if cuda: dtype = torch.cuda.FloatTensor else: if torch.cuda.is_available(): print("WARNING: You have a CUDA device, so you should probably set cuda=True") dtype = torch.FloatTensor # Set up dataloader dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) # Load inception model inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) inception_model.eval(); up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) def get_pred(x): if resize: x = up(x) x = inception_model(x) return F.softmax(x).data.cpu().numpy() # Get predictions preds = np.zeros((N, 1000)) for i, batch in enumerate(dataloader, 0): batch = batch.type(dtype) batchv = Variable(batch) batch_size_i = batch.size()[0] preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) # Now compute the mean kl-div split_scores = [] for k in range(splits): part = preds[k * (N // splits): (k+1) * (N // splits), :] py = np.mean(part, axis=0) scores = [] for i in range(part.shape[0]): pyx = part[i, :] scores.append(entropy(pyx, py)) split_scores.append(np.exp(np.mean(scores))) return np.mean(split_scores), np.std(split_scores) class UnlabeledDataset(torch.utils.data.Dataset): def __init__(self, folder, transform=None): self.folder = folder self.transform = transform self.image_files = os.listdir(folder) def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_file = self.image_files[idx] image_path = os.path.join(self.folder, image_file) image = io.imread(image_path) if self.transform: image = self.transform(image) return image class IgnoreLabelDataset(torch.utils.data.Dataset): def __init__(self, orig): self.orig = orig def __getitem__(self, index): return self.orig[index][0] def __len__(self): return len(self.orig) if __name__ == '__main__': # cifar = dset.CIFAR10(root='data/', download=True, # transform=transforms.Compose([ # transforms.Resize(32),`` # transforms.ToTensor(), # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # ]) # ) transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # set args parser = argparse.ArgumentParser() parser.add_argument('--data-root', type=str, default='/data/wyli/code/TinyDDPM/Output/unet_busi/Gens/') args = parser.parse_args() dataset = UnlabeledDataset(args.data_root, transform=transform) print ("Calculating Inception Score...") print (inception_score(dataset, cuda=True, batch_size=1, resize=True, splits=10)) ================================================ FILE: Diffusion_UKAN/released_models/readme.txt ================================================ download released_models.zip and unzip here ================================================ FILE: Diffusion_UKAN/requirements.txt ================================================ pytorch-fid==0.30.0 torch==2.3.0 torchvision==0.18.0 tqdm timm==0.9.16 scikit-image==0.23.1 ================================================ FILE: Diffusion_UKAN/tools/resive_cvc.py ================================================ import os from skimage import io, transform from skimage.util import img_as_ubyte import numpy as np # Define the source and destination directories src_dir = '/data/wyli/data/CVC-ClinicDB/Original/' dst_dir = '/data/wyli/data/cvc/images_64/' os.makedirs(dst_dir, exist_ok=True) # Get a list of all the image files in the source directory image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))] # Define the size of the crop box crop_size = np.array([288 ,288]) # Define the size of the resized image resize_size = (64, 64) for image_file in image_files: # Load the image image = io.imread(os.path.join(src_dir, image_file)) # print(image.shape) # Calculate the center of the image center = np.array(image.shape[:2]) // 2 # Calculate the start and end points of the crop box start = center - crop_size // 2 end = start + crop_size # Crop the image cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]]) # Resize the cropped image resized_image = transform.resize(cropped_image, resize_size, mode='reflect') # Save the resized image to the destination directory io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image)) ================================================ FILE: Diffusion_UKAN/tools/resize_busi.py ================================================ import os from skimage import io, transform from skimage.util import img_as_ubyte import numpy as np # Define the source and destination directories src_dir = '/data/wyli/data/busi/images/' dst_dir = '/data/wyli/data/busi/images_64/' os.makedirs(dst_dir, exist_ok=True) # Get a list of all the image files in the source directory image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))] # Define the size of the crop box crop_size = np.array([400 ,400]) # Define the size of the resized image # resize_size = (64, 64) resize_size = (64, 64) for image_file in image_files: # Load the image image = io.imread(os.path.join(src_dir, image_file)) print(image.shape) # Calculate the center of the image center = np.array(image.shape[:2]) // 2 # Calculate the start and end points of the crop box start = center - crop_size // 2 end = start + crop_size # Crop the image cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]]) # Resize the cropped image resized_image = transform.resize(cropped_image, resize_size, mode='reflect') # Save the resized image to the destination directory io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image)) ================================================ FILE: Diffusion_UKAN/tools/resize_glas.py ================================================ import os from skimage import io, transform from skimage.util import img_as_ubyte import numpy as np import random # Define the source and destination directories src_dir = '/data/wyli/data/glas/images/' dst_dir = '/data/wyli/data/glas/images_64/' os.makedirs(dst_dir, exist_ok=True) # Get a list of all the image files in the source directory image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))] # Define the size of the crop box crop_size = np.array([64, 64]) # Define the number of crops per image K = 5 for image_file in image_files: # Load the image image = io.imread(os.path.join(src_dir, image_file)) # Get the size of the image image_size = np.array(image.shape[:2]) for i in range(K): # Calculate a random start point for the crop box start = np.array([random.randint(0, image_size[0] - crop_size[0]), random.randint(0, image_size[1] - crop_size[1])]) # Calculate the end point of the crop box end = start + crop_size # Crop the image cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]]) # Save the cropped image to the destination directory io.imsave(os.path.join(dst_dir, f"{image_file}_{i}.png"), cropped_image) ================================================ FILE: Diffusion_UKAN/training_scripts/busi.sh ================================================ ##!/bin/bash source ~/miniconda3/etc/profile.d/conda.sh conda activate kan GPU=0 MODEL='UKan_Hybrid' EXP_NME='UKan_cvc' SAVE_ROOT='./Output/' DATASET='busi' cd ../ CUDA_VISIBLE_DEVICES=${GPU} python Main.py \ --model ${MODEL} \ --exp_nme ${EXP_NME} \ --batch_size 32 \ --channel 64 \ --dataset ${DATASET} \ --epoch 5000 \ --save_root ${SAVE_ROOT} # --lr 1e-4 # calcuate FID and IS CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1 cd inception-score-pytorch CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1 ================================================ FILE: Diffusion_UKAN/training_scripts/cvc.sh ================================================ ##!/bin/bash source ~/miniconda3/etc/profile.d/conda.sh conda activate kan GPU=0 MODEL='UKan_Hybrid' EXP_NME='UKan_cvc' SAVE_ROOT='./Output/' DATASET='cvc' cd ../ CUDA_VISIBLE_DEVICES=${GPU} python Main.py \ --model ${MODEL} \ --exp_nme ${EXP_NME} \ --batch_size 32 \ --channel 64 \ --dataset ${DATASET} \ --epoch 1000 \ --save_root ${SAVE_ROOT} # --lr 1e-4 # calcuate FID and IS CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1 cd inception-score-pytorch CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1 ================================================ FILE: Diffusion_UKAN/training_scripts/glas.sh ================================================ ##!/bin/bash source ~/miniconda3/etc/profile.d/conda.sh conda activate kan GPU=0 MODEL='UKan_Hybrid' EXP_NME='UKan_glas' SAVE_ROOT='./Output/' DATASET='glas' cd ../ CUDA_VISIBLE_DEVICES=${GPU} python Main.py \ --model ${MODEL} \ --exp_nme ${EXP_NME} \ --batch_size 32 \ --channel 64 \ --dataset ${DATASET} \ --epoch 1000 \ --save_root ${SAVE_ROOT} # --lr 1e-4 # calcuate FID and IS CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1 cd inception-score-pytorch CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1 ================================================ FILE: README.md ================================================ # U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation :pushpin: This is an official PyTorch implementation of **U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation** [[`Project Page`](https://yes-u-kan.github.io/)] [[`arXiv`](https://arxiv.org/abs/2406.02918)] [[`BibTeX`](#citation)]

> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)
> [Chenxin Li](https://xggnet.github.io/)1\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)1\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)1\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)1\*, [Hengyu Liu](https://liuhengyu321.github.io/)1, [Yifan Liu](https://yifliu3.github.io/)1, [Chen Zhen](https://franciszchen.github.io/)2, [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)1✉
1The Chinese Univerisity of Hong Kong, 2Centre for Artificial Intelligence and Robotics, Hong Kong We explore the untapped potential of Kolmogorov-Anold Network (aka. KAN) in improving backbones for vision tasks. We investigate, modify and re-design the established U-Net pipeline by integrating the dedicated KAN layers on the tokenized intermediate representation, termed U-KAN. Rigorous medical image segmentation benchmarks verify the superiority of U-KAN by higher accuracy even with less computation cost. We further delved into the potential of U-KAN as an alternative U-Net noise predictor in diffusion models, demonstrating its applicability in generating task-oriented model architectures. These endeavours unveil valuable insights and sheds light on the prospect that with U-KAN, you can make strong backbone for medical image segmentation and generation.
UKAN overview
## 📰News **[NOTE]** Random seed is essential for eval metric, and all reported results are calculated over three random runs with seeds of 2981, 6142, 1187, following rolling-UNet. We think most issues are related with this. **[2024.10]** U-KAN is accepted by AAAI-25. **[2024.6]** Some modifications are done in Seg_UKAN for better performance reproduction. The previous code can be quickly updated by replacing the contents of train.py and archs.py with the new ones. **[2024.6]** Model checkpoints and training logs are released! **[2024.6]** Code and paper of U-KAN are released! ## 💡Key Features - The first effort to incorporate the advantage of emerging KAN to improve established U-Net pipeline to be more **accurate, efficient and interpretable**. - A Segmentation U-KAN with **tokenized KAN block to effectively steer the KAN operators** to be compatible with the exiting convolution-based designs. - A Diffusion U-KAN as an **improved noise predictor** demonstrates its potential in backboning generative tasks and broader vision settings. ## 🛠Setup ```bash git clone https://github.com/CUHK-AIM-Group/U-KAN.git cd U-KAN conda create -n ukan python=3.10 conda activate ukan cd Seg_UKAN && pip install -r requirements.txt ``` **Tips A**: We test the framework using pytorch=1.13.0, and the CUDA compile version=11.6. Other versions should be also fine but not totally ensured. ## 📚Data Preparation **BUSI**: The dataset can be found [here](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset). **GLAS**: The dataset can be found [here](https://websignon.warwick.ac.uk/origin/slogin?shire=https%3A%2F%2Fwarwick.ac.uk%2Fsitebuilder2%2Fshire-read&providerId=urn%3Awarwick.ac.uk%3Asitebuilder2%3Aread%3Aservice&target=https%3A%2F%2Fwarwick.ac.uk%2Ffac%2Fcross_fac%2Ftia%2Fdata%2Fglascontest&status=notloggedin). **CVC-ClinicDB**: The dataset can be found [here](https://www.dropbox.com/s/p5qe9eotetjnbmq/CVC-ClinicDB.rar?e=3&dl=0). We also provide all the [pre-processed dataset](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/ErDlT-t0WoBNlKhBlbYfReYB-iviSCmkNRb1GqZ90oYjJA?e=hrPNWD) without requiring any further data processing. You can directly download and put them into the data dir. The resulted file structure is as follows. ``` Seg_UKAN ├── inputs │ ├── busi │ ├── images │ ├── malignant (1).png | ├── ... | ├── masks │ ├── 0 │ ├── malignant (1)_mask.png | ├── ... │ ├── GLAS │ ├── images │ ├── 0.png | ├── ... | ├── masks │ ├── 0 │ ├── 0.png | ├── ... │ ├── CVC-ClinicDB │ ├── images │ ├── 0.png | ├── ... | ├── masks │ ├── 0 │ ├── 0.png | ├── ... ``` ## 🔖Evaluating Segmentation U-KAN You can directly evaluate U-KAN from the checkpoint model. Here is an example for quick usage for using our **pre-trained models** in [Segmentation Model Zoo](#segmentation-model-zoo): 1. Download the pre-trained weights and put them to ```{args.output_dir}/{args.name}/model.pth``` 2. Run the following scripts to ```bash cd Seg_UKAN python val.py --name ${dataset}_UKAN --output_dir [YOUR_OUTPUT_DIR] ``` ## ⏳Training Segmentation U-KAN You can simply train U-KAN on a single GPU by specifing the dataset name ```--dataset``` and input size ```--input_size```. ```bash cd Seg_UKAN python train.py --arch UKAN --dataset {dataset} --input_w {input_size} --input_h {input_size} --name {dataset}_UKAN --data_dir [YOUR_DATA_DIR] ``` For example, train U-KAN with the resolution of 256x256 with a single GPU on the BUSI dataset in the ```inputs``` dir: ```bash cd Seg_UKAN python train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN --data_dir ./inputs ``` Please see Seg_UKAN/scripts.sh for more details. Note that the resolution of glas is 512x512, differing with other datasets (256x256). **[Quick Update]** Please follow the seeds of 2981, 6142, 1187 to fully reproduce the paper experimental results. All compared methods are evaluated on the same seed setting. ## 🎪Segmentation Model Zoo We provide all the pre-trained model [checkpoints](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ej6yZBSIrU5Ds9q-gQdhXqwBbpov5_MaWF483uZHm2lccA?e=rmlHMo) Here is an overview of the released performance&checkpoints. Note that results on a single run and the reported average results in the paper differ. |Method| Dataset | IoU | F1 | Checkpoints | |-----|------|-----|-----|-----| |Seg U-KAN| BUSI | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)| |Seg U-KAN| GLAS | 87.51 | 93.33 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EunQ9KRf6n1AqCJ40FWZF-QB25GMOoF7hoIwU15fefqFbw?e=m7kXwe)| |Seg U-KAN| CVC-ClinicDB | 85.61 | 92.19 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ekhb3PEmwZZMumSG69wPRRQBymYIi0PFNuLJcVNmmK1fjA?e=5XzVSi)| The parameter ``--no_kan'' denotes the baseline model that is replaced the KAN layers with MLP layers. Please note: it is reasonable to find occasional inconsistencies in performance, and the average results over multiple runs can reveal a more obvious trend. |Method| Layer Type | IoU | F1 | Checkpoints | |-----|------|-----|-----|-----| |Seg U-KAN (--no_kan)| MLP Layer | 63.49 | 77.07 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EmEH_qokqIFNtP59yU7vY_4Bq4Yc424zuYufwaJuiAGKiw?e=IJ3clx)| |Seg U-KAN| KAN Layer | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)| ## 🎇Medical Image Generation with Diffusion U-KAN Please refer to [Diffusion_UKAN](./Diffusion_UKAN/README.md) ## 🛒TODO List - [X] Release code for Seg U-KAN. - [X] Release code for Diffusion U-KAN. - [X] Upload the pretrained checkpoints. ## 🎈Acknowledgements Greatly appreciate the tremendous effort for the following projects! - [CKAN](https://github.com/AntonioTepsich/Convolutional-KANs) ## 📜Citation If you find this work helpful for your project,please consider citing the following paper: ``` @article{li2024ukan, title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation}, author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan}, journal={arXiv preprint arXiv:2406.02918}, year={2024} ''' } ================================================ FILE: Seg_UKAN/LICENSE ================================================ MIT License Copyright (c) 2022 Jeya Maria Jose Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Seg_UKAN/archs.py ================================================ import torch from torch import nn import torch import torchvision from torch import nn from torch.autograd import Variable from torch.utils.data import DataLoader from torchvision import transforms from torchvision.utils import save_image import torch.nn.functional as F import os import matplotlib.pyplot as plt from utils import * import timm from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import types import math from abc import ABCMeta, abstractmethod # from mmcv.cnn import ConvModule from pdb import set_trace as st from kan import KANLinear, KAN from torch.nn import init class KANLayer(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.dim = in_features grid_size=5 spline_order=3 scale_noise=0.1 scale_base=1.0 scale_spline=1.0 base_activation=torch.nn.SiLU grid_eps=0.02 grid_range=[-1, 1] if not no_kan: self.fc1 = KANLinear( in_features, hidden_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.fc2 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) self.fc3 = KANLinear( hidden_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) # # TODO # self.fc4 = KANLinear( # hidden_features, # out_features, # grid_size=grid_size, # spline_order=spline_order, # scale_noise=scale_noise, # scale_base=scale_base, # scale_spline=scale_spline, # base_activation=base_activation, # grid_eps=grid_eps, # grid_range=grid_range, # ) else: self.fc1 = nn.Linear(in_features, hidden_features) self.fc2 = nn.Linear(hidden_features, out_features) self.fc3 = nn.Linear(hidden_features, out_features) # TODO # self.fc1 = nn.Linear(in_features, hidden_features) self.dwconv_1 = DW_bn_relu(hidden_features) self.dwconv_2 = DW_bn_relu(hidden_features) self.dwconv_3 = DW_bn_relu(hidden_features) # # TODO # self.dwconv_4 = DW_bn_relu(hidden_features) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): # pdb.set_trace() B, N, C = x.shape x = self.fc1(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_1(x, H, W) x = self.fc2(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_2(x, H, W) x = self.fc3(x.reshape(B*N,C)) x = x.reshape(B,N,C).contiguous() x = self.dwconv_3(x, H, W) # # TODO # x = x.reshape(B,N,C).contiguous() # x = self.dwconv_4(x, H, W) return x class KANBlock(nn.Module): def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False): super().__init__() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim) self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x, H, W): x = x + self.drop_path(self.layer(self.norm2(x), H, W)) return x class DWConv(nn.Module): def __init__(self, dim=768): super(DWConv, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = x.flatten(2).transpose(1, 2) return x class DW_bn_relu(nn.Module): def __init__(self, dim=768): super(DW_bn_relu, self).__init__() self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) self.bn = nn.BatchNorm2d(dim) self.relu = nn.ReLU() def forward(self, x, H, W): B, N, C = x.shape x = x.transpose(1, 2).view(B, C, H, W) x = self.dwconv(x) x = self.bn(x) x = self.relu(x) x = x.flatten(2).transpose(1, 2) return x class PatchEmbed(nn.Module): """ Image to Patch Embedding """ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2)) self.norm = nn.LayerNorm(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): x = self.proj(x) _, _, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm(x) return x, H, W class ConvLayer(nn.Module): def __init__(self, in_ch, out_ch): super(ConvLayer, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class D_ConvLayer(nn.Module): def __init__(self, in_ch, out_ch): super(D_ConvLayer, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, in_ch, 3, padding=1), nn.BatchNorm2d(in_ch), nn.ReLU(inplace=True), nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True) ) def forward(self, input): return self.conv(input) class UKAN(nn.Module): def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, patch_size=16, in_chans=3, embed_dims=[256, 320, 512], no_kan=False, drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[1, 1, 1], **kwargs): super().__init__() kan_input_dim = embed_dims[0] self.encoder1 = ConvLayer(3, kan_input_dim//8) self.encoder2 = ConvLayer(kan_input_dim//8, kan_input_dim//4) self.encoder3 = ConvLayer(kan_input_dim//4, kan_input_dim) self.norm3 = norm_layer(embed_dims[1]) self.norm4 = norm_layer(embed_dims[2]) self.dnorm3 = norm_layer(embed_dims[1]) self.dnorm4 = norm_layer(embed_dims[0]) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] self.block1 = nn.ModuleList([KANBlock( dim=embed_dims[1], drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer )]) self.block2 = nn.ModuleList([KANBlock( dim=embed_dims[2], drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer )]) self.dblock1 = nn.ModuleList([KANBlock( dim=embed_dims[1], drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer )]) self.dblock2 = nn.ModuleList([KANBlock( dim=embed_dims[0], drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer )]) self.patch_embed3 = PatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) self.patch_embed4 = PatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.decoder1 = D_ConvLayer(embed_dims[2], embed_dims[1]) self.decoder2 = D_ConvLayer(embed_dims[1], embed_dims[0]) self.decoder3 = D_ConvLayer(embed_dims[0], embed_dims[0]//4) self.decoder4 = D_ConvLayer(embed_dims[0]//4, embed_dims[0]//8) self.decoder5 = D_ConvLayer(embed_dims[0]//8, embed_dims[0]//8) self.final = nn.Conv2d(embed_dims[0]//8, num_classes, kernel_size=1) self.soft = nn.Softmax(dim =1) def forward(self, x): B = x.shape[0] ### Encoder ### Conv Stage ### Stage 1 out = F.relu(F.max_pool2d(self.encoder1(x), 2, 2)) t1 = out ### Stage 2 out = F.relu(F.max_pool2d(self.encoder2(out), 2, 2)) t2 = out ### Stage 3 out = F.relu(F.max_pool2d(self.encoder3(out), 2, 2)) t3 = out ### Tokenized KAN Stage ### Stage 4 out, H, W = self.patch_embed3(out) for i, blk in enumerate(self.block1): out = blk(out, H, W) out = self.norm3(out) out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() t4 = out ### Bottleneck out, H, W= self.patch_embed4(out) for i, blk in enumerate(self.block2): out = blk(out, H, W) out = self.norm4(out) out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() ### Stage 4 out = F.relu(F.interpolate(self.decoder1(out), scale_factor=(2,2), mode ='bilinear')) out = torch.add(out, t4) _, _, H, W = out.shape out = out.flatten(2).transpose(1,2) for i, blk in enumerate(self.dblock1): out = blk(out, H, W) ### Stage 3 out = self.dnorm3(out) out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bilinear')) out = torch.add(out,t3) _,_,H,W = out.shape out = out.flatten(2).transpose(1,2) for i, blk in enumerate(self.dblock2): out = blk(out, H, W) out = self.dnorm4(out) out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bilinear')) out = torch.add(out,t2) out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2),mode ='bilinear')) out = torch.add(out,t1) out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear')) return self.final(out) ================================================ FILE: Seg_UKAN/config.py ================================================ # -------------------------------------------------------- # Swin Transformer # Copyright (c) 2021 Microsoft # Licensed under The MIT License [see LICENSE for details] # Written by Ze Liu # --------------------------------------------------------' import os import yaml from yacs.config import CfgNode as CN _C = CN() # Base config files _C.BASE = [''] # ----------------------------------------------------------------------------- # Data settings # ----------------------------------------------------------------------------- _C.DATA = CN() # Batch size for a single GPU, could be overwritten by command line argument _C.DATA.BATCH_SIZE = 1 # Path to dataset, could be overwritten by command line argument _C.DATA.DATA_PATH = '' # Dataset name _C.DATA.DATASET = 'imagenet' # Input image size _C.DATA.IMG_SIZE = 256 # Interpolation to resize image (random, bilinear, bicubic) _C.DATA.INTERPOLATION = 'bicubic' # Use zipped dataset instead of folder dataset # could be overwritten by command line argument _C.DATA.ZIP_MODE = False # Cache Data in Memory, could be overwritten by command line argument _C.DATA.CACHE_MODE = 'part' # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. _C.DATA.PIN_MEMORY = True # Number of data loading threads _C.DATA.NUM_WORKERS = 8 # ----------------------------------------------------------------------------- # Model settings # ----------------------------------------------------------------------------- _C.MODEL = CN() # Model type _C.MODEL.TYPE = 'swin' # Model name _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' # Checkpoint to resume, could be overwritten by command line argument _C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth' _C.MODEL.RESUME = '' # Number of classes, overwritten in data preparation _C.MODEL.NUM_CLASSES = 1000 # Dropout rate _C.MODEL.DROP_RATE = 0.0 # Drop path rate _C.MODEL.DROP_PATH_RATE = 0.1 # Label Smoothing _C.MODEL.LABEL_SMOOTHING = 0.1 # Swin Transformer parameters _C.MODEL.SWIN = CN() _C.MODEL.SWIN.PATCH_SIZE = 4 _C.MODEL.SWIN.IN_CHANS = 3 _C.MODEL.SWIN.EMBED_DIM = 96 _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2] _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] _C.MODEL.SWIN.WINDOW_SIZE = 4 _C.MODEL.SWIN.MLP_RATIO = 4. _C.MODEL.SWIN.QKV_BIAS = True _C.MODEL.SWIN.QK_SCALE = False _C.MODEL.SWIN.APE = False _C.MODEL.SWIN.PATCH_NORM = True _C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first" # ----------------------------------------------------------------------------- # Training settings # ----------------------------------------------------------------------------- _C.TRAIN = CN() _C.TRAIN.START_EPOCH = 0 _C.TRAIN.EPOCHS = 300 _C.TRAIN.WARMUP_EPOCHS = 20 _C.TRAIN.WEIGHT_DECAY = 0.05 _C.TRAIN.BASE_LR = 5e-4 _C.TRAIN.WARMUP_LR = 5e-7 _C.TRAIN.MIN_LR = 5e-6 # Clip gradient norm _C.TRAIN.CLIP_GRAD = 5.0 # Auto resume from latest checkpoint _C.TRAIN.AUTO_RESUME = True # Gradient accumulation steps # could be overwritten by command line argument _C.TRAIN.ACCUMULATION_STEPS = 0 # Whether to use gradient checkpointing to save memory # could be overwritten by command line argument _C.TRAIN.USE_CHECKPOINT = False # LR scheduler _C.TRAIN.LR_SCHEDULER = CN() _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' # Epoch interval to decay LR, used in StepLRScheduler _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 # LR decay rate, used in StepLRScheduler _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 # Optimizer _C.TRAIN.OPTIMIZER = CN() _C.TRAIN.OPTIMIZER.NAME = 'adamw' # Optimizer Epsilon _C.TRAIN.OPTIMIZER.EPS = 1e-8 # Optimizer Betas _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) # SGD momentum _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 # ----------------------------------------------------------------------------- # Augmentation settings # ----------------------------------------------------------------------------- _C.AUG = CN() # Color jitter factor _C.AUG.COLOR_JITTER = 0.4 # Use AutoAugment policy. "v0" or "original" _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' # Random erase prob _C.AUG.REPROB = 0.25 # Random erase mode _C.AUG.REMODE = 'pixel' # Random erase count _C.AUG.RECOUNT = 1 # Mixup alpha, mixup enabled if > 0 _C.AUG.MIXUP = 0.8 # Cutmix alpha, cutmix enabled if > 0 _C.AUG.CUTMIX = 1.0 # Cutmix min/max ratio, overrides alpha and enables cutmix if set _C.AUG.CUTMIX_MINMAX = False # Probability of performing mixup or cutmix when either/both is enabled _C.AUG.MIXUP_PROB = 1.0 # Probability of switching to cutmix when both mixup and cutmix enabled _C.AUG.MIXUP_SWITCH_PROB = 0.5 # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" _C.AUG.MIXUP_MODE = 'batch' # ----------------------------------------------------------------------------- # Testing settings # ----------------------------------------------------------------------------- _C.TEST = CN() # Whether to use center crop when testing _C.TEST.CROP = True # ----------------------------------------------------------------------------- # Misc # ----------------------------------------------------------------------------- # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') # overwritten by command line argument _C.AMP_OPT_LEVEL = '' # Path to output folder, overwritten by command line argument _C.OUTPUT = '' # Tag of experiment, overwritten by command line argument _C.TAG = 'default' # Frequency to save checkpoint _C.SAVE_FREQ = 1 # Frequency to logging info _C.PRINT_FREQ = 10 # Fixed random seed _C.SEED = 0 # Perform evaluation only, overwritten by command line argument _C.EVAL_MODE = False # Test throughput only, overwritten by command line argument _C.THROUGHPUT_MODE = False # local rank for DistributedDataParallel, given by command line argument _C.LOCAL_RANK = 0 def _update_config_from_file(config, cfg_file): config.defrost() with open(cfg_file, 'r') as f: yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) for cfg in yaml_cfg.setdefault('BASE', ['']): if cfg: _update_config_from_file( config, os.path.join(os.path.dirname(cfg_file), cfg) ) print('=> merge config from {}'.format(cfg_file)) config.merge_from_file(cfg_file) config.freeze() def update_config(config, args): _update_config_from_file(config, args.cfg) config.defrost() if args.opts: config.merge_from_list(args.opts) # merge from specific arguments if args.batch_size: config.DATA.BATCH_SIZE = args.batch_size if args.zip: config.DATA.ZIP_MODE = True if args.cache_mode: config.DATA.CACHE_MODE = args.cache_mode if args.resume: config.MODEL.RESUME = args.resume if args.accumulation_steps: config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps if args.use_checkpoint: config.TRAIN.USE_CHECKPOINT = True if args.amp_opt_level: config.AMP_OPT_LEVEL = args.amp_opt_level if args.tag: config.TAG = args.tag if args.eval: config.EVAL_MODE = True if args.throughput: config.THROUGHPUT_MODE = True config.freeze() def get_config(args): """Get a yacs CfgNode object with default values.""" # Return a clone so that the defaults will not be altered # This is for the "local variable" use pattern config = _C.clone() # update_config(config, args) return config ================================================ FILE: Seg_UKAN/dataset.py ================================================ import os import cv2 import numpy as np import torch import torch.utils.data class Dataset(torch.utils.data.Dataset): def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None): """ Args: img_ids (list): Image ids. img_dir: Image file directory. mask_dir: Mask file directory. img_ext (str): Image file extension. mask_ext (str): Mask file extension. num_classes (int): Number of classes. transform (Compose, optional): Compose transforms of albumentations. Defaults to None. Note: Make sure to put the files as the following structure: ├── images | ├── 0a7e06.jpg │ ├── 0aab0a.jpg │ ├── 0b1761.jpg │ ├── ... | └── masks ├── 0 | ├── 0a7e06.png | ├── 0aab0a.png | ├── 0b1761.png | ├── ... | ├── 1 | ├── 0a7e06.png | ├── 0aab0a.png | ├── 0b1761.png | ├── ... ... """ self.img_ids = img_ids self.img_dir = img_dir self.mask_dir = mask_dir self.img_ext = img_ext self.mask_ext = mask_ext self.num_classes = num_classes self.transform = transform def __len__(self): return len(self.img_ids) def __getitem__(self, idx): img_id = self.img_ids[idx] img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext)) mask = [] for i in range(self.num_classes): # print(os.path.join(self.mask_dir, str(i), # img_id + self.mask_ext)) mask.append(cv2.imread(os.path.join(self.mask_dir, str(i), img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None]) mask = np.dstack(mask) if self.transform is not None: augmented = self.transform(image=img, mask=mask) img = augmented['image'] mask = augmented['mask'] img = img.astype('float32') / 255 img = img.transpose(2, 0, 1) mask = mask.astype('float32') / 255 mask = mask.transpose(2, 0, 1) if mask.max()<1: mask[mask>0] = 1.0 return img, mask, {'img_id': img_id} ================================================ FILE: Seg_UKAN/environment.yml ================================================ name: ukan channels: - defaults dependencies: - _libgcc_mutex=0.1=main - _openmp_mutex=4.5=1_gnu - ca-certificates=2021.10.26=h06a4308_2 - certifi=2021.5.30=py36h06a4308_0 - ld_impl_linux-64=2.35.1=h7274673_9 - libffi=3.3=he6710b0_2 - libgcc-ng=9.3.0=h5101ec6_17 - libgomp=9.3.0=h5101ec6_17 - libstdcxx-ng=9.3.0=hd4cf53a_17 - ncurses=6.3=h7f8727e_2 - openssl=1.1.1l=h7f8727e_0 - pip=21.2.2=py36h06a4308_0 - python=3.6.13=h12debd9_1 - readline=8.1=h27cfd23_0 - setuptools=58.0.4=py36h06a4308_0 - sqlite=3.36.0=hc218d9a_0 - tk=8.6.11=h1ccaba5_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h7b6447c_0 - zlib=1.2.11=h7b6447c_3 - pip: - addict==2.4.0 - dataclasses==0.8 - mmcv-full==1.2.7 - numpy==1.19.5 - opencv-python==4.5.1.48 - perceptual==0.1 - pillow==8.4.0 - scikit-image==0.17.2 - scipy==1.5.4 - tifffile==2020.9.3 - timm==0.3.2 - torch==1.7.1 - torchvision==0.8.2 - typing-extensions==4.0.0 - yapf==0.31.0 # prefix: /home/jeyamariajose/anaconda3/envs/transweather ================================================ FILE: Seg_UKAN/kan.py ================================================ import torch import torch.nn.functional as F import math class KANLinear(torch.nn.Module): def __init__( self, in_features, out_features, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, enable_standalone_scale_spline=True, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KANLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.grid_size = grid_size self.spline_order = spline_order h = (grid_range[1] - grid_range[0]) / grid_size grid = ( ( torch.arange(-spline_order, grid_size + spline_order + 1) * h + grid_range[0] ) .expand(in_features, -1) .contiguous() ) self.register_buffer("grid", grid) self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features)) self.spline_weight = torch.nn.Parameter( torch.Tensor(out_features, in_features, grid_size + spline_order) ) if enable_standalone_scale_spline: self.spline_scaler = torch.nn.Parameter( torch.Tensor(out_features, in_features) ) self.scale_noise = scale_noise self.scale_base = scale_base self.scale_spline = scale_spline self.enable_standalone_scale_spline = enable_standalone_scale_spline self.base_activation = base_activation() self.grid_eps = grid_eps self.reset_parameters() def reset_parameters(self): torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base) with torch.no_grad(): noise = ( ( torch.rand(self.grid_size + 1, self.in_features, self.out_features) - 1 / 2 ) * self.scale_noise / self.grid_size ) self.spline_weight.data.copy_( (self.scale_spline if not self.enable_standalone_scale_spline else 1.0) * self.curve2coeff( self.grid.T[self.spline_order : -self.spline_order], noise, ) ) if self.enable_standalone_scale_spline: # torch.nn.init.constant_(self.spline_scaler, self.scale_spline) torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline) def b_splines(self, x: torch.Tensor): """ Compute the B-spline bases for the given input tensor. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). Returns: torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features grid: torch.Tensor = ( self.grid ) # (in_features, grid_size + 2 * spline_order + 1) x = x.unsqueeze(-1) bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype) for k in range(1, self.spline_order + 1): bases = ( (x - grid[:, : -(k + 1)]) / (grid[:, k:-1] - grid[:, : -(k + 1)]) * bases[:, :, :-1] ) + ( (grid[:, k + 1 :] - x) / (grid[:, k + 1 :] - grid[:, 1:(-k)]) * bases[:, :, 1:] ) assert bases.size() == ( x.size(0), self.in_features, self.grid_size + self.spline_order, ) return bases.contiguous() def curve2coeff(self, x: torch.Tensor, y: torch.Tensor): """ Compute the coefficients of the curve that interpolates the given points. Args: x (torch.Tensor): Input tensor of shape (batch_size, in_features). y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features). Returns: torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order). """ assert x.dim() == 2 and x.size(1) == self.in_features assert y.size() == (x.size(0), self.in_features, self.out_features) A = self.b_splines(x).transpose( 0, 1 ) # (in_features, batch_size, grid_size + spline_order) B = y.transpose(0, 1) # (in_features, batch_size, out_features) solution = torch.linalg.lstsq( A, B ).solution # (in_features, grid_size + spline_order, out_features) result = solution.permute( 2, 0, 1 ) # (out_features, in_features, grid_size + spline_order) assert result.size() == ( self.out_features, self.in_features, self.grid_size + self.spline_order, ) return result.contiguous() @property def scaled_spline_weight(self): return self.spline_weight * ( self.spline_scaler.unsqueeze(-1) if self.enable_standalone_scale_spline else 1.0 ) def forward(self, x: torch.Tensor): assert x.dim() == 2 and x.size(1) == self.in_features base_output = F.linear(self.base_activation(x), self.base_weight) spline_output = F.linear( self.b_splines(x).view(x.size(0), -1), self.scaled_spline_weight.view(self.out_features, -1), ) return base_output + spline_output @torch.no_grad() def update_grid(self, x: torch.Tensor, margin=0.01): assert x.dim() == 2 and x.size(1) == self.in_features batch = x.size(0) splines = self.b_splines(x) # (batch, in, coeff) splines = splines.permute(1, 0, 2) # (in, batch, coeff) orig_coeff = self.scaled_spline_weight # (out, in, coeff) orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out) unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out) unreduced_spline_output = unreduced_spline_output.permute( 1, 0, 2 ) # (batch, in, out) # sort each channel individually to collect data distribution x_sorted = torch.sort(x, dim=0)[0] grid_adaptive = x_sorted[ torch.linspace( 0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device ) ] uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size grid_uniform = ( torch.arange( self.grid_size + 1, dtype=torch.float32, device=x.device ).unsqueeze(1) * uniform_step + x_sorted[0] - margin ) grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive grid = torch.concatenate( [ grid[:1] - uniform_step * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1), grid, grid[-1:] + uniform_step * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1), ], dim=0, ) self.grid.copy_(grid.T) self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output)) def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): """ Compute the regularization loss. This is a dumb simulation of the original L1 regularization as stated in the paper, since the original one requires computing absolutes and entropy from the expanded (batch, in_features, out_features) intermediate tensor, which is hidden behind the F.linear function if we want an memory efficient implementation. The L1 regularization is now computed as mean absolute value of the spline weights. The authors implementation also includes this term in addition to the sample-based regularization. """ l1_fake = self.spline_weight.abs().mean(-1) regularization_loss_activation = l1_fake.sum() p = l1_fake / regularization_loss_activation regularization_loss_entropy = -torch.sum(p * p.log()) return ( regularize_activation * regularization_loss_activation + regularize_entropy * regularization_loss_entropy ) class KAN(torch.nn.Module): def __init__( self, layers_hidden, grid_size=5, spline_order=3, scale_noise=0.1, scale_base=1.0, scale_spline=1.0, base_activation=torch.nn.SiLU, grid_eps=0.02, grid_range=[-1, 1], ): super(KAN, self).__init__() self.grid_size = grid_size self.spline_order = spline_order self.layers = torch.nn.ModuleList() for in_features, out_features in zip(layers_hidden, layers_hidden[1:]): self.layers.append( KANLinear( in_features, out_features, grid_size=grid_size, spline_order=spline_order, scale_noise=scale_noise, scale_base=scale_base, scale_spline=scale_spline, base_activation=base_activation, grid_eps=grid_eps, grid_range=grid_range, ) ) def forward(self, x: torch.Tensor, update_grid=False): for layer in self.layers: if update_grid: layer.update_grid(x) x = layer(x) return x def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0): return sum( layer.regularization_loss(regularize_activation, regularize_entropy) for layer in self.layers ) ================================================ FILE: Seg_UKAN/losses.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F try: from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge except ImportError: pass __all__ = ['BCEDiceLoss', 'LovaszHingeLoss'] class BCEDiceLoss(nn.Module): def __init__(self): super().__init__() def forward(self, input, target): bce = F.binary_cross_entropy_with_logits(input, target) smooth = 1e-5 input = torch.sigmoid(input) num = target.size(0) input = input.view(num, -1) target = target.view(num, -1) intersection = (input * target) dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth) dice = 1 - dice.sum() / num return 0.5 * bce + dice class LovaszHingeLoss(nn.Module): def __init__(self): super().__init__() def forward(self, input, target): input = input.squeeze(1) target = target.squeeze(1) loss = lovasz_hinge(input, target, per_image=True) return loss ================================================ FILE: Seg_UKAN/metrics.py ================================================ import numpy as np import torch import torch.nn.functional as F from medpy.metric.binary import jc, dc, hd, hd95, recall, specificity, precision def iou_score(output, target): smooth = 1e-5 if torch.is_tensor(output): output = torch.sigmoid(output).data.cpu().numpy() if torch.is_tensor(target): target = target.data.cpu().numpy() output_ = output > 0.5 target_ = target > 0.5 intersection = (output_ & target_).sum() union = (output_ | target_).sum() iou = (intersection + smooth) / (union + smooth) dice = (2* iou) / (iou+1) try: hd95_ = hd95(output_, target_) except: hd95_ = 0 return iou, dice, hd95_ def dice_coef(output, target): smooth = 1e-5 output = torch.sigmoid(output).view(-1).data.cpu().numpy() target = target.view(-1).data.cpu().numpy() intersection = (output * target).sum() return (2. * intersection + smooth) / \ (output.sum() + target.sum() + smooth) def indicators(output, target): if torch.is_tensor(output): output = torch.sigmoid(output).data.cpu().numpy() if torch.is_tensor(target): target = target.data.cpu().numpy() output_ = output > 0.5 target_ = target > 0.5 iou_ = jc(output_, target_) dice_ = dc(output_, target_) hd_ = hd(output_, target_) hd95_ = hd95(output_, target_) recall_ = recall(output_, target_) specificity_ = specificity(output_, target_) precision_ = precision(output_, target_) return iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ ================================================ FILE: Seg_UKAN/requirements.txt ================================================ addict==2.4.0 dataclasses pandas pyyaml albumentations tqdm tensorboardX # mmcv-full==1.2.7 numpy opencv-python perceptual==0.1 pillow==8.4.0 scikit-image==0.17.2 scipy==1.5.4 tifffile==2020.9.3 timm==0.3.2 typing-extensions==4.0.0 yapf==0.31.0 pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 ================================================ FILE: Seg_UKAN/scripts.sh ================================================ dataset=busi input_size=256 python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR] python val.py --name ${dataset}_UKAN dataset=glas input_size=512 python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR] python val.py --name ${dataset}_UKAN dataset=cvc input_size=256 python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR] python val.py --name ${dataset}_UKAN ================================================ FILE: Seg_UKAN/train.py ================================================ import argparse import os from collections import OrderedDict from glob import glob import random import numpy as np import pandas as pd import torch import torch.backends.cudnn as cudnn import torch.nn as nn import torch.optim as optim import yaml from albumentations.augmentations import transforms from albumentations.augmentations import geometric from albumentations.core.composition import Compose, OneOf from sklearn.model_selection import train_test_split from torch.optim import lr_scheduler from tqdm import tqdm from albumentations import RandomRotate90, Resize import archs import losses from dataset import Dataset from metrics import iou_score, indicators from utils import AverageMeter, str2bool from tensorboardX import SummaryWriter import shutil import os import subprocess from pdb import set_trace as st ARCH_NAMES = archs.__all__ LOSS_NAMES = losses.__all__ LOSS_NAMES.append('BCEWithLogitsLoss') def list_type(s): str_list = s.split(',') int_list = [int(a) for a in str_list] return int_list def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--name', default=None, help='model name: (default: arch+timestamp)') parser.add_argument('--epochs', default=400, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch_size', default=8, type=int, metavar='N', help='mini-batch size (default: 16)') parser.add_argument('--dataseed', default=2981, type=int, help='') # model parser.add_argument('--arch', '-a', metavar='ARCH', default='UKAN') parser.add_argument('--deep_supervision', default=False, type=str2bool) parser.add_argument('--input_channels', default=3, type=int, help='input channels') parser.add_argument('--num_classes', default=1, type=int, help='number of classes') parser.add_argument('--input_w', default=256, type=int, help='image width') parser.add_argument('--input_h', default=256, type=int, help='image height') parser.add_argument('--input_list', type=list_type, default=[128, 160, 256]) # loss parser.add_argument('--loss', default='BCEDiceLoss', choices=LOSS_NAMES, help='loss: ' + ' | '.join(LOSS_NAMES) + ' (default: BCEDiceLoss)') # dataset parser.add_argument('--dataset', default='busi', help='dataset name') parser.add_argument('--data_dir', default='inputs', help='dataset dir') parser.add_argument('--output_dir', default='outputs', help='ouput dir') # optimizer parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'], help='loss: ' + ' | '.join(['Adam', 'SGD']) + ' (default: Adam)') parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, help='momentum') parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay') parser.add_argument('--nesterov', default=False, type=str2bool, help='nesterov') parser.add_argument('--kan_lr', default=1e-2, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--kan_weight_decay', default=1e-4, type=float, help='weight decay') # scheduler parser.add_argument('--scheduler', default='CosineAnnealingLR', choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR']) parser.add_argument('--min_lr', default=1e-5, type=float, help='minimum learning rate') parser.add_argument('--factor', default=0.1, type=float) parser.add_argument('--patience', default=2, type=int) parser.add_argument('--milestones', default='1,2', type=str) parser.add_argument('--gamma', default=2/3, type=float) parser.add_argument('--early_stopping', default=-1, type=int, metavar='N', help='early stopping (default: -1)') parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', ) parser.add_argument('--num_workers', default=4, type=int) parser.add_argument('--no_kan', action='store_true') config = parser.parse_args() return config def train(config, train_loader, model, criterion, optimizer): avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()} model.train() pbar = tqdm(total=len(train_loader)) for input, target, _ in train_loader: input = input.cuda() target = target.cuda() # compute output if config['deep_supervision']: outputs = model(input) loss = 0 for output in outputs: loss += criterion(output, target) loss /= len(outputs) iou, dice, _ = iou_score(outputs[-1], target) iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(outputs[-1], target) else: output = model(input) loss = criterion(output, target) iou, dice, _ = iou_score(output, target) iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(output, target) # compute gradient and do optimizing step optimizer.zero_grad() loss.backward() optimizer.step() avg_meters['loss'].update(loss.item(), input.size(0)) avg_meters['iou'].update(iou, input.size(0)) postfix = OrderedDict([ ('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg), ]) pbar.set_postfix(postfix) pbar.update(1) pbar.close() return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg)]) def validate(config, val_loader, model, criterion): avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter(), 'dice': AverageMeter()} # switch to evaluate mode model.eval() with torch.no_grad(): pbar = tqdm(total=len(val_loader)) for input, target, _ in val_loader: input = input.cuda() target = target.cuda() # compute output if config['deep_supervision']: outputs = model(input) loss = 0 for output in outputs: loss += criterion(output, target) loss /= len(outputs) iou, dice, _ = iou_score(outputs[-1], target) else: output = model(input) loss = criterion(output, target) iou, dice, _ = iou_score(output, target) avg_meters['loss'].update(loss.item(), input.size(0)) avg_meters['iou'].update(iou, input.size(0)) avg_meters['dice'].update(dice, input.size(0)) postfix = OrderedDict([ ('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg), ('dice', avg_meters['dice'].avg) ]) pbar.set_postfix(postfix) pbar.update(1) pbar.close() return OrderedDict([('loss', avg_meters['loss'].avg), ('iou', avg_meters['iou'].avg), ('dice', avg_meters['dice'].avg)]) def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def main(): seed_torch() config = vars(parse_args()) exp_name = config.get('name') output_dir = config.get('output_dir') my_writer = SummaryWriter(f'{output_dir}/{exp_name}') if config['name'] is None: if config['deep_supervision']: config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch']) else: config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch']) os.makedirs(f'{output_dir}/{exp_name}', exist_ok=True) print('-' * 20) for key in config: print('%s: %s' % (key, config[key])) print('-' * 20) with open(f'{output_dir}/{exp_name}/config.yml', 'w') as f: yaml.dump(config, f) # define loss function (criterion) if config['loss'] == 'BCEWithLogitsLoss': criterion = nn.BCEWithLogitsLoss().cuda() else: criterion = losses.__dict__[config['loss']]().cuda() cudnn.benchmark = True # create model model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'], no_kan=config['no_kan']) model = model.cuda() param_groups = [] kan_fc_params = [] other_params = [] for name, param in model.named_parameters(): # print(name, "=>", param.shape) if 'layer' in name.lower() and 'fc' in name.lower(): # higher lr for kan layers # kan_fc_params.append(name) param_groups.append({'params': param, 'lr': config['kan_lr'], 'weight_decay': config['kan_weight_decay']}) else: # other_params.append(name) param_groups.append({'params': param, 'lr': config['lr'], 'weight_decay': config['weight_decay']}) # st() if config['optimizer'] == 'Adam': optimizer = optim.Adam(param_groups) elif config['optimizer'] == 'SGD': optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay']) else: raise NotImplementedError if config['scheduler'] == 'CosineAnnealingLR': scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=config['epochs'], eta_min=config['min_lr']) elif config['scheduler'] == 'ReduceLROnPlateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr']) elif config['scheduler'] == 'MultiStepLR': scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma']) elif config['scheduler'] == 'ConstantLR': scheduler = None else: raise NotImplementedError shutil.copy2('train.py', f'{output_dir}/{exp_name}/') shutil.copy2('archs.py', f'{output_dir}/{exp_name}/') dataset_name = config['dataset'] img_ext = '.png' if dataset_name == 'busi': mask_ext = '_mask.png' elif dataset_name == 'glas': mask_ext = '.png' elif dataset_name == 'cvc': mask_ext = '.png' # Data loading code img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext))) img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed']) train_transform = Compose([ RandomRotate90(), geometric.transforms.Flip(), Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) val_transform = Compose([ Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) train_dataset = Dataset( img_ids=train_img_ids, img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'), mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'), img_ext=img_ext, mask_ext=mask_ext, num_classes=config['num_classes'], transform=train_transform) val_dataset = Dataset( img_ids=val_img_ids, img_dir=os.path.join(config['data_dir'] ,config['dataset'], 'images'), mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'), img_ext=img_ext, mask_ext=mask_ext, num_classes=config['num_classes'], transform=val_transform) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], drop_last=True) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], drop_last=False) log = OrderedDict([ ('epoch', []), ('lr', []), ('loss', []), ('iou', []), ('val_loss', []), ('val_iou', []), ('val_dice', []), ]) best_iou = 0 best_dice= 0 trigger = 0 for epoch in range(config['epochs']): print('Epoch [%d/%d]' % (epoch, config['epochs'])) # train for one epoch train_log = train(config, train_loader, model, criterion, optimizer) # evaluate on validation set val_log = validate(config, val_loader, model, criterion) if config['scheduler'] == 'CosineAnnealingLR': scheduler.step() elif config['scheduler'] == 'ReduceLROnPlateau': scheduler.step(val_log['loss']) print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f' % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou'])) log['epoch'].append(epoch) log['lr'].append(config['lr']) log['loss'].append(train_log['loss']) log['iou'].append(train_log['iou']) log['val_loss'].append(val_log['loss']) log['val_iou'].append(val_log['iou']) log['val_dice'].append(val_log['dice']) pd.DataFrame(log).to_csv(f'{output_dir}/{exp_name}/log.csv', index=False) my_writer.add_scalar('train/loss', train_log['loss'], global_step=epoch) my_writer.add_scalar('train/iou', train_log['iou'], global_step=epoch) my_writer.add_scalar('val/loss', val_log['loss'], global_step=epoch) my_writer.add_scalar('val/iou', val_log['iou'], global_step=epoch) my_writer.add_scalar('val/dice', val_log['dice'], global_step=epoch) my_writer.add_scalar('val/best_iou_value', best_iou, global_step=epoch) my_writer.add_scalar('val/best_dice_value', best_dice, global_step=epoch) trigger += 1 if val_log['iou'] > best_iou: torch.save(model.state_dict(), f'{output_dir}/{exp_name}/model.pth') best_iou = val_log['iou'] best_dice = val_log['dice'] print("=> saved best model") print('IoU: %.4f' % best_iou) print('Dice: %.4f' % best_dice) trigger = 0 # early stopping if config['early_stopping'] >= 0 and trigger >= config['early_stopping']: print("=> early stopping") break torch.cuda.empty_cache() if __name__ == '__main__': main() ================================================ FILE: Seg_UKAN/utils.py ================================================ import argparse import torch.nn as nn class qkv_transform(nn.Conv1d): """Conv1d for qkv_transform""" def str2bool(v): if v.lower() in ['true', 1]: return True elif v.lower() in ['false', 0]: return False else: raise argparse.ArgumentTypeError('Boolean value expected.') def count_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count ================================================ FILE: Seg_UKAN/val.py ================================================ #! /data/cxli/miniconda3/envs/th200/bin/python import argparse import os from glob import glob import random import numpy as np import cv2 import torch import torch.backends.cudnn as cudnn import yaml from albumentations.augmentations import transforms from albumentations.core.composition import Compose from sklearn.model_selection import train_test_split from tqdm import tqdm from collections import OrderedDict import archs from dataset import Dataset from metrics import iou_score from utils import AverageMeter from albumentations import RandomRotate90,Resize import time from PIL import Image def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--name', default=None, help='model name') parser.add_argument('--output_dir', default='outputs', help='ouput dir') args = parser.parse_args() return args def seed_torch(seed=1029): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True def main(): seed_torch() args = parse_args() with open(f'{args.output_dir}/{args.name}/config.yml', 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) print('-'*20) for key in config.keys(): print('%s: %s' % (key, str(config[key]))) print('-'*20) cudnn.benchmark = True model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list']) model = model.cuda() dataset_name = config['dataset'] img_ext = '.png' if dataset_name == 'busi': mask_ext = '_mask.png' elif dataset_name == 'glas': mask_ext = '.png' elif dataset_name == 'cvc': mask_ext = '.png' # Data loading code img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext))) # img_ids.sort() img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids] _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed']) ckpt = torch.load(f'{args.output_dir}/{args.name}/model.pth') try: model.load_state_dict(ckpt) except: print("Pretrained model keys:", ckpt.keys()) print("Current model keys:", model.state_dict().keys()) pretrained_dict = {k: v for k, v in ckpt.items() if k in model.state_dict()} current_dict = model.state_dict() diff_keys = set(current_dict.keys()) - set(pretrained_dict.keys()) print("Difference in model keys:") for key in diff_keys: print(f"Key: {key}") model.load_state_dict(ckpt, strict=False) model.eval() val_transform = Compose([ Resize(config['input_h'], config['input_w']), transforms.Normalize(), ]) val_dataset = Dataset( img_ids=val_img_ids, img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'), mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'), img_ext=img_ext, mask_ext=mask_ext, num_classes=config['num_classes'], transform=val_transform) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], drop_last=False) iou_avg_meter = AverageMeter() dice_avg_meter = AverageMeter() hd95_avg_meter = AverageMeter() with torch.no_grad(): for input, target, meta in tqdm(val_loader, total=len(val_loader)): input = input.cuda() target = target.cuda() model = model.cuda() # compute output output = model(input) iou, dice, hd95_ = iou_score(output, target) iou_avg_meter.update(iou, input.size(0)) dice_avg_meter.update(dice, input.size(0)) hd95_avg_meter.update(hd95_, input.size(0)) output = torch.sigmoid(output).cpu().numpy() output[output>=0.5]=1 output[output<0.5]=0 os.makedirs(os.path.join(args.output_dir, config['name'], 'out_val'), exist_ok=True) for pred, img_id in zip(output, meta['img_id']): pred_np = pred[0].astype(np.uint8) pred_np = pred_np * 255 img = Image.fromarray(pred_np, 'L') img.save(os.path.join(args.output_dir, config['name'], 'out_val/{}.jpg'.format(img_id))) print(config['name']) print('IoU: %.4f' % iou_avg_meter.avg) print('Dice: %.4f' % dice_avg_meter.avg) print('HD95: %.4f' % hd95_avg_meter.avg) if __name__ == '__main__': main()