Full Code of CUHK-AIM-Group/U-KAN for AI

main b20bf63490f0 cached
42 files
236.5 KB
64.2k tokens
405 symbols
1 requests
Download .txt
Showing preview only (250K chars total). Download the full file or copy to clipboard to get everything.
Repository: CUHK-AIM-Group/U-KAN
Branch: main
Commit: b20bf63490f0
Files: 42
Total size: 236.5 KB

Directory structure:
gitextract_suz9sp40/

├── Diffusion_UKAN/
│   ├── Diffusion/
│   │   ├── Diffusion.py
│   │   ├── Model.py
│   │   ├── Model_ConvKan.py
│   │   ├── Model_UKAN_Hybrid.py
│   │   ├── Model_UMLP.py
│   │   ├── Train.py
│   │   ├── UNet.py
│   │   ├── __init__.py
│   │   ├── kan_utils/
│   │   │   ├── __init__.py
│   │   │   ├── fastkanconv.py
│   │   │   └── kan.py
│   │   └── utils.py
│   ├── Main.py
│   ├── Main_Test.py
│   ├── README.md
│   ├── Scheduler.py
│   ├── data/
│   │   └── readme.txt
│   ├── inception-score-pytorch/
│   │   ├── LICENSE.md
│   │   ├── README.md
│   │   └── inception_score.py
│   ├── released_models/
│   │   └── readme.txt
│   ├── requirements.txt
│   ├── tools/
│   │   ├── resive_cvc.py
│   │   ├── resize_busi.py
│   │   └── resize_glas.py
│   └── training_scripts/
│       ├── busi.sh
│       ├── cvc.sh
│       └── glas.sh
├── README.md
└── Seg_UKAN/
    ├── LICENSE
    ├── archs.py
    ├── config.py
    ├── dataset.py
    ├── environment.yml
    ├── kan.py
    ├── losses.py
    ├── metrics.py
    ├── requirements.txt
    ├── scripts.sh
    ├── train.py
    ├── utils.py
    └── val.py

================================================
FILE CONTENTS
================================================

================================================
FILE: Diffusion_UKAN/Diffusion/Diffusion.py
================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm


def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))

    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss


class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        print('Start Sampling')
        for time_step in tqdm(reversed(range(self.T))):
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)   




================================================
FILE: Diffusion_UKAN/Diffusion/Model.py
================================================

   
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h
        # return x

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )

class Ukan(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks1 = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.middleblocks2 = nn.ModuleList([
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )



        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]


        kan_c=512
        self.fc1 = KANLinear(
                    kan_c,
                    kan_c *2,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
        # print(now_ch)
        # self.dwconv = DWConv(kan_c *2)
        self.act = nn.GELU()
        self.fc2 = KANLinear(
                    kan_c *2,
                    kan_c,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
        
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        # for layer in self.middleblocks1:
            # h = layer(h, temb)
        B, C, H, W = h.shape
        # transform  B, C, H, W into B*H*W, C
        h = h.permute(0, 2, 3, 1).reshape(B*H*W, C)
        h =self.fc2( self.fc1(h))
        # transform B*H*W, C  into  B, C, H, W
        h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)

        # for layer in self.middleblocks2:
        #     h = layer(h, temb)
        ### Stage 3
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class kan(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        

        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]

        # self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc1 = KANLinear(
                    in_features,
                    hidden_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
        
        # self.fc2 = nn.Linear(hidden_features, out_features)
        self.fc2 = KANLinear(
                    hidden_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )

        self.fc3 = KANLinear(
                    hidden_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )   

        # ##############################################
        self.version = 4 # version 4 hard code ���ܶ�����
        
        # ##############################################
        if self.version == 1:
            self.dwconv_1 = DWConv(hidden_features)
            self.act_1 = act_layer()

            self.dwconv_2 = DWConv(hidden_features)
            self.act_2 = act_layer()

            self.dwconv_3 = DWConv(hidden_features)
            self.act_3 = act_layer()

            self.dwconv_4 = DWConv(hidden_features)
            self.act_4 = act_layer()
        elif self.version == 2:
            self.dwconv_1 = DWConv(hidden_features)
            self.act_1 = act_layer()

            self.dwconv_2 = DWConv(hidden_features)
            self.act_2 = act_layer()

            self.dwconv_3 = DWConv(hidden_features)
            self.act_3 = act_layer()

        elif self.version == 3:
            self.dwconv_1 = DW_bn_relu(hidden_features)
            self.dwconv_2 = DW_bn_relu(hidden_features)
            self.dwconv_3 = DW_bn_relu(hidden_features)
        elif self.version == 4:
            self.dwconv_1 = DW_bn_relu(hidden_features)
            self.dwconv_2 = DW_bn_relu(hidden_features)
            self.dwconv_3 = DW_bn_relu(hidden_features)
        elif self.version == 5:
            self.dwconv_1 = DWConv(hidden_features)
            self.act_1 = act_layer()

            self.dwconv_2 = DWConv(hidden_features)
            self.act_2 = act_layer()

            self.dwconv_3 = DWConv(hidden_features)
            self.act_3 = act_layer()
        elif self.version == 6:
            self.dwconv_1 = DWConv(hidden_features)
            self.act_1 = act_layer()

            self.dwconv_2 = DWConv(hidden_features)
            self.act_2 = act_layer()

            self.dwconv_3 = DWConv(hidden_features)
            self.act_3 = act_layer()
        elif self.version == 7:
            self.dwconv_1 = DWConv(hidden_features)
            self.act_1 = act_layer()

            self.dwconv_2 = DWConv(hidden_features)
            self.act_2 = act_layer()

            self.dwconv_3 = DWConv(hidden_features)
            self.act_3 = act_layer()
        elif self.version == 8:
            self.dwconv_1 = DW_bn_relu(hidden_features)
            self.dwconv_2 = DW_bn_relu(hidden_features)
            self.dwconv_3 = DW_bn_relu(hidden_features)

    
        self.drop = nn.Dropout(drop)

        self.shift_size = shift_size
        self.pad = shift_size // 2

        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    

    def forward(self, x, H, W):
        # pdb.set_trace()
        B, N, C = x.shape

        if self.version == 1:
            x = self.dwconv_1(x, H, W)
            x = self.act_1(x) 

            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_2(x, H, W)
            x = self.act_2(x) 

            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_3(x, H, W)
            x = self.act_3(x) 

            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_4(x, H, W)
            x = self.act_4(x) 
        elif self.version == 2:
            
            x = self.dwconv_1(x, H, W)
            x = self.act_1(x) 

            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_2(x, H, W)
            x = self.act_2(x) 

            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_3(x, H, W)
            x = self.act_3(x) 

            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

        elif self.version == 3:
            x = self.dwconv_1(x, H, W)
            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_2(x, H, W)
            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_3(x, H, W)
            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
        elif self.version == 4:
            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_1(x, H, W)
            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_2(x, H, W)
            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_3(x, H, W)
        elif self.version == 5:

            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_1(x, H, W)
            x = self.act_1(x) 

            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_2(x, H, W)
            x = self.act_2(x) 

            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_3(x, H, W)
            x = self.act_3(x) 
        elif self.version == 6:

            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_1(x, H, W)
            x = self.act_1(x) 

            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_2(x, H, W)
            x = self.act_2(x) 

            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_3(x, H, W)
            x = self.act_3(x) 
        elif self.version == 7:

            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_1(x, H, W)
            x = self.act_1(x) 
            x = self.drop(x)

            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_2(x, H, W)
            x = self.act_2(x) 
            x = self.drop(x)

            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()

            x = self.dwconv_3(x, H, W)
            x = self.act_3(x) 
            x = self.drop(x)
        elif self.version == 8:
            x = self.dwconv_1(x, H, W)
            x = self.drop(x)
            x = self.fc1(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_2(x, H, W)
            x = self.drop(x)
            x = self.fc2(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
            x = self.dwconv_3(x, H, W)
            x = self.drop(x)
            x = self.fc3(x.reshape(B*N,C))
            x = x.reshape(B,N,C).contiguous()
        return x


class Ukan_v3(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)


        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks1 = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.middleblocks2 = nn.ModuleList([
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )



        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]
        kan_c=512
        self.kan1 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4)
        self.kan2 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4)

        
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        # for layer in self.middleblocks1:
            # h = layer(h, temb)
        B, C, H, W = h.shape
        # transform  B, C, H, W into B*H*W, C
        h = h.reshape(B,C, H*W).permute(0, 2, 1)
        h = self.kan1(h, H, W)
        h = self.kan2(h, H, W)
        h = h.permute(0, 2, 1).reshape(B, C, H, W)


        # h =self.fc2( self.fc1(h))
        # transform B*H*W, C  into  B, C, H, W
        # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)
        # B, N, C = x.shape

        # for layer in self.middleblocks2:
        #     h = layer(h, temb)
        ### Stage 3
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

class Ukan_v2(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,version=4):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)


        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks1 = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.middleblocks2 = nn.ModuleList([
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )



        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]
        kan_c=512
        self.kan = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=version)

        
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        # for layer in self.middleblocks1:
            # h = layer(h, temb)
        B, C, H, W = h.shape
        # transform  B, C, H, W into B*H*W, C
        h = h.reshape(B,C, H*W).permute(0, 2, 1)
        h = self.kan(h, H, W)
        h = h.permute(0, 2, 1).reshape(B, C, H, W)


        # h =self.fc2( self.fc1(h))
        # transform B*H*W, C  into  B, C, H, W
        # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)
        # B, N, C = x.shape

        # for layer in self.middleblocks2:
        #     h = layer(h, temb)
        ### Stage 3
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)


        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h





class UNet_MLP(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks1 = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.middleblocks2 = nn.ModuleList([
            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )


        kan_c=512
        self.fc1 = nn.Linear(
                    kan_c,
                    kan_c *2,
                )
        self.act = nn.GELU()

        self.fc2 = nn.Linear(
                    kan_c *2,
                    kan_c,
                )
        
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        # for layer in self.middleblocks1:
            # h = layer(h, temb)
        B, C, H, W = h.shape
        # transform  B, C, H, W into B*H*W, C
        h = h.permute(0, 2, 3, 1).reshape(B*H*W, C)
        h =self.fc2(self.act(self.fc1(h)))
        # transform B*H*W, C  into  B, C, H, W
        h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)

        # for layer in self.middleblocks2:
        #     h = layer(h, temb)
        ### Stage 3
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

if __name__ == '__main__':
    batch_size = 8
    model = UNet(
        T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
        num_res_blocks=2, dropout=0.1)
    x = torch.randn(batch_size, 3, 32, 32)
    t = torch.randint(1000, (batch_size, ))
    y = model(x, t)
    print(y.shape)


================================================
FILE: Diffusion_UKAN/Diffusion/Model_ConvKan.py
================================================

   
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from Diffusion.kan_utils.fastkanconv import FastKANConvLayer, SplineConv2D

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=2, padding=1)
        # self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=1, padding=1)
        # self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            # Swish(),
            # nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
            FastKANConvLayer(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            # Swish(),
            nn.Dropout(dropout),
            # nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
            FastKANConvLayer(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            # self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
            self.shortcut = FastKANConvLayer(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)) and not isinstance(module, (SplineConv2D)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        # init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)
        h = h + self.shortcut(x)
        h = self.attn(h)
        return h
        # return x

class UNet_ConvKan(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])
    
        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            # Swish(),
            # nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
            FastKANConvLayer(now_ch, 3, 3, stride=1, padding=1)
        )
        
        # self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

if __name__ == '__main__':
    batch_size = 8
    model = UNet_ConvKan(
        T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],
        num_res_blocks=2, dropout=0.1)
    x = torch.randn(batch_size, 3, 32, 32)
    t = torch.randint(1000, (batch_size, ))
    y = model(x, t)
    print(y.shape)


================================================
FILE: Diffusion_UKAN/Diffusion/Model_UKAN_Hybrid.py
================================================

   
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)


def shift(dim):
            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
            x_cat = torch.cat(x_shift, 1)
            x_cat = torch.narrow(x_cat, 2, self.pad, H)
            x_cat = torch.narrow(x_cat, 3, self.pad, W)
            return x_cat


class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
def swish(x):
    
    return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x
    
class kan(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        
        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=Swish
        grid_eps=0.02
        grid_range=[-1, 1]

        self.fc1 = KANLinear(
                    in_features,
                    hidden_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = self.fc1(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()

        return x

class shiftedBlock(nn.Module):
    def __init__(self, dim,  mlp_ratio=4.,drop_path=0.,norm_layer=nn.LayerNorm):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, dim),
        )

        self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W, temb):

        temb = self.temb_proj(temb)
        x = self.drop_path(self.kan(self.norm2(x), H, W))
        x = x + temb.unsqueeze(1)

        return x

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.GroupNorm(32, dim)
        # self.relu = Swish()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = self.bn(x)
        x = swish(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class SingleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(SingleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
        )

        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class DoubleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
            nn.GroupNorm(32, h_ch),
            Swish(),
            nn.Conv2d(h_ch, h_ch, 3, padding=1),
            nn.GroupNorm(32, h_ch),
            Swish()
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class D_SingleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(D_SingleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.GroupNorm(32,in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class D_DoubleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(D_DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, padding=1),
            nn.GroupNorm(32,in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
             nn.GroupNorm(32,h_ch),
            Swish()
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input,temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]

class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, h_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, h_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1),
        )
        if in_ch != h_ch:
            self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(h_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UKan_Hybrid(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        attn = []
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record hput channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            h_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, h_ch=h_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = h_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            h_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = h_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )

        # 
        embed_dims = [256, 320, 512]
        norm_layer = nn.LayerNorm
        dpr = [0.0, 0.0, 0.0]
        self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
        self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])

        self.norm3 = norm_layer(embed_dims[1])
        self.norm4 = norm_layer(embed_dims[2])
        self.dnorm3 = norm_layer(embed_dims[1])

        self.kan_block1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1],  mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)])

        self.kan_block2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[2],  mlp_ratio=1, drop_path=dpr[1], norm_layer=norm_layer)])

        self.kan_dblock1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)])

        self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1])  
        self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0])  

        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
    
        t3 = h

        B = x.shape[0]
        h, H, W = self.patch_embed3(h)
 
        for i, blk in enumerate(self.kan_block1):
            h = blk(h, H, W, temb)
        h = self.norm3(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = h

        h, H, W= self.patch_embed4(h)
        for i, blk in enumerate(self.kan_block2):
            h = blk(h, H, W, temb)
        h = self.norm4(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        ### Stage 4
        h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear'))

        h = torch.add(h, t4)

        _, _, H, W = h.shape
        h = h.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.kan_dblock1):
            h = blk(h, H, W, temb)

            
        ### Stage 3
        h = self.dnorm3(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear'))

        h = torch.add(h,t3)

        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h


if __name__ == '__main__':
    batch_size = 8
    model = UKan_Hybrid(
        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[],
        num_res_blocks=2, dropout=0.1)



================================================
FILE: Diffusion_UKAN/Diffusion/Model_UMLP.py
================================================

   
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)


def shift(dim):
            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]
            x_cat = torch.cat(x_shift, 1)
            x_cat = torch.narrow(x_cat, 2, self.pad, H)
            x_cat = torch.narrow(x_cat, 3, self.pad, W)
            return x_cat


class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
def swish(x):
    
    return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x
    
class kan(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4, kan_val=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        
        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=Swish
        grid_eps=0.02
        grid_range=[-1, 1]

        if kan_val:
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.fc3 = nn.Linear(hidden_features, out_features)
        else:
            self.fc1 = nn.Sequential(
                nn.Linear(in_features, hidden_features),
                Swish(),
                nn.Linear(hidden_features, out_features))
            
 
        self.drop = nn.Dropout(drop)

        self.shift_size = shift_size
        self.pad = shift_size // 2

        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = self.fc1(x.reshape(B*N,C))

        x = x.reshape(B,N,C).contiguous()

        return x

class shiftedBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, version=1, kan_val=False):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)

        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, dim),
        )
        # self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, kan_val=kan_val)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W, temb):

        temb = self.temb_proj(temb)
        # x = x + self.drop_path(self.kan(self.norm2(x), H, W))
        x = self.drop_path(self.kan(self.norm2(x), H, W))
        x = x + temb.unsqueeze(1)

        return x

class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.GroupNorm(32, dim)
        # self.relu = Swish()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = self.bn(x)
        x = swish(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class SingleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(SingleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
        )

        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class DoubleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
            nn.GroupNorm(32, h_ch),
            Swish(),
            nn.Conv2d(h_ch, h_ch, 3, padding=1),
            nn.GroupNorm(32, h_ch),
            Swish()
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class D_SingleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(D_SingleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.GroupNorm(32,in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input, temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]


class D_DoubleConv(nn.Module):
    def __init__(self, in_ch, h_ch):
        super(D_DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, padding=1),
            nn.GroupNorm(32,in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, padding=1),
             nn.GroupNorm(32,h_ch),
            Swish()
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(256, h_ch),
        )
    def forward(self, input,temb):
        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]

class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, h_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, h_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1),
        )
        if in_ch != h_ch:
            self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(h_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UMLP(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        attn = []
        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record hput channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            h_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, h_ch=h_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = h_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            h_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = h_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )

        # 
        embed_dims = [256, 320, 512]
        drop_rate = 0.0
        attn_drop_rate = 0.0
        kan_val = False
        version = 4
        sr_ratios = [8, 4, 2, 1]
        num_heads=[1, 2, 4, 8]
        qkv_bias=False
        qk_scale=None
        norm_layer = nn.LayerNorm
        dpr = [0.0, 0.0, 0.0]
        self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
        self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])

        
        self.norm3 = norm_layer(embed_dims[1])
        self.norm4 = norm_layer(embed_dims[2])

        self.dnorm3 = norm_layer(embed_dims[1])
        self.dnorm4 = norm_layer(embed_dims[0])


        self.kan_block1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])

        self.kan_block2 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])

        self.kan_dblock1 = nn.ModuleList([shiftedBlock(
            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,
            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])

        # self.kan_dblock2 = nn.ModuleList([shiftedBlock(
        #     dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,
        #     drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,
        #     sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])

        self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1])  
        self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0])  

        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
    
        t3 = h

        B = x.shape[0]
        h, H, W = self.patch_embed3(h)
 
        for i, blk in enumerate(self.kan_block1):
            h = blk(h, H, W, temb)
        h = self.norm3(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = h

        h, H, W= self.patch_embed4(h)
        for i, blk in enumerate(self.kan_block2):
            h = blk(h, H, W, temb)
        h = self.norm4(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        ### Stage 4
        h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear'))

        h = torch.add(h, t4)

        _, _, H, W = h.shape
        h = h.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.kan_dblock1):
            h = blk(h, H, W, temb)

            
        ### Stage 3
        h = self.dnorm3(h)
        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear'))

        h = torch.add(h,t3)

        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h


if __name__ == '__main__':
    batch_size = 8
    model = UMLP(
        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[],
        num_res_blocks=2, dropout=0.1)
    



================================================
FILE: Diffusion_UKAN/Diffusion/Train.py
================================================

import os
from typing import Dict
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, transforms
# from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer
from Diffusion.UNet import UNet, UNet_Baseline
from Diffusion.Model_ConvKan import UNet_ConvKan
from Diffusion.Model_UMLP import UMLP
from Diffusion.Model_UKAN_Hybrid import UKan_Hybrid
from Scheduler import GradualWarmupScheduler
from skimage import io
import os
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import Dataset
import sys


model_dict = {
    'UNet': UNet,
    'UNet_ConvKan': UNet_ConvKan, # dose not work
    'UMLP': UMLP,
    'UKan_Hybrid': UKan_Hybrid,
    'UNet_Baseline': UNet_Baseline,
}

class UnlabeledDataset(Dataset):
    def __init__(self, folder, transform=None, repeat_n=1):
        self.folder = folder
        self.transform = transform
        # self.image_files = os.listdir(folder) * repeat_n
        self.image_files = os.listdir(folder) 

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.folder, image_file)
        image = io.imread(image_path)
        if self.transform:
            image = self.transform(image)
        return image, torch.Tensor([0])


def train(modelConfig: Dict):
    device = torch.device(modelConfig["device"])
    log_print = True
    if log_print:
        file = open(modelConfig["save_weight_dir"]+'log.txt', "w")
        sys.stdout = file
    transform = Compose([
        ToTensor(),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    if modelConfig["dataset"] == 'cvc':
        dataset = UnlabeledDataset('data/cvc/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
    elif modelConfig["dataset"] == 'glas':
        dataset = UnlabeledDataset('data/glas/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
    elif modelConfig["dataset"] == 'glas_resize':
        dataset = UnlabeledDataset('data/glas/images_64_resize/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
    elif modelConfig["dataset"] == 'busi':
        dataset = UnlabeledDataset('data/busi/images_64/', transform=transform, repeat_n=modelConfig["dataset_repeat"])
    else:
        raise ValueError('dataset not found')

    print('modelConfig: ')
    for key, value in modelConfig.items():
        print(key, ' : ', value)
        
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    
    print('Using {}'.format(modelConfig["model"]))
    # model setup
    net_model =model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                    num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)

    if modelConfig["training_load_weight"] is not None:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
        
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"] // 10, after_scheduler=cosineScheduler)

    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

    # start training
    for e in range(1,modelConfig["epoch"]+1):
        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for images, labels in tqdmDataLoader:
                # train
                optimizer.zero_grad()
                x_0 = images.to(device)
                
                loss = trainer(x_0).sum() / 1000.
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])
                optimizer.step()
                tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })
                # print version
                if log_print:
                    print("epoch: ", e, "loss: ", loss.item(), "img shape: ", x_0.shape, "LR: ", optimizer.state_dict()['param_groups'][0]["lr"])
        warmUpScheduler.step()
        if e % 50 ==0:
            torch.save(net_model.state_dict(), os.path.join(
                modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
            modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(e)
            eval_tmp(modelConfig, e)

    torch.save(net_model.state_dict(), os.path.join(
        modelConfig["save_weight_dir"], 'ckpt_' + str(e) + "_.pt"))
    if log_print:
        file.close()
        sys.stdout = sys.__stdout__
    
def eval_tmp(modelConfig: Dict, nme: int):
    # load model and evaluate
    with torch.no_grad():
        device = torch.device(modelConfig["device"])
        model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
        ckpt = torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
    
        model.load_state_dict(ckpt)
        
        print("model load weight done.")
        model.eval()
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
        # Sampled from standard normal distribution
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)
        # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
        # save_image(saveNoisy, os.path.join(
            # modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]

        save_root = modelConfig["sampled_dir"].replace('Gens','Tmp')
        os.makedirs(save_root, exist_ok=True)
        save_image(sampledImgs, os.path.join(
            save_root,  modelConfig["sampledImgName"].replace('.png','_{}.png').format(nme)), nrow=modelConfig["nrow"])
        if nme < 0.95 * modelConfig["epoch"]:
            os.remove(os.path.join(
                modelConfig["save_weight_dir"], modelConfig["test_load_weight"]))

def eval(modelConfig: Dict):
    # load model and evaluate
    with torch.no_grad():
        device = torch.device(modelConfig["device"])

        model = model_dict[modelConfig["model"]](T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                    num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)
    
        ckpt = torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)

        model.load_state_dict(ckpt)
        print("model load weight done.")
        model.eval()
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)
        # Sampled from standard normal distribution
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"], 3, modelConfig["img_size"], modelConfig["img_size"]], device=device)     
        # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
        # save_image(saveNoisy, os.path.join(
        #     modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]

        for i, image in enumerate(sampledImgs):
    
            save_image(image, os.path.join(modelConfig["sampled_dir"],  modelConfig["sampledImgName"].replace('.png','_{}.png').format(i)), nrow=modelConfig["nrow"])


================================================
FILE: Diffusion_UKAN/Diffusion/UNet.py
================================================

   
import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle 
        # torch.Size([8, 512, 4, 4])
        for layer in self.middleblocks:
            h = layer(h, temb)
        # torch.Size([8, 512, 4, 4])
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h
    
class UNet_Baseline(nn.Module):
    # Remove the middle blocks
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)


        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)
        # Downsampling
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h



if __name__ == '__main__':
    batch_size = 8
    model = UNet(
        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[1],
        num_res_blocks=2, dropout=0.1)

================================================
FILE: Diffusion_UKAN/Diffusion/__init__.py
================================================
from .Diffusion import *
from .UNet import *
from .Train import *


================================================
FILE: Diffusion_UKAN/Diffusion/kan_utils/__init__.py
================================================
from .kan import *
from .fastkanconv import *
# from .kan_convolutional import *


================================================
FILE: Diffusion_UKAN/Diffusion/kan_utils/fastkanconv.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Union


class PolynomialFunction(nn.Module):
    def __init__(self, 
                 degree: int = 3):
        super().__init__()
        self.degree = degree

    def forward(self, x):
        return torch.stack([x ** i for i in range(self.degree)], dim=-1)
    
class BSplineFunction(nn.Module):
    def __init__(self, grid_min: float = -2.,
        grid_max: float = 2., degree: int = 3, num_basis: int = 8):
        super(BSplineFunction, self).__init__()
        self.degree = degree
        self.num_basis = num_basis
        self.knots = torch.linspace(grid_min, grid_max, num_basis + degree + 1)  # Uniform knots

    def basis_function(self, i, k, t):
        if k == 0:
            return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float()
        else:
            left_num = (t - self.knots[i]) * self.basis_function(i, k - 1, t)
            left_den = self.knots[i + k] - self.knots[i]
            left = left_num / left_den if left_den != 0 else 0

            right_num = (self.knots[i + k + 1] - t) * self.basis_function(i + 1, k - 1, t)
            right_den = self.knots[i + k + 1] - self.knots[i + 1]
            right = right_num / right_den if right_den != 0 else 0
            return left + right 
    
    def forward(self, x):
        x = x.squeeze()  # Assuming x is of shape (B, 1)
        basis_functions = torch.stack([self.basis_function(i, self.degree, x) for i in range(self.num_basis)], dim=-1)
        return basis_functions

class ChebyshevFunction(nn.Module):
    def __init__(self, degree: int = 4):
        super(ChebyshevFunction, self).__init__()
        self.degree = degree

    def forward(self, x):
        chebyshev_polynomials = [torch.ones_like(x), x]
        for n in range(2, self.degree):
            chebyshev_polynomials.append(2 * x * chebyshev_polynomials[-1] - chebyshev_polynomials[-2])
        return torch.stack(chebyshev_polynomials, dim=-1)

class FourierBasisFunction(nn.Module):
    def __init__(self, 
                 num_frequencies: int = 4, 
                 period: float = 1.0):
        super(FourierBasisFunction, self).__init__()
        assert num_frequencies % 2 == 0, "num_frequencies must be even"
        self.num_frequencies = num_frequencies
        self.period = nn.Parameter(torch.Tensor([period]), requires_grad=False)

    def forward(self, x):
        frequencies = torch.arange(1, self.num_frequencies // 2 + 1, device=x.device)
        sin_components = torch.sin(2 * torch.pi * frequencies * x[..., None] / self.period)
        cos_components = torch.cos(2 * torch.pi * frequencies * x[..., None] / self.period)
        basis_functions = torch.cat([sin_components, cos_components], dim=-1)
        return basis_functions
        
class RadialBasisFunction(nn.Module):
    def __init__(
        self,
        grid_min: float = -2.,
        grid_max: float = 2.,
        num_grids: int = 4,
        denominator: float = None,
    ):
        super().__init__()
        grid = torch.linspace(grid_min, grid_max, num_grids)
        self.grid = torch.nn.Parameter(grid, requires_grad=False)
        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)

    def forward(self, x):
        return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
    

    
    
class SplineConv2D(nn.Conv2d):
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 kernel_size: Union[int, Tuple[int, int]] = 3,
                 stride: Union[int, Tuple[int, int]] = 1, 
                 padding: Union[int, Tuple[int, int]] = 0, 
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1, 
                 bias: bool = True, 
                 init_scale: float = 0.1, 
                 padding_mode: str = "zeros", 
                 **kw
                 ) -> None:
        self.init_scale = init_scale
        super().__init__(in_channels, 
                         out_channels, 
                         kernel_size, 
                         stride, 
                         padding, 
                         dilation, 
                         groups, 
                         bias, 
                         padding_mode, 
                         **kw
                         )

    def reset_parameters(self) -> None:
        nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
        if self.bias is not None:
            nn.init.zeros_(self.bias)


class FastKANConvLayer(nn.Module):
    def __init__(self, 
                 in_channels: int, 
                 out_channels: int, 
                 kernel_size: Union[int, Tuple[int, int]] = 3,
                 stride: Union[int, Tuple[int, int]] = 1, 
                 padding: Union[int, Tuple[int, int]] = 0, 
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1, 
                 bias: bool = True, 
                 grid_min: float = -2., 
                 grid_max: float = 2.,
                 num_grids: int = 4, 
                 use_base_update: bool = True, 
                 base_activation = F.silu,
                 spline_weight_init_scale: float = 0.1, 
                 padding_mode: str = "zeros",
                 kan_type: str = "BSpline",
                #  kan_type: str = "RBF",
                 ) -> None:
        
        super().__init__()
        if kan_type == "RBF":
            self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
        elif kan_type == "Fourier":
            self.rbf = FourierBasisFunction(num_grids)
        elif kan_type == "Poly":
            self.rbf = PolynomialFunction(num_grids)
        elif kan_type == "Chebyshev":
            self.rbf = ChebyshevFunction(num_grids)
        elif kan_type == "BSpline":
            self.rbf = BSplineFunction(grid_min, grid_max, 3, num_grids)

        self.spline_conv = SplineConv2D(in_channels * num_grids, 
                                        out_channels, 
                                        kernel_size,
                                        stride, 
                                        padding, 
                                        dilation, 
                                        groups, 
                                        bias,
                                        spline_weight_init_scale, 
                                        padding_mode)
        
        self.use_base_update = use_base_update
        if use_base_update:
            self.base_activation = base_activation
            self.base_conv = nn.Conv2d(in_channels, 
                                       out_channels, 
                                       kernel_size, 
                                       stride, 
                                       padding, 
                                       dilation, 
                                       groups, 
                                       bias, 
                                       padding_mode)

    def forward(self, x):
        batch_size, channels, height, width = x.shape
        x_rbf = self.rbf(x.view(batch_size, channels, -1)).view(batch_size, channels, height, width, -1)
        x_rbf = x_rbf.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, -1, height, width)
        
        # Apply spline convolution
        ret = self.spline_conv(x_rbf)
         
        if self.use_base_update:
            base = self.base_conv(self.base_activation(x))
            ret = ret + base
        
        return ret


================================================
FILE: Diffusion_UKAN/Diffusion/kan_utils/kan.py
================================================
import torch
import torch.nn.functional as F
import math


class KANLinear(torch.nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        enable_standalone_scale_spline=True,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KANLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.grid_size = grid_size
        self.spline_order = spline_order

        h = (grid_range[1] - grid_range[0]) / grid_size
        grid = (
            (
                torch.arange(-spline_order, grid_size + spline_order + 1) * h
                + grid_range[0]
            )
            .expand(in_features, -1)
            .contiguous()
        )
        self.register_buffer("grid", grid)

        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
        self.spline_weight = torch.nn.Parameter(
            torch.Tensor(out_features, in_features, grid_size + spline_order)
        )
        if enable_standalone_scale_spline:
            self.spline_scaler = torch.nn.Parameter(
                torch.Tensor(out_features, in_features)
            )

        self.scale_noise = scale_noise
        self.scale_base = scale_base
        self.scale_spline = scale_spline
        self.enable_standalone_scale_spline = enable_standalone_scale_spline
        self.base_activation = base_activation()
        self.grid_eps = grid_eps

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
        with torch.no_grad():
            noise = (
                (
                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)
                    - 1 / 2
                )
                * self.scale_noise
                / self.grid_size
            )
            self.spline_weight.data.copy_(
                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
                * self.curve2coeff(
                    self.grid.T[self.spline_order : -self.spline_order],
                    noise,
                )
            )
            if self.enable_standalone_scale_spline:
                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)

    def b_splines(self, x: torch.Tensor):
        """
        Compute the B-spline bases for the given input tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).

        Returns:
            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features

        grid: torch.Tensor = (
            self.grid
        )  # (in_features, grid_size + 2 * spline_order + 1)
        x = x.unsqueeze(-1)
        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
        for k in range(1, self.spline_order + 1):
            bases = (
                (x - grid[:, : -(k + 1)])
                / (grid[:, k:-1] - grid[:, : -(k + 1)])
                * bases[:, :, :-1]
            ) + (
                (grid[:, k + 1 :] - x)
                / (grid[:, k + 1 :] - grid[:, 1:(-k)])
                * bases[:, :, 1:]
            )

        assert bases.size() == (
            x.size(0),
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return bases.contiguous()

    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
        """
        Compute the coefficients of the curve that interpolates the given points.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, in_features).
            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).

        Returns:
            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
        """
        assert x.dim() == 2 and x.size(1) == self.in_features
        assert y.size() == (x.size(0), self.in_features, self.out_features)

        A = self.b_splines(x).transpose(
            0, 1
        )  # (in_features, batch_size, grid_size + spline_order)
        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)
        solution = torch.linalg.lstsq(
            A, B
        ).solution  # (in_features, grid_size + spline_order, out_features)
        result = solution.permute(
            2, 0, 1
        )  # (out_features, in_features, grid_size + spline_order)

        assert result.size() == (
            self.out_features,
            self.in_features,
            self.grid_size + self.spline_order,
        )
        return result.contiguous()

    @property
    def scaled_spline_weight(self):
        return self.spline_weight * (
            self.spline_scaler.unsqueeze(-1)
            if self.enable_standalone_scale_spline
            else 1.0
        )

    def forward(self, x: torch.Tensor):
        assert x.dim() == 2 and x.size(1) == self.in_features

        base_output = F.linear(self.base_activation(x), self.base_weight)
        spline_output = F.linear(
            self.b_splines(x).view(x.size(0), -1),
            self.scaled_spline_weight.view(self.out_features, -1),
        )
        return base_output + spline_output

    @torch.no_grad()
    def update_grid(self, x: torch.Tensor, margin=0.01):
        assert x.dim() == 2 and x.size(1) == self.in_features
        batch = x.size(0)

        splines = self.b_splines(x)  # (batch, in, coeff)
        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)
        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)
        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)
        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)
        unreduced_spline_output = unreduced_spline_output.permute(
            1, 0, 2
        )  # (batch, in, out)

        # sort each channel individually to collect data distribution
        x_sorted = torch.sort(x, dim=0)[0]
        grid_adaptive = x_sorted[
            torch.linspace(
                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
            )
        ]

        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
        grid_uniform = (
            torch.arange(
                self.grid_size + 1, dtype=torch.float32, device=x.device
            ).unsqueeze(1)
            * uniform_step
            + x_sorted[0]
            - margin
        )

        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
        grid = torch.concatenate(
            [
                grid[:1]
                - uniform_step
                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
                grid,
                grid[-1:]
                + uniform_step
                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
            ],
            dim=0,
        )

        self.grid.copy_(grid.T)
        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        """
        Compute the regularization loss.

        This is a dumb simulation of the original L1 regularization as stated in the
        paper, since the original one requires computing absolutes and entropy from the
        expanded (batch, in_features, out_features) intermediate tensor, which is hidden
        behind the F.linear function if we want an memory efficient implementation.

        The L1 regularization is now computed as mean absolute value of the spline
        weights. The authors implementation also includes this term in addition to the
        sample-based regularization.
        """
        l1_fake = self.spline_weight.abs().mean(-1)
        regularization_loss_activation = l1_fake.sum()
        p = l1_fake / regularization_loss_activation
        regularization_loss_entropy = -torch.sum(p * p.log())
        return (
            regularize_activation * regularization_loss_activation
            + regularize_entropy * regularization_loss_entropy
        )


class KAN(torch.nn.Module):
    def __init__(
        self,
        layers_hidden,
        grid_size=5,
        spline_order=3,
        scale_noise=0.1,
        scale_base=1.0,
        scale_spline=1.0,
        base_activation=torch.nn.SiLU,
        grid_eps=0.02,
        grid_range=[-1, 1],
    ):
        super(KAN, self).__init__()
        self.grid_size = grid_size
        self.spline_order = spline_order

        self.layers = torch.nn.ModuleList()
        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
            self.layers.append(
                KANLinear(
                    in_features,
                    out_features,
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_noise=scale_noise,
                    scale_base=scale_base,
                    scale_spline=scale_spline,
                    base_activation=base_activation,
                    grid_eps=grid_eps,
                    grid_range=grid_range,
                )
            )

    def forward(self, x: torch.Tensor, update_grid=False):
        for layer in self.layers:
            if update_grid:
                layer.update_grid(x)
            x = layer(x)
        return x

    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
        return sum(
            layer.regularization_loss(regularize_activation, regularize_entropy)
            for layer in self.layers
        )


================================================
FILE: Diffusion_UKAN/Diffusion/utils.py
================================================
import argparse
import torch.nn as nn

class qkv_transform(nn.Conv1d):
    """Conv1d for qkv_transform"""

def str2bool(v):
    if v.lower() in ['true', 1]:
        return True
    elif v.lower() in ['false', 0]:
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


================================================
FILE: Diffusion_UKAN/Main.py
================================================
from Diffusion.Train import train, eval
import os
import argparse
import torch
import numpy as np

def main(model_config = None):

    if model_config is not None:
        modelConfig = model_config
    if modelConfig["state"] == "train":
        train(modelConfig)
        modelConfig['batch_size'] = 64
        modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])
        for i in range(32):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval(modelConfig)
    else:
        for i in range(32):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval(modelConfig)

def seed_all(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--state', type=str, default='train') # train or eval
    parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc
    parser.add_argument('--epoch', type=int, default=1000) # 1000 for cvc/glas, 5000 for busi
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--T', type=int, default=1000)
    parser.add_argument('--channel', type=int, default=64) # 64 or 128
    parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')
    parser.add_argument('--num_res_blocks', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--img_size', type=float, default=64) 
    parser.add_argument('--dataset_repeat', type=int, default=1) # did not use
    parser.add_argument('--seed', type=int, default=0) # did not use
    parser.add_argument('--model', type=str, default='UKAN_Hybrid')
    parser.add_argument('--exp_nme', type=str, default='UKAN_Hybrid')
    parser.add_argument('--save_root', type=str, default='./Output/') 
    args = parser.parse_args()

    save_root = args.save_root
    if args.seed != 0:
        seed_all(args)

    modelConfig = {
        "dataset": args.dataset, 
        "state": args.state, # or eval
        "epoch": args.epoch,
        "batch_size": args.batch_size,
        "T": args.T,
        "channel": args.channel,
        "channel_mult": [1, 2, 3, 4],
        "attn": [2],
        "num_res_blocks": args.num_res_blocks,
        "dropout": args.dropout,
        "lr": args.lr,
        "multiplier": 2.,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": 64,
        "grad_clip": 1.,
        "device": "cuda", ### MAKE SURE YOU HAVE A GPU !!!
        "training_load_weight": None,
        "save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"),
        "sampled_dir": os.path.join(save_root, args.exp_nme, "Gens"),
        "test_load_weight": args.test_load_weight,
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model":args.model,
        "version": 1,
        "dataset_repeat": args.dataset_repeat,
        "seed": args.seed,
        "save_root": args.save_root,
        }

    os.makedirs(modelConfig["save_weight_dir"], exist_ok=True)
    os.makedirs(modelConfig["sampled_dir"], exist_ok=True)

    # backup 
    import shutil
    shutil.copy("Diffusion/Model_UKAN_Hybrid.py", os.path.join(save_root, args.exp_nme))
    shutil.copy("Diffusion/Train.py", os.path.join(save_root, args.exp_nme))

    main(modelConfig)


================================================
FILE: Diffusion_UKAN/Main_Test.py
================================================
from Diffusion.Train import train, eval, eval_tmp
import os
import argparse
import torch
def main(model_config = None):

    if model_config is not None:
        modelConfig = model_config
    if modelConfig["state"] == "train":
        train(modelConfig)
        modelConfig['batch_size'] = 64
        modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])
        for i in range(32):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval(modelConfig)
    else:
        for i in range(1):
            modelConfig["sampledImgName"] = "sampledImgName{}.png".format(i)
            eval_tmp(modelConfig,1000) # for grid visualization
            # eval(modelConfig) # for metric evaluation

def seed_all(args):
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    import numpy as np
    np.random.seed(args.seed)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--state', type=str, default='eval')
    parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc
    parser.add_argument('--epoch', type=int, default=1000)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--T', type=int, default=1000)
    parser.add_argument('--channel', type=int, default=64)
    parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')
    parser.add_argument('--num_res_blocks', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.15)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--img_size', type=float, default=64) # 64 or 128
    parser.add_argument('--dataset_repeat', type=int, default=1) # didnot use
    parser.add_argument('--seed', type=int, default=0) 
    parser.add_argument('--model', type=str, default='UKan_Hybrid')
    parser.add_argument('--exp_nme', type=str, default='./')

    parser.add_argument('--save_root', type=str, default='released_models/ukan_cvc') 
    # parser.add_argument('--save_root', type=str, default='released_models/ukan_glas') 
    # parser.add_argument('--save_root', type=str, default='released_models/ukan_busi') 
    args = parser.parse_args()

    save_root = args.save_root
    if args.seed != 0:
        seed_all(args)

    modelConfig = {
        "dataset": args.dataset, 
        "state": args.state, # or eval
        "epoch": args.epoch,
        "batch_size": args.batch_size,
        "T": args.T,
        "channel": args.channel,
        "channel_mult": [1, 2, 3, 4],
        "attn": [2],
        "num_res_blocks": args.num_res_blocks,
        "dropout": args.dropout,
        "lr": args.lr,
        "multiplier": 2.,
        "beta_1": 1e-4,
        "beta_T": 0.02,
        "img_size": 64,
        "grad_clip": 1.,
        "device": "cuda", ### MAKE SURE YOU HAVE A GPU !!!
        "training_load_weight": None,
        "save_weight_dir": os.path.join(save_root, args.exp_nme, "Weights"),
        "sampled_dir": os.path.join(save_root, args.exp_nme, "FinalCheck"),
        "test_load_weight": args.test_load_weight,
        "sampledNoisyImgName": "NoisyNoGuidenceImgs.png",
        "sampledImgName": "SampledNoGuidenceImgs.png",
        "nrow": 8,
        "model":args.model, 
        "version": 1,
        "dataset_repeat": args.dataset_repeat,
        "seed": args.seed,
        "save_root": args.save_root,
        }

    os.makedirs(modelConfig["save_weight_dir"], exist_ok=True)
    os.makedirs(modelConfig["sampled_dir"], exist_ok=True)

    main(modelConfig)


================================================
FILE: Diffusion_UKAN/README.md
================================================
# Diffusion UKAN (arxiv)

> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>
> [Chenxin Li](https://xggnet.github.io/)\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)\*, [Hengyu Liu](), [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>✉</sup><br>The Chinese Univerisity of Hong Kong

Contact: wuyangli@cuhk.edu.hk

## 💡 Environment 
You can change the torch and Cuda versions to satisfy your device.
```bash
conda create --name UKAN python=3.10
conda activate UKAN
conda install cudatoolkit=11.3
pip install -r requirement.txt
```

## 🖼️ Gallery of Diffusion UKAN 

![image](./assets/gen.png)

## 📚 Prepare datasets
Download the pre-processed dataset from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/ESqX-V_eLSBEuaJXAzf64JMB16xF9kz3661pJSwQ-hOspg?e=XdABCH) and unzip it into the project folder. The data is pre-processed by the scripts in [tools](./tools).
```
Diffusion_UKAN
|    data
|    └─ cvc
|        └─ images_64
|    └─ busi
|        └─ images_64
|    └─ glas
|        └─ images_64
```
## 📦 Prepare pre-trained models

Download released_models from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/EUVSH8QFUmpJlxyoEj8Pr2IB8PzGbVJg53rc6GcqxGgLDg?e=a4glNt) and unzip it in the project folder.
```
Diffusion_UKAN
|    released_models
|    └─ ukan_cvc
|        └─ FinalCheck   # generated toy images (see next section)
|        └─ Gens         # the generated images used for evaluation in our paper
|        └─ Tmp          # saved generated images during model training with a 50-epoch interval
|        └─ Weights      # The final checkpoint
|        └─ FID.txt      # raw evaluation data 
|        └─ IS.txt       # raw evaluation data  
|    └─ ukan_busi
|    └─ ukan_glas
```
## 🧸 Toy example
Images will be generated in `released_models/ukan_cvc/FinalCheck` by running this:

```python
python Main_Test.py
```
## 🔥 Training
<!-- You may need to modify the dirs slightly. -->
Please refer to the [training_scripts](./training_scripts) folder. Besides, you can play with different network variations by modifying `MODEL` according to the following dictionary,

```python
model_dict = {
    'UNet': UNet,
    'UNet_ConvKan': UNet_ConvKan,
    'UMLP': UMLP,
    'UKan_Hybrid': UKan_Hybrid,
    'UNet_Baseline': UNet_Baseline,
}
```


## 🤞 Acknowledgement 
Thanks for 
We mainly appreciate these excellent projects
- [Simple DDPM](https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-) 
- [Kolmogorov-Arnold Network](https://github.com/mintisan/awesome-kan) 
- [Efficient Kolmogorov-Arnold Network](https://github.com/Blealtan/efficient-kan.git)


## 📜Citation
If you find this work helpful for your project, please consider citing the following paper:
```
@article{li2024ukan,
  title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation},
  author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan},
  journal={arXiv preprint arXiv:2406.02918},
  year={2024}
}
```



================================================
FILE: Diffusion_UKAN/Scheduler.py
================================================
from torch.optim.lr_scheduler import _LRScheduler

class GradualWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
        self.multiplier = multiplier
        self.total_epoch = warm_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        self.last_epoch = None
        self.base_lrs = None
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]


    def step(self, epoch=None, metrics=None):
        if self.finished and self.after_scheduler:
            if epoch is None:
                self.after_scheduler.step(None)
            else:
                self.after_scheduler.step(epoch - self.total_epoch)
        else:
            return super(GradualWarmupScheduler, self).step(epoch)

================================================
FILE: Diffusion_UKAN/data/readme.txt
================================================
download data.zip and unzip here

================================================
FILE: Diffusion_UKAN/inception-score-pytorch/LICENSE.md
================================================
Copyright 2017 Shane T. Barratt

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

================================================
FILE: Diffusion_UKAN/inception-score-pytorch/README.md
================================================
# Inception Score Pytorch

Pytorch was lacking code to calculate the Inception Score for GANs. This repository fills this gap.
However, we do not recommend using the Inception Score to evaluate generative models, see [our note](https://arxiv.org/abs/1801.01973) for why.

## Getting Started

Clone the repository and navigate to it:
```
$ git clone git@github.com:sbarratt/inception-score-pytorch.git
$ cd inception-score-pytorch
```

To generate random 64x64 images and calculate the inception score, do the following:
```
$ python inception_score.py
```

The only function is `inception_score`. It takes a list of numpy images normalized to the range [0,1] and a set of arguments and then calculates the inception score. Please assure your images are 3x299x299 and if not (e.g. your GAN was trained on CIFAR), pass `resize=True` to the function to have it automatically resize using bilinear interpolation before passing the images to the inception network.

```python
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs
    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
```

### Prerequisites

You will need [torch](http://pytorch.org/), [torchvision](https://github.com/pytorch/vision), [numpy/scipy](https://scipy.org/).

## License

This project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details

## Acknowledgments

* Inception Score from [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)


================================================
FILE: Diffusion_UKAN/inception-score-pytorch/inception_score.py
================================================
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
import os
from skimage import io
import cv2
import os
import numpy as np
from scipy.stats import entropy
import torchvision.datasets as dset
import torchvision.transforms as transforms

import argparse
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=32):
    """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

class UnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.transform = transform
        self.image_files = os.listdir(folder)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.folder, image_file)
        image = io.imread(image_path)
  
        if self.transform:
            image = self.transform(image)
        return image
    
class IgnoreLabelDataset(torch.utils.data.Dataset):
    def __init__(self, orig):
        self.orig = orig

    def __getitem__(self, index):
        return self.orig[index][0]

    def __len__(self):
        return len(self.orig)


if __name__ == '__main__':

    # cifar = dset.CIFAR10(root='data/', download=True,
    #                          transform=transforms.Compose([
    #                              transforms.Resize(32),``
    #                              transforms.ToTensor(),
    #                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #                          ])
    # )

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])


    # set args
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-root', type=str, default='/data/wyli/code/TinyDDPM/Output/unet_busi/Gens/')

    args = parser.parse_args()

    dataset = UnlabeledDataset(args.data_root, transform=transform)
    
    print ("Calculating Inception Score...")
    print (inception_score(dataset, cuda=True, batch_size=1, resize=True, splits=10))




================================================
FILE: Diffusion_UKAN/released_models/readme.txt
================================================
download released_models.zip and unzip here

================================================
FILE: Diffusion_UKAN/requirements.txt
================================================
pytorch-fid==0.30.0
torch==2.3.0
torchvision==0.18.0
tqdm
timm==0.9.16
scikit-image==0.23.1

================================================
FILE: Diffusion_UKAN/tools/resive_cvc.py
================================================
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np

# Define the source and destination directories
src_dir = '/data/wyli/data/CVC-ClinicDB/Original/'
dst_dir = '/data/wyli/data/cvc/images_64/'

os.makedirs(dst_dir, exist_ok=True)

# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]

# Define the size of the crop box
crop_size = np.array([288 ,288])

# Define the size of the resized image
resize_size = (64, 64)

for image_file in image_files:
    # Load the image
    image = io.imread(os.path.join(src_dir, image_file))
    # print(image.shape)

    # Calculate the center of the image
    center = np.array(image.shape[:2]) // 2

    # Calculate the start and end points of the crop box
    start = center - crop_size // 2
    end = start + crop_size

    # Crop the image
    cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])

    # Resize the cropped image
    resized_image = transform.resize(cropped_image, resize_size, mode='reflect')

    # Save the resized image to the destination directory
    io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))

================================================
FILE: Diffusion_UKAN/tools/resize_busi.py
================================================
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np

# Define the source and destination directories
src_dir = '/data/wyli/data/busi/images/'
dst_dir = '/data/wyli/data/busi/images_64/'

os.makedirs(dst_dir, exist_ok=True)

# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]

# Define the size of the crop box
crop_size = np.array([400 ,400])

# Define the size of the resized image
# resize_size = (64, 64)
resize_size = (64, 64)

for image_file in image_files:
    # Load the image
    image = io.imread(os.path.join(src_dir, image_file))
    print(image.shape)


    # Calculate the center of the image
    center = np.array(image.shape[:2]) // 2

    # Calculate the start and end points of the crop box
    start = center - crop_size // 2
    end = start + crop_size

    # Crop the image
    cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])

    # Resize the cropped image
    resized_image = transform.resize(cropped_image, resize_size, mode='reflect')

    # Save the resized image to the destination directory
    io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))

================================================
FILE: Diffusion_UKAN/tools/resize_glas.py
================================================
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np
import random

# Define the source and destination directories
src_dir = '/data/wyli/data/glas/images/'
dst_dir = '/data/wyli/data/glas/images_64/'

os.makedirs(dst_dir, exist_ok=True)

# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]

# Define the size of the crop box
crop_size = np.array([64, 64])

# Define the number of crops per image
K = 5

for image_file in image_files:
    # Load the image
    image = io.imread(os.path.join(src_dir, image_file))

    # Get the size of the image
    image_size = np.array(image.shape[:2])

    for i in range(K):
        # Calculate a random start point for the crop box
        start = np.array([random.randint(0, image_size[0] - crop_size[0]), random.randint(0, image_size[1] - crop_size[1])])

        # Calculate the end point of the crop box
        end = start + crop_size

        # Crop the image
        cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])

        # Save the cropped image to the destination directory
        io.imsave(os.path.join(dst_dir, f"{image_file}_{i}.png"), cropped_image)

================================================
FILE: Diffusion_UKAN/training_scripts/busi.sh
================================================
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh

conda activate kan

GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_cvc'
SAVE_ROOT='./Output/'
DATASET='busi'

cd ../

CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME}  \
--batch_size 32  \
--channel 64 \
--dataset ${DATASET} \
--epoch 5000 \
--save_root ${SAVE_ROOT} 
# --lr 1e-4 

# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1

cd inception-score-pytorch

CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens"  > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1


================================================
FILE: Diffusion_UKAN/training_scripts/cvc.sh
================================================
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh

conda activate kan

GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_cvc'
SAVE_ROOT='./Output/'
DATASET='cvc'

cd ../

CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME}  \
--batch_size 32  \
--channel 64 \
--dataset ${DATASET} \
--epoch 1000 \
--save_root ${SAVE_ROOT} 
# --lr 1e-4 

# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1

cd inception-score-pytorch

CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens"  > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1


================================================
FILE: Diffusion_UKAN/training_scripts/glas.sh
================================================
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh

conda activate kan

GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_glas'
SAVE_ROOT='./Output/'
DATASET='glas'

cd ../

CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME}  \
--batch_size 32  \
--channel 64 \
--dataset ${DATASET} \
--epoch 1000 \
--save_root ${SAVE_ROOT} 
# --lr 1e-4 

# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1

cd inception-score-pytorch

CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens"  > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1


================================================
FILE: README.md
================================================
# U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation

:pushpin: This is an official PyTorch implementation of **U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**

[[`Project Page`](https://yes-u-kan.github.io/)] [[`arXiv`](https://arxiv.org/abs/2406.02918)] [[`BibTeX`](#citation)]

<p align="center">
  <img src="./assets/logo_1.png" alt="" width="120" height="120">
</p>

> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>
> [Chenxin Li](https://xggnet.github.io/)<sup>1\*</sup>, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)<sup>1\*</sup>, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)<sup>1\*</sup>, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)<sup>1\*</sup>, [Hengyu Liu](https://liuhengyu321.github.io/)<sup>1</sup>, [Yifan Liu](https://yifliu3.github.io/)<sup>1</sup>, [Chen Zhen](https://franciszchen.github.io/)<sup>2</sup>, [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>1✉</sup><br> <sup>1</sup>The Chinese Univerisity of Hong Kong, <sup>2</sup>Centre for Artificial Intelligence and Robotics, Hong Kong

We explore the untapped potential of Kolmogorov-Anold Network (aka. KAN) in improving backbones for vision tasks. We investigate, modify and re-design the established U-Net pipeline by integrating the dedicated KAN layers on the tokenized intermediate representation, termed U-KAN. Rigorous medical image segmentation benchmarks verify the superiority of U-KAN by higher accuracy even with less computation cost. We further delved into the potential of U-KAN as an alternative U-Net noise predictor in diffusion models, demonstrating its applicability in generating task-oriented model architectures. These endeavours unveil valuable insights and sheds light on the prospect that with U-KAN, you can make strong backbone for medical image segmentation and generation.

<div align="center">
    <img width="100%" alt="UKAN overview" src="assets/framework-1.jpg"/>
</div>

## 📰News

 **[NOTE]** Random seed is essential for eval metric, and all reported results are calculated over three random runs with seeds of 2981, 6142, 1187, following rolling-UNet. We think most issues are related with this.

**[2024.10]** U-KAN is accepted by AAAI-25. 

**[2024.6]** Some modifications are done in Seg_UKAN for better performance reproduction. The previous code can be quickly updated by replacing the contents of train.py and archs.py with the new ones.

**[2024.6]** Model checkpoints and training logs are released!

**[2024.6]** Code and paper of U-KAN are released!

## 💡Key Features
- The first effort to incorporate the advantage of emerging KAN to improve established U-Net pipeline to be more **accurate, efficient and interpretable**.
- A Segmentation U-KAN with **tokenized KAN block to effectively steer the KAN operators** to be compatible with the exiting convolution-based designs.
- A Diffusion U-KAN as an **improved noise predictor** demonstrates its potential in backboning generative tasks and broader vision settings.

## 🛠Setup

```bash
git clone https://github.com/CUHK-AIM-Group/U-KAN.git
cd U-KAN
conda create -n ukan python=3.10
conda activate ukan
cd Seg_UKAN && pip install -r requirements.txt
```

**Tips A**: We test the framework using pytorch=1.13.0, and the CUDA compile version=11.6. Other versions should be also fine but not totally ensured.


## 📚Data Preparation
**BUSI**:  The dataset can be found [here](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset). 

**GLAS**:  The dataset can be found [here](https://websignon.warwick.ac.uk/origin/slogin?shire=https%3A%2F%2Fwarwick.ac.uk%2Fsitebuilder2%2Fshire-read&providerId=urn%3Awarwick.ac.uk%3Asitebuilder2%3Aread%3Aservice&target=https%3A%2F%2Fwarwick.ac.uk%2Ffac%2Fcross_fac%2Ftia%2Fdata%2Fglascontest&status=notloggedin). 
<!-- You can directly use the [processed GLAS data]() without further data processing. -->
**CVC-ClinicDB**:  The dataset can be found [here](https://www.dropbox.com/s/p5qe9eotetjnbmq/CVC-ClinicDB.rar?e=3&dl=0). 
<!-- You can directly use the [processed CVC-ClinicDB data]() without further data processing. -->

We also provide all the [pre-processed dataset](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/ErDlT-t0WoBNlKhBlbYfReYB-iviSCmkNRb1GqZ90oYjJA?e=hrPNWD) without requiring any further data processing. You can directly download and put them into the data dir.



The resulted file structure is as follows.
```
Seg_UKAN
├── inputs
│   ├── busi
│     ├── images
│           ├── malignant (1).png
|           ├── ...
|     ├── masks
│        ├── 0
│           ├── malignant (1)_mask.png
|           ├── ...
│   ├── GLAS
│     ├── images
│           ├── 0.png
|           ├── ...
|     ├── masks
│        ├── 0
│           ├── 0.png
|           ├── ...
│   ├── CVC-ClinicDB
│     ├── images
│           ├── 0.png
|           ├── ...
|     ├── masks
│        ├── 0
│           ├── 0.png
|           ├── ...
```

## 🔖Evaluating Segmentation U-KAN

You can directly evaluate U-KAN from the checkpoint model. Here is an example for quick usage for using our **pre-trained models** in [Segmentation Model Zoo](#segmentation-model-zoo):
1. Download the pre-trained weights and put them to ```{args.output_dir}/{args.name}/model.pth```
2. Run the following scripts to 
```bash
cd Seg_UKAN
python val.py --name ${dataset}_UKAN --output_dir [YOUR_OUTPUT_DIR] 
```

## ⏳Training Segmentation U-KAN

You can simply train U-KAN on a single GPU by specifing the dataset name ```--dataset``` and input size ```--input_size```.
```bash
cd Seg_UKAN
python train.py --arch UKAN --dataset {dataset} --input_w {input_size} --input_h {input_size} --name {dataset}_UKAN  --data_dir [YOUR_DATA_DIR]
```
For example, train U-KAN with the resolution of 256x256 with a single GPU on the BUSI dataset in the ```inputs``` dir:
```bash
cd Seg_UKAN
python train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN  --data_dir ./inputs
```
Please see Seg_UKAN/scripts.sh for more details.
Note that the resolution of glas is 512x512, differing with other datasets (256x256).

**[Quick Update]** Please follow the seeds of 2981, 6142, 1187 to fully reproduce the paper experimental results. All compared methods are evaluated on the same seed setting.

## 🎪Segmentation Model Zoo
We provide all the pre-trained model [checkpoints](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ej6yZBSIrU5Ds9q-gQdhXqwBbpov5_MaWF483uZHm2lccA?e=rmlHMo)
Here is an overview of the released performance&checkpoints. Note that results on a single run and the reported average results in the paper differ.
|Method| Dataset | IoU | F1  | Checkpoints |
|-----|------|-----|-----|-----|
|Seg U-KAN| BUSI | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|
|Seg U-KAN| GLAS | 87.51 | 93.33 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EunQ9KRf6n1AqCJ40FWZF-QB25GMOoF7hoIwU15fefqFbw?e=m7kXwe)|
|Seg U-KAN| CVC-ClinicDB | 85.61 | 92.19 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ekhb3PEmwZZMumSG69wPRRQBymYIi0PFNuLJcVNmmK1fjA?e=5XzVSi)|

The parameter ``--no_kan'' denotes the baseline model that is replaced the KAN layers with MLP layers. Please note: it is reasonable to find occasional inconsistencies in performance, and the average results over multiple runs can reveal a more obvious trend.
|Method| Layer Type | IoU | F1  | Checkpoints |
|-----|------|-----|-----|-----|
|Seg U-KAN (--no_kan)| MLP Layer  | 63.49 |	77.07 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EmEH_qokqIFNtP59yU7vY_4Bq4Yc424zuYufwaJuiAGKiw?e=IJ3clx)|
|Seg U-KAN| KAN Layer |  65.26 | 78.75  | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|

## 🎇Medical Image Generation with Diffusion U-KAN

Please refer to [Diffusion_UKAN](./Diffusion_UKAN/README.md)


## 🛒TODO List
- [X] Release code for Seg U-KAN.
- [X] Release code for Diffusion U-KAN.
- [X] Upload the pretrained checkpoints.


## 🎈Acknowledgements
Greatly appreciate the tremendous effort for the following projects!
- [CKAN](https://github.com/AntonioTepsich/Convolutional-KANs)


## 📜Citation
If you find this work helpful for your project,please consider citing the following paper:
```
@article{li2024ukan,
  title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation},
  author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan},
  journal={arXiv preprint arXiv:2406.02918},
  year={2024}
'''
}


================================================
FILE: Seg_UKAN/LICENSE
================================================
MIT License

Copyright (c) 2022 Jeya Maria Jose

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: Seg_UKAN/archs.py
================================================
import torch
from torch import nn
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from utils import *

import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types
import math
from abc import ABCMeta, abstractmethod
# from mmcv.cnn import ConvModule
from pdb import set_trace as st

from kan import KANLinear, KAN
from torch.nn import init


class KANLayer(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dim = in_features
        
        grid_size=5
        spline_order=3
        scale_noise=0.1
        scale_base=1.0
        scale_spline=1.0
        base_activation=torch.nn.SiLU
        grid_eps=0.02
        grid_range=[-1, 1]

        if not no_kan:
            self.fc1 = KANLinear(
                        in_features,
                        hidden_features,
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_noise=scale_noise,
                        scale_base=scale_base,
                        scale_spline=scale_spline,
                        base_activation=base_activation,
                        grid_eps=grid_eps,
                        grid_range=grid_range,
                    )
            self.fc2 = KANLinear(
                        hidden_features,
                        out_features,
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_noise=scale_noise,
                        scale_base=scale_base,
                        scale_spline=scale_spline,
                        base_activation=base_activation,
                        grid_eps=grid_eps,
                        grid_range=grid_range,
                    )
            self.fc3 = KANLinear(
                        hidden_features,
                        out_features,
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_noise=scale_noise,
                        scale_base=scale_base,
                        scale_spline=scale_spline,
                        base_activation=base_activation,
                        grid_eps=grid_eps,
                        grid_range=grid_range,
                    )
            # # TODO   
            # self.fc4 = KANLinear(
            #             hidden_features,
            #             out_features,
            #             grid_size=grid_size,
            #             spline_order=spline_order,
            #             scale_noise=scale_noise,
            #             scale_base=scale_base,
            #             scale_spline=scale_spline,
            #             base_activation=base_activation,
            #             grid_eps=grid_eps,
            #             grid_range=grid_range,
            #         )   

        else:
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.fc3 = nn.Linear(hidden_features, out_features)

        # TODO
        # self.fc1 = nn.Linear(in_features, hidden_features)


        self.dwconv_1 = DW_bn_relu(hidden_features)
        self.dwconv_2 = DW_bn_relu(hidden_features)
        self.dwconv_3 = DW_bn_relu(hidden_features)

        # # TODO
        # self.dwconv_4 = DW_bn_relu(hidden_features)
    
        self.drop = nn.Dropout(drop)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
    

    def forward(self, x, H, W):
        # pdb.set_trace()
        B, N, C = x.shape

        x = self.fc1(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_1(x, H, W)
        x = self.fc2(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_2(x, H, W)
        x = self.fc3(x.reshape(B*N,C))
        x = x.reshape(B,N,C).contiguous()
        x = self.dwconv_3(x, H, W)

        # # TODO
        # x = x.reshape(B,N,C).contiguous()
        # x = self.dwconv_4(x, H, W)
    
        return x

class KANBlock(nn.Module):
    def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):
        super().__init__()

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim)

        self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        x = x + self.drop_path(self.layer(self.norm2(x), H, W))

        return x


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class DW_bn_relu(nn.Module):
    def __init__(self, dim=768):
        super(DW_bn_relu, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
        self.bn = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU()

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
                              padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class ConvLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvLayer, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)

class D_ConvLayer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(D_ConvLayer, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, in_ch, 3, padding=1),
            nn.BatchNorm2d(in_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)



class UKAN(nn.Module):
    def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, patch_size=16, in_chans=3, embed_dims=[256, 320, 512], no_kan=False,
    drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[1, 1, 1], **kwargs):
        super().__init__()

        kan_input_dim = embed_dims[0]

        self.encoder1 = ConvLayer(3, kan_input_dim//8)  
        self.encoder2 = ConvLayer(kan_input_dim//8, kan_input_dim//4)  
        self.encoder3 = ConvLayer(kan_input_dim//4, kan_input_dim)

        self.norm3 = norm_layer(embed_dims[1])
        self.norm4 = norm_layer(embed_dims[2])

        self.dnorm3 = norm_layer(embed_dims[1])
        self.dnorm4 = norm_layer(embed_dims[0])

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.block1 = nn.ModuleList([KANBlock(
            dim=embed_dims[1], 
            drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
            )])

        self.block2 = nn.ModuleList([KANBlock(
            dim=embed_dims[2],
            drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
            )])

        self.dblock1 = nn.ModuleList([KANBlock(
            dim=embed_dims[1], 
            drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
            )])

        self.dblock2 = nn.ModuleList([KANBlock(
            dim=embed_dims[0], 
            drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
            )])

        self.patch_embed3 = PatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
        self.patch_embed4 = PatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])

        self.decoder1 = D_ConvLayer(embed_dims[2], embed_dims[1])  
        self.decoder2 = D_ConvLayer(embed_dims[1], embed_dims[0])  
        self.decoder3 = D_ConvLayer(embed_dims[0], embed_dims[0]//4) 
        self.decoder4 = D_ConvLayer(embed_dims[0]//4, embed_dims[0]//8)
        self.decoder5 = D_ConvLayer(embed_dims[0]//8, embed_dims[0]//8)

        self.final = nn.Conv2d(embed_dims[0]//8, num_classes, kernel_size=1)
        self.soft = nn.Softmax(dim =1)

    def forward(self, x):
        
        B = x.shape[0]
        ### Encoder
        ### Conv Stage

        ### Stage 1
        out = F.relu(F.max_pool2d(self.encoder1(x), 2, 2))
        t1 = out
        ### Stage 2
        out = F.relu(F.max_pool2d(self.encoder2(out), 2, 2))
        t2 = out
        ### Stage 3
        out = F.relu(F.max_pool2d(self.encoder3(out), 2, 2))
        t3 = out

        ### Tokenized KAN Stage
        ### Stage 4

        out, H, W = self.patch_embed3(out)
        for i, blk in enumerate(self.block1):
            out = blk(out, H, W)
        out = self.norm3(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
        t4 = out

        ### Bottleneck

        out, H, W= self.patch_embed4(out)
        for i, blk in enumerate(self.block2):
            out = blk(out, H, W)
        out = self.norm4(out)
        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        ### Stage 4
        out = F.relu(F.interpolate(self.decoder1(out), scale_factor=(2,2), mode ='bilinear'))

        out = torch.add(out, t4)
        _, _, H, W = out.shape
        out = out.flatten(2).transpose(1,2)
        for i, blk in enumerate(self.dblock1):
            out = blk(out, H, W)

        ### Stage 3
        out = self.dnorm3(out
Download .txt
gitextract_suz9sp40/

├── Diffusion_UKAN/
│   ├── Diffusion/
│   │   ├── Diffusion.py
│   │   ├── Model.py
│   │   ├── Model_ConvKan.py
│   │   ├── Model_UKAN_Hybrid.py
│   │   ├── Model_UMLP.py
│   │   ├── Train.py
│   │   ├── UNet.py
│   │   ├── __init__.py
│   │   ├── kan_utils/
│   │   │   ├── __init__.py
│   │   │   ├── fastkanconv.py
│   │   │   └── kan.py
│   │   └── utils.py
│   ├── Main.py
│   ├── Main_Test.py
│   ├── README.md
│   ├── Scheduler.py
│   ├── data/
│   │   └── readme.txt
│   ├── inception-score-pytorch/
│   │   ├── LICENSE.md
│   │   ├── README.md
│   │   └── inception_score.py
│   ├── released_models/
│   │   └── readme.txt
│   ├── requirements.txt
│   ├── tools/
│   │   ├── resive_cvc.py
│   │   ├── resize_busi.py
│   │   └── resize_glas.py
│   └── training_scripts/
│       ├── busi.sh
│       ├── cvc.sh
│       └── glas.sh
├── README.md
└── Seg_UKAN/
    ├── LICENSE
    ├── archs.py
    ├── config.py
    ├── dataset.py
    ├── environment.yml
    ├── kan.py
    ├── losses.py
    ├── metrics.py
    ├── requirements.txt
    ├── scripts.sh
    ├── train.py
    ├── utils.py
    └── val.py
Download .txt
SYMBOL INDEX (405 symbols across 23 files)

FILE: Diffusion_UKAN/Diffusion/Diffusion.py
  function extract (line 9) | def extract(v, t, x_shape):
  class GaussianDiffusionTrainer (line 19) | class GaussianDiffusionTrainer(nn.Module):
    method __init__ (line 20) | def __init__(self, model, beta_1, beta_T, T):
    method forward (line 37) | def forward(self, x_0):
  class GaussianDiffusionSampler (line 50) | class GaussianDiffusionSampler(nn.Module):
    method __init__ (line 51) | def __init__(self, model, beta_1, beta_T, T):
    method predict_xt_prev_mean_from_eps (line 67) | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
    method p_mean_variance (line 74) | def p_mean_variance(self, x_t, t):
    method forward (line 84) | def forward(self, x_T):

FILE: Diffusion_UKAN/Diffusion/Model.py
  class Swish (line 10) | class Swish(nn.Module):
    method forward (line 11) | def forward(self, x):
  class TimeEmbedding (line 14) | class TimeEmbedding(nn.Module):
    method __init__ (line 15) | def __init__(self, T, d_model, dim):
    method initialize (line 35) | def initialize(self):
    method forward (line 41) | def forward(self, t):
  class DownSample (line 46) | class DownSample(nn.Module):
    method __init__ (line 47) | def __init__(self, in_ch):
    method initialize (line 52) | def initialize(self):
    method forward (line 56) | def forward(self, x, temb):
  class UpSample (line 61) | class UpSample(nn.Module):
    method __init__ (line 62) | def __init__(self, in_ch):
    method initialize (line 67) | def initialize(self):
    method forward (line 71) | def forward(self, x, temb):
  class AttnBlock (line 79) | class AttnBlock(nn.Module):
    method __init__ (line 80) | def __init__(self, in_ch):
    method initialize (line 89) | def initialize(self):
    method forward (line 95) | def forward(self, x):
  class ResBlock (line 117) | class ResBlock(nn.Module):
    method __init__ (line 118) | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
    method initialize (line 145) | def initialize(self):
    method forward (line 152) | def forward(self, x, temb):
  class KANLinear (line 162) | class KANLinear(torch.nn.Module):
    method __init__ (line 163) | def __init__(
    method reset_parameters (line 212) | def reset_parameters(self):
    method b_splines (line 234) | def b_splines(self, x: torch.Tensor):
    method curve2coeff (line 269) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
    method scaled_spline_weight (line 302) | def scaled_spline_weight(self):
    method forward (line 309) | def forward(self, x: torch.Tensor):
    method update_grid (line 322) | def update_grid(self, x: torch.Tensor, margin=0.01):
    method regularization_loss (line 370) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  class Ukan (line 392) | class Ukan(nn.Module):
    method __init__ (line 393) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 486) | def initialize(self):
    method forward (line 492) | def forward(self, x, t):
  class DW_bn_relu (line 524) | class DW_bn_relu(nn.Module):
    method __init__ (line 525) | def __init__(self, dim=768):
    method forward (line 531) | def forward(self, x, H, W):
  class kan (line 541) | class kan(nn.Module):
    method __init__ (line 542) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method _init_weights (line 674) | def _init_weights(self, m):
    method forward (line 690) | def forward(self, x, H, W):
  class Ukan_v3 (line 831) | class Ukan_v3(nn.Module):
    method __init__ (line 832) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 900) | def initialize(self):
    method forward (line 906) | def forward(self, x, t):
  class Ukan_v2 (line 944) | class Ukan_v2(nn.Module):
    method __init__ (line 945) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,versi...
    method initialize (line 1012) | def initialize(self):
    method forward (line 1018) | def forward(self, x, t):
  class UNet (line 1055) | class UNet(nn.Module):
    method __init__ (line 1056) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 1103) | def initialize(self):
    method forward (line 1109) | def forward(self, x, t):
  class UNet_MLP (line 1137) | class UNet_MLP(nn.Module):
    method __init__ (line 1138) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 1203) | def initialize(self):
    method forward (line 1209) | def forward(self, x, t):

FILE: Diffusion_UKAN/Diffusion/Model_ConvKan.py
  class Swish (line 10) | class Swish(nn.Module):
    method forward (line 11) | def forward(self, x):
  class TimeEmbedding (line 14) | class TimeEmbedding(nn.Module):
    method __init__ (line 15) | def __init__(self, T, d_model, dim):
    method initialize (line 35) | def initialize(self):
    method forward (line 41) | def forward(self, t):
  class DownSample (line 46) | class DownSample(nn.Module):
    method __init__ (line 47) | def __init__(self, in_ch):
    method initialize (line 53) | def initialize(self):
    method forward (line 57) | def forward(self, x, temb):
  class UpSample (line 62) | class UpSample(nn.Module):
    method __init__ (line 63) | def __init__(self, in_ch):
    method initialize (line 69) | def initialize(self):
    method forward (line 73) | def forward(self, x, temb):
  class AttnBlock (line 81) | class AttnBlock(nn.Module):
    method __init__ (line 82) | def __init__(self, in_ch):
    method initialize (line 91) | def initialize(self):
    method forward (line 97) | def forward(self, x):
  class ResBlock (line 119) | class ResBlock(nn.Module):
    method __init__ (line 120) | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
    method initialize (line 150) | def initialize(self):
    method forward (line 157) | def forward(self, x, temb):
  class UNet_ConvKan (line 166) | class UNet_ConvKan(nn.Module):
    method __init__ (line 167) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 215) | def initialize(self):
    method forward (line 221) | def forward(self, x, t):

FILE: Diffusion_UKAN/Diffusion/Model_UKAN_Hybrid.py
  class KANLinear (line 10) | class KANLinear(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method reset_parameters (line 60) | def reset_parameters(self):
    method b_splines (line 82) | def b_splines(self, x: torch.Tensor):
    method curve2coeff (line 117) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
    method scaled_spline_weight (line 150) | def scaled_spline_weight(self):
    method forward (line 157) | def forward(self, x: torch.Tensor):
    method update_grid (line 168) | def update_grid(self, x: torch.Tensor, margin=0.01):
    method regularization_loss (line 216) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  class KAN (line 239) | class KAN(torch.nn.Module):
    method __init__ (line 240) | def __init__(
    method forward (line 273) | def forward(self, x: torch.Tensor, update_grid=False):
    method regularization_loss (line 280) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  function conv1x1 (line 287) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  function shift (line 292) | def shift(dim):
  class OverlapPatchEmbed (line 300) | class OverlapPatchEmbed(nn.Module):
    method __init__ (line 304) | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, e...
    method _init_weights (line 319) | def _init_weights(self, m):
    method forward (line 334) | def forward(self, x):
  class Swish (line 343) | class Swish(nn.Module):
    method forward (line 344) | def forward(self, x):
  function swish (line 346) | def swish(x):
  class TimeEmbedding (line 351) | class TimeEmbedding(nn.Module):
    method __init__ (line 352) | def __init__(self, T, d_model, dim):
    method initialize (line 372) | def initialize(self):
    method forward (line 378) | def forward(self, t):
  class DownSample (line 383) | class DownSample(nn.Module):
    method __init__ (line 384) | def __init__(self, in_ch):
    method initialize (line 389) | def initialize(self):
    method forward (line 393) | def forward(self, x, temb):
  class UpSample (line 398) | class UpSample(nn.Module):
    method __init__ (line 399) | def __init__(self, in_ch):
    method initialize (line 404) | def initialize(self):
    method forward (line 408) | def forward(self, x, temb):
  class kan (line 415) | class kan(nn.Module):
    method __init__ (line 416) | def __init__(self, in_features, hidden_features=None, out_features=None):
    method _init_weights (line 446) | def _init_weights(self, m):
    method forward (line 462) | def forward(self, x, H, W):
  class shiftedBlock (line 469) | class shiftedBlock(nn.Module):
    method __init__ (line 470) | def __init__(self, dim,  mlp_ratio=4.,drop_path=0.,norm_layer=nn.Layer...
    method _init_weights (line 486) | def _init_weights(self, m):
    method forward (line 501) | def forward(self, x, H, W, temb):
  class DWConv (line 509) | class DWConv(nn.Module):
    method __init__ (line 510) | def __init__(self, dim=768):
    method forward (line 514) | def forward(self, x, H, W):
  class DW_bn_relu (line 522) | class DW_bn_relu(nn.Module):
    method __init__ (line 523) | def __init__(self, dim=768):
    method forward (line 529) | def forward(self, x, H, W):
  class SingleConv (line 539) | class SingleConv(nn.Module):
    method __init__ (line 540) | def __init__(self, in_ch, h_ch):
    method forward (line 552) | def forward(self, input, temb):
  class DoubleConv (line 556) | class DoubleConv(nn.Module):
    method __init__ (line 557) | def __init__(self, in_ch, h_ch):
    method forward (line 571) | def forward(self, input, temb):
  class D_SingleConv (line 575) | class D_SingleConv(nn.Module):
    method __init__ (line 576) | def __init__(self, in_ch, h_ch):
    method forward (line 587) | def forward(self, input, temb):
  class D_DoubleConv (line 591) | class D_DoubleConv(nn.Module):
    method __init__ (line 592) | def __init__(self, in_ch, h_ch):
    method forward (line 606) | def forward(self, input,temb):
  class AttnBlock (line 609) | class AttnBlock(nn.Module):
    method __init__ (line 610) | def __init__(self, in_ch):
    method initialize (line 619) | def initialize(self):
    method forward (line 625) | def forward(self, x):
  class ResBlock (line 647) | class ResBlock(nn.Module):
    method __init__ (line 648) | def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):
    method initialize (line 675) | def initialize(self):
    method forward (line 682) | def forward(self, x, temb):
  class UKan_Hybrid (line 692) | class UKan_Hybrid(nn.Module):
    method __init__ (line 693) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 758) | def initialize(self):
    method forward (line 764) | def forward(self, x, t):

FILE: Diffusion_UKAN/Diffusion/Model_UMLP.py
  class KANLinear (line 10) | class KANLinear(torch.nn.Module):
    method __init__ (line 11) | def __init__(
    method reset_parameters (line 60) | def reset_parameters(self):
    method b_splines (line 82) | def b_splines(self, x: torch.Tensor):
    method curve2coeff (line 117) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
    method scaled_spline_weight (line 150) | def scaled_spline_weight(self):
    method forward (line 157) | def forward(self, x: torch.Tensor):
    method update_grid (line 168) | def update_grid(self, x: torch.Tensor, margin=0.01):
    method regularization_loss (line 216) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  class KAN (line 239) | class KAN(torch.nn.Module):
    method __init__ (line 240) | def __init__(
    method forward (line 273) | def forward(self, x: torch.Tensor, update_grid=False):
    method regularization_loss (line 280) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  function conv1x1 (line 287) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
  function shift (line 292) | def shift(dim):
  class OverlapPatchEmbed (line 300) | class OverlapPatchEmbed(nn.Module):
    method __init__ (line 304) | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, e...
    method _init_weights (line 319) | def _init_weights(self, m):
    method forward (line 334) | def forward(self, x):
  class Swish (line 343) | class Swish(nn.Module):
    method forward (line 344) | def forward(self, x):
  function swish (line 346) | def swish(x):
  class TimeEmbedding (line 351) | class TimeEmbedding(nn.Module):
    method __init__ (line 352) | def __init__(self, T, d_model, dim):
    method initialize (line 372) | def initialize(self):
    method forward (line 378) | def forward(self, t):
  class DownSample (line 383) | class DownSample(nn.Module):
    method __init__ (line 384) | def __init__(self, in_ch):
    method initialize (line 389) | def initialize(self):
    method forward (line 393) | def forward(self, x, temb):
  class UpSample (line 398) | class UpSample(nn.Module):
    method __init__ (line 399) | def __init__(self, in_ch):
    method initialize (line 404) | def initialize(self):
    method forward (line 408) | def forward(self, x, temb):
  class kan (line 415) | class kan(nn.Module):
    method __init__ (line 416) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method _init_weights (line 450) | def _init_weights(self, m):
    method forward (line 466) | def forward(self, x, H, W):
  class shiftedBlock (line 475) | class shiftedBlock(nn.Module):
    method __init__ (line 476) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
    method _init_weights (line 492) | def _init_weights(self, m):
    method forward (line 507) | def forward(self, x, H, W, temb):
  class DWConv (line 516) | class DWConv(nn.Module):
    method __init__ (line 517) | def __init__(self, dim=768):
    method forward (line 521) | def forward(self, x, H, W):
  class DW_bn_relu (line 529) | class DW_bn_relu(nn.Module):
    method __init__ (line 530) | def __init__(self, dim=768):
    method forward (line 536) | def forward(self, x, H, W):
  class SingleConv (line 546) | class SingleConv(nn.Module):
    method __init__ (line 547) | def __init__(self, in_ch, h_ch):
    method forward (line 559) | def forward(self, input, temb):
  class DoubleConv (line 563) | class DoubleConv(nn.Module):
    method __init__ (line 564) | def __init__(self, in_ch, h_ch):
    method forward (line 578) | def forward(self, input, temb):
  class D_SingleConv (line 582) | class D_SingleConv(nn.Module):
    method __init__ (line 583) | def __init__(self, in_ch, h_ch):
    method forward (line 594) | def forward(self, input, temb):
  class D_DoubleConv (line 598) | class D_DoubleConv(nn.Module):
    method __init__ (line 599) | def __init__(self, in_ch, h_ch):
    method forward (line 613) | def forward(self, input,temb):
  class AttnBlock (line 616) | class AttnBlock(nn.Module):
    method __init__ (line 617) | def __init__(self, in_ch):
    method initialize (line 626) | def initialize(self):
    method forward (line 632) | def forward(self, x):
  class ResBlock (line 654) | class ResBlock(nn.Module):
    method __init__ (line 655) | def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):
    method initialize (line 682) | def initialize(self):
    method forward (line 689) | def forward(self, x, temb):
  class UMLP (line 699) | class UMLP(nn.Module):
    method __init__ (line 700) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 788) | def initialize(self):
    method forward (line 794) | def forward(self, x, t):

FILE: Diffusion_UKAN/Diffusion/Train.py
  class UnlabeledDataset (line 32) | class UnlabeledDataset(Dataset):
    method __init__ (line 33) | def __init__(self, folder, transform=None, repeat_n=1):
    method __len__ (line 39) | def __len__(self):
    method __getitem__ (line 42) | def __getitem__(self, idx):
  function train (line 51) | def train(modelConfig: Dict):
  function eval_tmp (line 136) | def eval_tmp(modelConfig: Dict, nme: int):
  function eval (line 168) | def eval(modelConfig: Dict):

FILE: Diffusion_UKAN/Diffusion/UNet.py
  class Swish (line 10) | class Swish(nn.Module):
    method forward (line 11) | def forward(self, x):
  class TimeEmbedding (line 15) | class TimeEmbedding(nn.Module):
    method __init__ (line 16) | def __init__(self, T, d_model, dim):
    method initialize (line 36) | def initialize(self):
    method forward (line 42) | def forward(self, t):
  class DownSample (line 47) | class DownSample(nn.Module):
    method __init__ (line 48) | def __init__(self, in_ch):
    method initialize (line 53) | def initialize(self):
    method forward (line 57) | def forward(self, x, temb):
  class UpSample (line 62) | class UpSample(nn.Module):
    method __init__ (line 63) | def __init__(self, in_ch):
    method initialize (line 68) | def initialize(self):
    method forward (line 72) | def forward(self, x, temb):
  class AttnBlock (line 80) | class AttnBlock(nn.Module):
    method __init__ (line 81) | def __init__(self, in_ch):
    method initialize (line 90) | def initialize(self):
    method forward (line 96) | def forward(self, x):
  class ResBlock (line 118) | class ResBlock(nn.Module):
    method __init__ (line 119) | def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
    method initialize (line 146) | def initialize(self):
    method forward (line 153) | def forward(self, x, temb):
  class UNet (line 163) | class UNet(nn.Module):
    method __init__ (line 164) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 210) | def initialize(self):
    method forward (line 216) | def forward(self, x, t):
  class UNet_Baseline (line 240) | class UNet_Baseline(nn.Module):
    method __init__ (line 242) | def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
    method initialize (line 284) | def initialize(self):
    method forward (line 290) | def forward(self, x, t):

FILE: Diffusion_UKAN/Diffusion/kan_utils/fastkanconv.py
  class PolynomialFunction (line 7) | class PolynomialFunction(nn.Module):
    method __init__ (line 8) | def __init__(self,
    method forward (line 13) | def forward(self, x):
  class BSplineFunction (line 16) | class BSplineFunction(nn.Module):
    method __init__ (line 17) | def __init__(self, grid_min: float = -2.,
    method basis_function (line 24) | def basis_function(self, i, k, t):
    method forward (line 37) | def forward(self, x):
  class ChebyshevFunction (line 42) | class ChebyshevFunction(nn.Module):
    method __init__ (line 43) | def __init__(self, degree: int = 4):
    method forward (line 47) | def forward(self, x):
  class FourierBasisFunction (line 53) | class FourierBasisFunction(nn.Module):
    method __init__ (line 54) | def __init__(self,
    method forward (line 62) | def forward(self, x):
  class RadialBasisFunction (line 69) | class RadialBasisFunction(nn.Module):
    method __init__ (line 70) | def __init__(
    method forward (line 82) | def forward(self, x):
  class SplineConv2D (line 88) | class SplineConv2D(nn.Conv2d):
    method __init__ (line 89) | def __init__(self,
    method reset_parameters (line 115) | def reset_parameters(self) -> None:
  class FastKANConvLayer (line 121) | class FastKANConvLayer(nn.Module):
    method __init__ (line 122) | def __init__(self,
    method forward (line 178) | def forward(self, x):

FILE: Diffusion_UKAN/Diffusion/kan_utils/kan.py
  class KANLinear (line 6) | class KANLinear(torch.nn.Module):
    method __init__ (line 7) | def __init__(
    method reset_parameters (line 56) | def reset_parameters(self):
    method b_splines (line 78) | def b_splines(self, x: torch.Tensor):
    method curve2coeff (line 113) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
    method scaled_spline_weight (line 146) | def scaled_spline_weight(self):
    method forward (line 153) | def forward(self, x: torch.Tensor):
    method update_grid (line 164) | def update_grid(self, x: torch.Tensor, margin=0.01):
    method regularization_loss (line 212) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  class KAN (line 235) | class KAN(torch.nn.Module):
    method __init__ (line 236) | def __init__(
    method forward (line 269) | def forward(self, x: torch.Tensor, update_grid=False):
    method regularization_loss (line 276) | def regularization_loss(self, regularize_activation=1.0, regularize_en...

FILE: Diffusion_UKAN/Diffusion/utils.py
  class qkv_transform (line 4) | class qkv_transform(nn.Conv1d):
  function str2bool (line 7) | def str2bool(v):
  function count_params (line 16) | def count_params(model):
  class AverageMeter (line 20) | class AverageMeter(object):
    method __init__ (line 23) | def __init__(self):
    method reset (line 26) | def reset(self):
    method update (line 32) | def update(self, val, n=1):

FILE: Diffusion_UKAN/Main.py
  function main (line 7) | def main(model_config = None):
  function seed_all (line 23) | def seed_all(args):

FILE: Diffusion_UKAN/Main_Test.py
  function main (line 5) | def main(model_config = None):
  function seed_all (line 22) | def seed_all(args):

FILE: Diffusion_UKAN/Scheduler.py
  class GradualWarmupScheduler (line 3) | class GradualWarmupScheduler(_LRScheduler):
    method __init__ (line 4) | def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=...
    method get_lr (line 13) | def get_lr(self):
    method step (line 24) | def step(self, epoch=None, metrics=None):

FILE: Diffusion_UKAN/inception-score-pytorch/inception_score.py
  function inception_score (line 17) | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits...
  class UnlabeledDataset (line 75) | class UnlabeledDataset(torch.utils.data.Dataset):
    method __init__ (line 76) | def __init__(self, folder, transform=None):
    method __len__ (line 81) | def __len__(self):
    method __getitem__ (line 84) | def __getitem__(self, idx):
  class IgnoreLabelDataset (line 93) | class IgnoreLabelDataset(torch.utils.data.Dataset):
    method __init__ (line 94) | def __init__(self, orig):
    method __getitem__ (line 97) | def __getitem__(self, index):
    method __len__ (line 100) | def __len__(self):

FILE: Seg_UKAN/archs.py
  class KANLayer (line 27) | class KANLayer(nn.Module):
    method __init__ (line 28) | def __init__(self, in_features, hidden_features=None, out_features=Non...
    method _init_weights (line 114) | def _init_weights(self, m):
    method forward (line 130) | def forward(self, x, H, W):
  class KANBlock (line 150) | class KANBlock(nn.Module):
    method __init__ (line 151) | def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm...
    method _init_weights (line 162) | def _init_weights(self, m):
    method forward (line 177) | def forward(self, x, H, W):
  class DWConv (line 183) | class DWConv(nn.Module):
    method __init__ (line 184) | def __init__(self, dim=768):
    method forward (line 188) | def forward(self, x, H, W):
  class DW_bn_relu (line 196) | class DW_bn_relu(nn.Module):
    method __init__ (line 197) | def __init__(self, dim=768):
    method forward (line 203) | def forward(self, x, H, W):
  class PatchEmbed (line 213) | class PatchEmbed(nn.Module):
    method __init__ (line 217) | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, e...
    method _init_weights (line 232) | def _init_weights(self, m):
    method forward (line 247) | def forward(self, x):
  class ConvLayer (line 256) | class ConvLayer(nn.Module):
    method __init__ (line 257) | def __init__(self, in_ch, out_ch):
    method forward (line 268) | def forward(self, input):
  class D_ConvLayer (line 271) | class D_ConvLayer(nn.Module):
    method __init__ (line 272) | def __init__(self, in_ch, out_ch):
    method forward (line 283) | def forward(self, input):
  class UKAN (line 288) | class UKAN(nn.Module):
    method __init__ (line 289) | def __init__(self, num_classes, input_channels=3, deep_supervision=Fal...
    method forward (line 339) | def forward(self, x):

FILE: Seg_UKAN/config.py
  function _update_config_from_file (line 175) | def _update_config_from_file(config, cfg_file):
  function update_config (line 190) | def update_config(config, args):
  function get_config (line 222) | def get_config(args):

FILE: Seg_UKAN/dataset.py
  class Dataset (line 9) | class Dataset(torch.utils.data.Dataset):
    method __init__ (line 10) | def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_...
    method __len__ (line 52) | def __len__(self):
    method __getitem__ (line 55) | def __getitem__(self, idx):

FILE: Seg_UKAN/kan.py
  class KANLinear (line 6) | class KANLinear(torch.nn.Module):
    method __init__ (line 7) | def __init__(
    method reset_parameters (line 56) | def reset_parameters(self):
    method b_splines (line 78) | def b_splines(self, x: torch.Tensor):
    method curve2coeff (line 113) | def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
    method scaled_spline_weight (line 146) | def scaled_spline_weight(self):
    method forward (line 153) | def forward(self, x: torch.Tensor):
    method update_grid (line 164) | def update_grid(self, x: torch.Tensor, margin=0.01):
    method regularization_loss (line 212) | def regularization_loss(self, regularize_activation=1.0, regularize_en...
  class KAN (line 235) | class KAN(torch.nn.Module):
    method __init__ (line 236) | def __init__(
    method forward (line 269) | def forward(self, x: torch.Tensor, update_grid=False):
    method regularization_loss (line 276) | def regularization_loss(self, regularize_activation=1.0, regularize_en...

FILE: Seg_UKAN/losses.py
  class BCEDiceLoss (line 13) | class BCEDiceLoss(nn.Module):
    method __init__ (line 14) | def __init__(self):
    method forward (line 17) | def forward(self, input, target):
  class LovaszHingeLoss (line 30) | class LovaszHingeLoss(nn.Module):
    method __init__ (line 31) | def __init__(self):
    method forward (line 34) | def forward(self, input, target):

FILE: Seg_UKAN/metrics.py
  function iou_score (line 9) | def iou_score(output, target):
  function dice_coef (line 31) | def dice_coef(output, target):
  function indicators (line 41) | def indicators(output, target):

FILE: Seg_UKAN/train.py
  function list_type (line 47) | def list_type(s):
  function parse_args (line 53) | def parse_args():
  function train (line 138) | def train(config, train_loader, model, criterion, optimizer):
  function validate (line 186) | def validate(config, val_loader, model, criterion):
  function seed_torch (line 231) | def seed_torch(seed=1029):
  function main (line 242) | def main():

FILE: Seg_UKAN/utils.py
  class qkv_transform (line 4) | class qkv_transform(nn.Conv1d):
  function str2bool (line 7) | def str2bool(v):
  function count_params (line 16) | def count_params(model):
  class AverageMeter (line 20) | class AverageMeter(object):
    method __init__ (line 23) | def __init__(self):
    method reset (line 26) | def reset(self):
    method update (line 32) | def update(self, val, n=1):

FILE: Seg_UKAN/val.py
  function parse_args (line 28) | def parse_args():
  function seed_torch (line 38) | def seed_torch(seed=1029):
  function main (line 49) | def main():
Condensed preview — 42 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (253K chars).
[
  {
    "path": "Diffusion_UKAN/Diffusion/Diffusion.py",
    "chars": 3370,
    "preview": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef extra"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model.py",
    "chars": 41630,
    "preview": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\n\n\nclas"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_ConvKan.py",
    "chars": 8321,
    "preview": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom D"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_UKAN_Hybrid.py",
    "chars": 27672,
    "preview": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom t"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_UMLP.py",
    "chars": 29295,
    "preview": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom t"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Train.py",
    "chars": 8809,
    "preview": "\nimport os\nfrom typing import Dict\nimport torch\nimport torch.optim as optim\nfrom tqdm import tqdm\nfrom torch.utils.data "
  },
  {
    "path": "Diffusion_UKAN/Diffusion/UNet.py",
    "chars": 10066,
    "preview": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\n\n\nclas"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/__init__.py",
    "chars": 66,
    "preview": "from .Diffusion import *\nfrom .UNet import *\nfrom .Train import *\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/__init__.py",
    "chars": 81,
    "preview": "from .kan import *\nfrom .fastkanconv import *\n# from .kan_convolutional import *\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/fastkanconv.py",
    "chars": 7611,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import List, Tuple, Union\n\n\nclass Polynom"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/kan.py",
    "chars": 10046,
    "preview": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        s"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/utils.py",
    "chars": 805,
    "preview": "import argparse\nimport torch.nn as nn\n\nclass qkv_transform(nn.Conv1d):\n    \"\"\"Conv1d for qkv_transform\"\"\"\n\ndef str2bool("
  },
  {
    "path": "Diffusion_UKAN/Main.py",
    "chars": 3670,
    "preview": "from Diffusion.Train import train, eval\nimport os\nimport argparse\nimport torch\nimport numpy as np\n\ndef main(model_config"
  },
  {
    "path": "Diffusion_UKAN/Main_Test.py",
    "chars": 3698,
    "preview": "from Diffusion.Train import train, eval, eval_tmp\nimport os\nimport argparse\nimport torch\ndef main(model_config = None):\n"
  },
  {
    "path": "Diffusion_UKAN/README.md",
    "chars": 3254,
    "preview": "# Diffusion UKAN (arxiv)\n\n> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxi"
  },
  {
    "path": "Diffusion_UKAN/Scheduler.py",
    "chars": 1314,
    "preview": "from torch.optim.lr_scheduler import _LRScheduler\n\nclass GradualWarmupScheduler(_LRScheduler):\n    def __init__(self, op"
  },
  {
    "path": "Diffusion_UKAN/data/readme.txt",
    "chars": 32,
    "preview": "download data.zip and unzip here"
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/LICENSE.md",
    "chars": 1055,
    "preview": "Copyright 2017 Shane T. Barratt\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this so"
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/README.md",
    "chars": 1717,
    "preview": "# Inception Score Pytorch\n\nPytorch was lacking code to calculate the Inception Score for GANs. This repository fills thi"
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/inception_score.py",
    "chars": 3945,
    "preview": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\nimport torch."
  },
  {
    "path": "Diffusion_UKAN/released_models/readme.txt",
    "chars": 43,
    "preview": "download released_models.zip and unzip here"
  },
  {
    "path": "Diffusion_UKAN/requirements.txt",
    "chars": 91,
    "preview": "pytorch-fid==0.30.0\ntorch==2.3.0\ntorchvision==0.18.0\ntqdm\ntimm==0.9.16\nscikit-image==0.23.1"
  },
  {
    "path": "Diffusion_UKAN/tools/resive_cvc.py",
    "chars": 1251,
    "preview": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\n\n# Define the sourc"
  },
  {
    "path": "Diffusion_UKAN/tools/resize_busi.py",
    "chars": 1266,
    "preview": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\n\n# Define the sourc"
  },
  {
    "path": "Diffusion_UKAN/tools/resize_glas.py",
    "chars": 1273,
    "preview": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\nimport random\n\n# De"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/busi.sh",
    "chars": 708,
    "preview": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_cvc"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/cvc.sh",
    "chars": 707,
    "preview": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_cvc"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/glas.sh",
    "chars": 709,
    "preview": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_gla"
  },
  {
    "path": "README.md",
    "chars": 8889,
    "preview": "# U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation\n\n:pushpin: This is an official PyTorch imple"
  },
  {
    "path": "Seg_UKAN/LICENSE",
    "chars": 1072,
    "preview": "MIT License\n\nCopyright (c) 2022 Jeya Maria Jose\n\nPermission is hereby granted, free of charge, to any person obtaining a"
  },
  {
    "path": "Seg_UKAN/archs.py",
    "chars": 13973,
    "preview": "import torch\nfrom torch import nn\nimport torch\nimport torchvision\nfrom torch import nn\nfrom torch.autograd import Variab"
  },
  {
    "path": "Seg_UKAN/config.py",
    "chars": 7355,
    "preview": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed "
  },
  {
    "path": "Seg_UKAN/dataset.py",
    "chars": 2519,
    "preview": "import os\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.utils.data\n\n\nclass Dataset(torch.utils.data.Dataset):"
  },
  {
    "path": "Seg_UKAN/environment.yml",
    "chars": 1069,
    "preview": "name: ukan\nchannels:\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=4.5=1_gnu\n  - ca-certificat"
  },
  {
    "path": "Seg_UKAN/kan.py",
    "chars": 10046,
    "preview": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        s"
  },
  {
    "path": "Seg_UKAN/losses.py",
    "chars": 1036,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ntry:\n    from LovaszSoftmax.pytorch.lovasz_losses im"
  },
  {
    "path": "Seg_UKAN/metrics.py",
    "chars": 1582,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom medpy.metric.binary import jc, dc, hd, hd95, recal"
  },
  {
    "path": "Seg_UKAN/requirements.txt",
    "chars": 379,
    "preview": "addict==2.4.0\ndataclasses\npandas\npyyaml\nalbumentations\ntqdm\ntensorboardX\n# mmcv-full==1.2.7\nnumpy\nopencv-python\nperceptu"
  },
  {
    "path": "Seg_UKAN/scripts.sh",
    "chars": 650,
    "preview": "dataset=busi\ninput_size=256\npython train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_s"
  },
  {
    "path": "Seg_UKAN/train.py",
    "chars": 15409,
    "preview": "import argparse\nimport os\nfrom collections import OrderedDict\nfrom glob import glob\nimport random\nimport numpy as np\n\nim"
  },
  {
    "path": "Seg_UKAN/utils.py",
    "chars": 805,
    "preview": "import argparse\nimport torch.nn as nn\n\nclass qkv_transform(nn.Conv1d):\n    \"\"\"Conv1d for qkv_transform\"\"\"\n\ndef str2bool("
  },
  {
    "path": "Seg_UKAN/val.py",
    "chars": 4879,
    "preview": "#! /data/cxli/miniconda3/envs/th200/bin/python\nimport argparse\nimport os\nfrom glob import glob\nimport random\nimport nump"
  }
]

About this extraction

This page contains the full source code of the CUHK-AIM-Group/U-KAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 42 files (236.5 KB), approximately 64.2k tokens, and a symbol index with 405 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!