[
  {
    "path": "Diffusion_UKAN/Diffusion/Diffusion.py",
    "content": "\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom tqdm import tqdm\n\n\ndef extract(v, t, x_shape):\n    \"\"\"\n    Extract some coefficients at specified timesteps, then reshape to\n    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.\n    \"\"\"\n    device = t.device\n    out = torch.gather(v, index=t, dim=0).float().to(device)\n    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))\n\n\nclass GaussianDiffusionTrainer(nn.Module):\n    def __init__(self, model, beta_1, beta_T, T):\n        super().__init__()\n\n        self.model = model\n        self.T = T\n\n        self.register_buffer(\n            'betas', torch.linspace(beta_1, beta_T, T).double())\n        alphas = 1. - self.betas\n        alphas_bar = torch.cumprod(alphas, dim=0)\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\n            'sqrt_alphas_bar', torch.sqrt(alphas_bar))\n        self.register_buffer(\n            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))\n\n    def forward(self, x_0):\n        \"\"\"\n        Algorithm 1.\n        \"\"\"\n        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)\n        noise = torch.randn_like(x_0)\n        x_t = (\n            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +\n            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)\n        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')\n        return loss\n\n\nclass GaussianDiffusionSampler(nn.Module):\n    def __init__(self, model, beta_1, beta_T, T):\n        super().__init__()\n\n        self.model = model\n        self.T = T\n\n        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())\n        alphas = 1. - self.betas\n        alphas_bar = torch.cumprod(alphas, dim=0)\n        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]\n\n        self.register_buffer('coeff1', torch.sqrt(1. / alphas))\n        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))\n\n        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))\n\n    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):\n        assert x_t.shape == eps.shape\n        return (\n            extract(self.coeff1, t, x_t.shape) * x_t -\n            extract(self.coeff2, t, x_t.shape) * eps\n        )\n\n    def p_mean_variance(self, x_t, t):\n        # below: only log_variance is used in the KL computations\n        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])\n        var = extract(var, t, x_t.shape)\n\n        eps = self.model(x_t, t)\n        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)\n\n        return xt_prev_mean, var\n\n    def forward(self, x_T):\n        \"\"\"\n        Algorithm 2.\n        \"\"\"\n        x_t = x_T\n        print('Start Sampling')\n        for time_step in tqdm(reversed(range(self.T))):\n            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step\n            mean, var= self.p_mean_variance(x_t=x_t, t=t)\n            # no noise when t == 0\n            if time_step > 0:\n                noise = torch.randn_like(x_t)\n            else:\n                noise = 0\n            x_t = mean + torch.sqrt(var) * noise\n            assert torch.isnan(x_t).int().sum() == 0, \"nan in tensor.\"\n        x_0 = x_t\n        return torch.clip(x_0, -1, 1)   \n\n\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model.py",
    "content": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\n\n\nclass Swish(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\nclass TimeEmbedding(nn.Module):\n    def __init__(self, T, d_model, dim):\n        assert d_model % 2 == 0\n        super().__init__()\n        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n        emb = torch.exp(-emb)\n        pos = torch.arange(T).float()\n        emb = pos[:, None] * emb[None, :]\n        assert list(emb.shape) == [T, d_model // 2]\n        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n        assert list(emb.shape) == [T, d_model // 2, 2]\n        emb = emb.view(T, d_model)\n\n        self.timembedding = nn.Sequential(\n            nn.Embedding.from_pretrained(emb),\n            nn.Linear(d_model, dim),\n            Swish(),\n            nn.Linear(dim, dim),\n        )\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, nn.Linear):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n\n    def forward(self, t):\n        emb = self.timembedding(t)\n        return emb\n\n\nclass DownSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        x = self.main(x)\n        return x\n\n\nclass UpSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        _, _, H, W = x.shape\n        x = F.interpolate(\n            x, scale_factor=2, mode='nearest')\n        x = self.main(x)\n        return x\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.group_norm = nn.GroupNorm(32, in_ch)\n        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.initialize()\n\n    def initialize(self):\n        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n            init.xavier_uniform_(module.weight)\n            init.zeros_(module.bias)\n        init.xavier_uniform_(self.proj.weight, gain=1e-5)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        h = self.group_norm(x)\n        q = self.proj_q(h)\n        k = self.proj_k(h)\n        v = self.proj_v(h)\n\n        q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n        k = k.view(B, C, H * W)\n        w = torch.bmm(q, k) * (int(C) ** (-0.5))\n        assert list(w.shape) == [B, H * W, H * W]\n        w = F.softmax(w, dim=-1)\n\n        v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n        h = torch.bmm(w, v)\n        assert list(h.shape) == [B, H * W, C]\n        h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n        h = self.proj(h)\n\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n        super().__init__()\n        self.block1 = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(tdim, out_ch),\n        )\n        self.block2 = nn.Sequential(\n            nn.GroupNorm(32, out_ch),\n            Swish(),\n            nn.Dropout(dropout),\n            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n        )\n        if in_ch != out_ch:\n            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n        else:\n            self.shortcut = nn.Identity()\n        if attn:\n            self.attn = AttnBlock(out_ch)\n        else:\n            self.attn = nn.Identity()\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n\n    def forward(self, x, temb):\n        h = self.block1(x)\n        h += self.temb_proj(temb)[:, :, None, None]\n        h = self.block2(h)\n\n        h = h + self.shortcut(x)\n        h = self.attn(h)\n        return h\n        # return x\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        enable_standalone_scale_spline=True,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KANLinear, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        h = (grid_range[1] - grid_range[0]) / grid_size\n        grid = (\n            (\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n                + grid_range[0]\n            )\n            .expand(in_features, -1)\n            .contiguous()\n        )\n        self.register_buffer(\"grid\", grid)\n\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n        self.spline_weight = torch.nn.Parameter(\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\n        )\n        if enable_standalone_scale_spline:\n            self.spline_scaler = torch.nn.Parameter(\n                torch.Tensor(out_features, in_features)\n            )\n\n        self.scale_noise = scale_noise\n        self.scale_base = scale_base\n        self.scale_spline = scale_spline\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n        self.base_activation = base_activation()\n        self.grid_eps = grid_eps\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n        with torch.no_grad():\n            noise = (\n                (\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n                    - 1 / 2\n                )\n                * self.scale_noise\n                / self.grid_size\n            )\n            self.spline_weight.data.copy_(\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n                * self.curve2coeff(\n                    self.grid.T[self.spline_order : -self.spline_order],\n                    noise,\n                )\n            )\n            if self.enable_standalone_scale_spline:\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n\n    def b_splines(self, x: torch.Tensor):\n        \"\"\"\n        Compute the B-spline bases for the given input tensor.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n\n        Returns:\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        grid: torch.Tensor = (\n            self.grid\n        )  # (in_features, grid_size + 2 * spline_order + 1)\n        x = x.unsqueeze(-1)\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n        for k in range(1, self.spline_order + 1):\n            bases = (\n                (x - grid[:, : -(k + 1)])\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n                * bases[:, :, :-1]\n            ) + (\n                (grid[:, k + 1 :] - x)\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n                * bases[:, :, 1:]\n            )\n\n        assert bases.size() == (\n            x.size(0),\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return bases.contiguous()\n\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Compute the coefficients of the curve that interpolates the given points.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n\n        Returns:\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\n\n        A = self.b_splines(x).transpose(\n            0, 1\n        )  # (in_features, batch_size, grid_size + spline_order)\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n        solution = torch.linalg.lstsq(\n            A, B\n        ).solution  # (in_features, grid_size + spline_order, out_features)\n        result = solution.permute(\n            2, 0, 1\n        )  # (out_features, in_features, grid_size + spline_order)\n\n        assert result.size() == (\n            self.out_features,\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return result.contiguous()\n\n    @property\n    def scaled_spline_weight(self):\n        return self.spline_weight * (\n            self.spline_scaler.unsqueeze(-1)\n            if self.enable_standalone_scale_spline\n            else 1.0\n        )\n\n    def forward(self, x: torch.Tensor):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        \n\n        base_output = F.linear(self.base_activation(x), self.base_weight)\n        spline_output = F.linear(\n            self.b_splines(x).view(x.size(0), -1),\n            self.scaled_spline_weight.view(self.out_features, -1),\n        )\n        return base_output + spline_output\n\n    @torch.no_grad()\n    def update_grid(self, x: torch.Tensor, margin=0.01):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        batch = x.size(0)\n\n        splines = self.b_splines(x)  # (batch, in, coeff)\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n        unreduced_spline_output = unreduced_spline_output.permute(\n            1, 0, 2\n        )  # (batch, in, out)\n\n        # sort each channel individually to collect data distribution\n        x_sorted = torch.sort(x, dim=0)[0]\n        grid_adaptive = x_sorted[\n            torch.linspace(\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n            )\n        ]\n\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n        grid_uniform = (\n            torch.arange(\n                self.grid_size + 1, dtype=torch.float32, device=x.device\n            ).unsqueeze(1)\n            * uniform_step\n            + x_sorted[0]\n            - margin\n        )\n\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n        grid = torch.concatenate(\n            [\n                grid[:1]\n                - uniform_step\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n                grid,\n                grid[-1:]\n                + uniform_step\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n            ],\n            dim=0,\n        )\n\n        self.grid.copy_(grid.T)\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        \"\"\"\n        Compute the regularization loss.\n\n        This is a dumb simulation of the original L1 regularization as stated in the\n        paper, since the original one requires computing absolutes and entropy from the\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n        behind the F.linear function if we want an memory efficient implementation.\n\n        The L1 regularization is now computed as mean absolute value of the spline\n        weights. The authors implementation also includes this term in addition to the\n        sample-based regularization.\n        \"\"\"\n        l1_fake = self.spline_weight.abs().mean(-1)\n        regularization_loss_activation = l1_fake.sum()\n        p = l1_fake / regularization_loss_activation\n        regularization_loss_entropy = -torch.sum(p * p.log())\n        return (\n            regularize_activation * regularization_loss_activation\n            + regularize_entropy * regularization_loss_entropy\n        )\n\nclass Ukan(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks1 = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.middleblocks2 = nn.ModuleList([\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n\n\n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=torch.nn.SiLU\n        grid_eps=0.02\n        grid_range=[-1, 1]\n\n\n        kan_c=512\n        self.fc1 = KANLinear(\n                    kan_c,\n                    kan_c *2,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n        # print(now_ch)\n        # self.dwconv = DWConv(kan_c *2)\n        self.act = nn.GELU()\n        self.fc2 = KANLinear(\n                    kan_c *2,\n                    kan_c,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n        \n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        # for layer in self.middleblocks1:\n            # h = layer(h, temb)\n        B, C, H, W = h.shape\n        # transform  B, C, H, W into B*H*W, C\n        h = h.permute(0, 2, 3, 1).reshape(B*H*W, C)\n        h =self.fc2( self.fc1(h))\n        # transform B*H*W, C  into  B, C, H, W\n        h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)\n\n        # for layer in self.middleblocks2:\n        #     h = layer(h, temb)\n        ### Stage 3\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\nclass DW_bn_relu(nn.Module):\n    def __init__(self, dim=768):\n        super(DW_bn_relu, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n        self.bn = nn.BatchNorm2d(dim)\n        self.relu = nn.ReLU()\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = self.bn(x)\n        x = self.relu(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass kan(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.dim = in_features\n        \n\n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=torch.nn.SiLU\n        grid_eps=0.02\n        grid_range=[-1, 1]\n\n        # self.fc1 = nn.Linear(in_features, hidden_features)\n        self.fc1 = KANLinear(\n                    in_features,\n                    hidden_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n        \n        # self.fc2 = nn.Linear(hidden_features, out_features)\n        self.fc2 = KANLinear(\n                    hidden_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n\n        self.fc3 = KANLinear(\n                    hidden_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )   \n\n        # ##############################################\n        self.version = 4 # version 4 hard code ���ܶ�����\n        \n        # ##############################################\n        if self.version == 1:\n            self.dwconv_1 = DWConv(hidden_features)\n            self.act_1 = act_layer()\n\n            self.dwconv_2 = DWConv(hidden_features)\n            self.act_2 = act_layer()\n\n            self.dwconv_3 = DWConv(hidden_features)\n            self.act_3 = act_layer()\n\n            self.dwconv_4 = DWConv(hidden_features)\n            self.act_4 = act_layer()\n        elif self.version == 2:\n            self.dwconv_1 = DWConv(hidden_features)\n            self.act_1 = act_layer()\n\n            self.dwconv_2 = DWConv(hidden_features)\n            self.act_2 = act_layer()\n\n            self.dwconv_3 = DWConv(hidden_features)\n            self.act_3 = act_layer()\n\n        elif self.version == 3:\n            self.dwconv_1 = DW_bn_relu(hidden_features)\n            self.dwconv_2 = DW_bn_relu(hidden_features)\n            self.dwconv_3 = DW_bn_relu(hidden_features)\n        elif self.version == 4:\n            self.dwconv_1 = DW_bn_relu(hidden_features)\n            self.dwconv_2 = DW_bn_relu(hidden_features)\n            self.dwconv_3 = DW_bn_relu(hidden_features)\n        elif self.version == 5:\n            self.dwconv_1 = DWConv(hidden_features)\n            self.act_1 = act_layer()\n\n            self.dwconv_2 = DWConv(hidden_features)\n            self.act_2 = act_layer()\n\n            self.dwconv_3 = DWConv(hidden_features)\n            self.act_3 = act_layer()\n        elif self.version == 6:\n            self.dwconv_1 = DWConv(hidden_features)\n            self.act_1 = act_layer()\n\n            self.dwconv_2 = DWConv(hidden_features)\n            self.act_2 = act_layer()\n\n            self.dwconv_3 = DWConv(hidden_features)\n            self.act_3 = act_layer()\n        elif self.version == 7:\n            self.dwconv_1 = DWConv(hidden_features)\n            self.act_1 = act_layer()\n\n            self.dwconv_2 = DWConv(hidden_features)\n            self.act_2 = act_layer()\n\n            self.dwconv_3 = DWConv(hidden_features)\n            self.act_3 = act_layer()\n        elif self.version == 8:\n            self.dwconv_1 = DW_bn_relu(hidden_features)\n            self.dwconv_2 = DW_bn_relu(hidden_features)\n            self.dwconv_3 = DW_bn_relu(hidden_features)\n\n    \n        self.drop = nn.Dropout(drop)\n\n        self.shift_size = shift_size\n        self.pad = shift_size // 2\n\n        \n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    \n\n    def forward(self, x, H, W):\n        # pdb.set_trace()\n        B, N, C = x.shape\n\n        if self.version == 1:\n            x = self.dwconv_1(x, H, W)\n            x = self.act_1(x) \n\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_2(x, H, W)\n            x = self.act_2(x) \n\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_3(x, H, W)\n            x = self.act_3(x) \n\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_4(x, H, W)\n            x = self.act_4(x) \n        elif self.version == 2:\n            \n            x = self.dwconv_1(x, H, W)\n            x = self.act_1(x) \n\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_2(x, H, W)\n            x = self.act_2(x) \n\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_3(x, H, W)\n            x = self.act_3(x) \n\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n        elif self.version == 3:\n            x = self.dwconv_1(x, H, W)\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_2(x, H, W)\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_3(x, H, W)\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n        elif self.version == 4:\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_1(x, H, W)\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_2(x, H, W)\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_3(x, H, W)\n        elif self.version == 5:\n\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_1(x, H, W)\n            x = self.act_1(x) \n\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_2(x, H, W)\n            x = self.act_2(x) \n\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_3(x, H, W)\n            x = self.act_3(x) \n        elif self.version == 6:\n\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_1(x, H, W)\n            x = self.act_1(x) \n\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_2(x, H, W)\n            x = self.act_2(x) \n\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_3(x, H, W)\n            x = self.act_3(x) \n        elif self.version == 7:\n\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_1(x, H, W)\n            x = self.act_1(x) \n            x = self.drop(x)\n\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_2(x, H, W)\n            x = self.act_2(x) \n            x = self.drop(x)\n\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n\n            x = self.dwconv_3(x, H, W)\n            x = self.act_3(x) \n            x = self.drop(x)\n        elif self.version == 8:\n            x = self.dwconv_1(x, H, W)\n            x = self.drop(x)\n            x = self.fc1(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_2(x, H, W)\n            x = self.drop(x)\n            x = self.fc2(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n            x = self.dwconv_3(x, H, W)\n            x = self.drop(x)\n            x = self.fc3(x.reshape(B*N,C))\n            x = x.reshape(B,N,C).contiguous()\n        return x\n\n\nclass Ukan_v3(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks1 = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.middleblocks2 = nn.ModuleList([\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n\n\n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=torch.nn.SiLU\n        grid_eps=0.02\n        grid_range=[-1, 1]\n        kan_c=512\n        self.kan1 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4)\n        self.kan2 = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=4)\n\n        \n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        # for layer in self.middleblocks1:\n            # h = layer(h, temb)\n        B, C, H, W = h.shape\n        # transform  B, C, H, W into B*H*W, C\n        h = h.reshape(B,C, H*W).permute(0, 2, 1)\n        h = self.kan1(h, H, W)\n        h = self.kan2(h, H, W)\n        h = h.permute(0, 2, 1).reshape(B, C, H, W)\n\n\n        # h =self.fc2( self.fc1(h))\n        # transform B*H*W, C  into  B, C, H, W\n        # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)\n        # B, N, C = x.shape\n\n        # for layer in self.middleblocks2:\n        #     h = layer(h, temb)\n        ### Stage 3\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\nclass Ukan_v2(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout,version=4):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks1 = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.middleblocks2 = nn.ModuleList([\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n\n\n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=torch.nn.SiLU\n        grid_eps=0.02\n        grid_range=[-1, 1]\n        kan_c=512\n        self.kan = kan(in_features=kan_c, hidden_features=kan_c, act_layer=nn.GELU, drop=0., version=version)\n\n        \n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        # for layer in self.middleblocks1:\n            # h = layer(h, temb)\n        B, C, H, W = h.shape\n        # transform  B, C, H, W into B*H*W, C\n        h = h.reshape(B,C, H*W).permute(0, 2, 1)\n        h = self.kan(h, H, W)\n        h = h.permute(0, 2, 1).reshape(B, C, H, W)\n\n\n        # h =self.fc2( self.fc1(h))\n        # transform B*H*W, C  into  B, C, H, W\n        # h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)\n        # B, N, C = x.shape\n\n        # for layer in self.middleblocks2:\n        #     h = layer(h, temb)\n        ### Stage 3\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\nclass UNet(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n        \n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        for layer in self.middleblocks:\n            h = layer(h, temb)\n\n\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\n\n\n\n\nclass UNet_MLP(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks1 = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.middleblocks2 = nn.ModuleList([\n            # ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n\n        kan_c=512\n        self.fc1 = nn.Linear(\n                    kan_c,\n                    kan_c *2,\n                )\n        self.act = nn.GELU()\n\n        self.fc2 = nn.Linear(\n                    kan_c *2,\n                    kan_c,\n                )\n        \n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        # for layer in self.middleblocks1:\n            # h = layer(h, temb)\n        B, C, H, W = h.shape\n        # transform  B, C, H, W into B*H*W, C\n        h = h.permute(0, 2, 3, 1).reshape(B*H*W, C)\n        h =self.fc2(self.act(self.fc1(h)))\n        # transform B*H*W, C  into  B, C, H, W\n        h = h.reshape(B, H, W, C).permute(0, 3, 1, 2)\n\n        # for layer in self.middleblocks2:\n        #     h = layer(h, temb)\n        ### Stage 3\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\nif __name__ == '__main__':\n    batch_size = 8\n    model = UNet(\n        T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],\n        num_res_blocks=2, dropout=0.1)\n    x = torch.randn(batch_size, 3, 32, 32)\n    t = torch.randint(1000, (batch_size, ))\n    y = model(x, t)\n    print(y.shape)\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_ConvKan.py",
    "content": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom Diffusion.kan_utils.fastkanconv import FastKANConvLayer, SplineConv2D\n\nclass Swish(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\nclass TimeEmbedding(nn.Module):\n    def __init__(self, T, d_model, dim):\n        assert d_model % 2 == 0\n        super().__init__()\n        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n        emb = torch.exp(-emb)\n        pos = torch.arange(T).float()\n        emb = pos[:, None] * emb[None, :]\n        assert list(emb.shape) == [T, d_model // 2]\n        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n        assert list(emb.shape) == [T, d_model // 2, 2]\n        emb = emb.view(T, d_model)\n\n        self.timembedding = nn.Sequential(\n            nn.Embedding.from_pretrained(emb),\n            nn.Linear(d_model, dim),\n            Swish(),\n            nn.Linear(dim, dim),\n        )\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, nn.Linear):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n\n    def forward(self, t):\n        emb = self.timembedding(t)\n        return emb\n\n\nclass DownSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n        self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=2, padding=1)\n        # self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        x = self.main(x)\n        return x\n\n\nclass UpSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        # self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n        self.main = FastKANConvLayer(in_ch, in_ch, 3, stride=1, padding=1)\n        # self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        _, _, H, W = x.shape\n        x = F.interpolate(\n            x, scale_factor=2, mode='nearest')\n        x = self.main(x)\n        return x\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.group_norm = nn.GroupNorm(32, in_ch)\n        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.initialize()\n\n    def initialize(self):\n        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n            init.xavier_uniform_(module.weight)\n            init.zeros_(module.bias)\n        init.xavier_uniform_(self.proj.weight, gain=1e-5)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        h = self.group_norm(x)\n        q = self.proj_q(h)\n        k = self.proj_k(h)\n        v = self.proj_v(h)\n\n        q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n        k = k.view(B, C, H * W)\n        w = torch.bmm(q, k) * (int(C) ** (-0.5))\n        assert list(w.shape) == [B, H * W, H * W]\n        w = F.softmax(w, dim=-1)\n\n        v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n        h = torch.bmm(w, v)\n        assert list(h.shape) == [B, H * W, C]\n        h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n        h = self.proj(h)\n\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n        super().__init__()\n        self.block1 = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            # Swish(),\n            # nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n            FastKANConvLayer(in_ch, out_ch, 3, stride=1, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(tdim, out_ch),\n        )\n        self.block2 = nn.Sequential(\n            nn.GroupNorm(32, out_ch),\n            # Swish(),\n            nn.Dropout(dropout),\n            # nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n            FastKANConvLayer(out_ch, out_ch, 3, stride=1, padding=1),\n        )\n        if in_ch != out_ch:\n            # self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n            self.shortcut = FastKANConvLayer(in_ch, out_ch, 1, stride=1, padding=0)\n        else:\n            self.shortcut = nn.Identity()\n        if attn:\n            self.attn = AttnBlock(out_ch)\n        else:\n            self.attn = nn.Identity()\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)) and not isinstance(module, (SplineConv2D)):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n        # init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n\n    def forward(self, x, temb):\n        h = self.block1(x)\n        h += self.temb_proj(temb)[:, :, None, None]\n        h = self.block2(h)\n        h = h + self.shortcut(x)\n        h = self.attn(h)\n        return h\n        # return x\n\nclass UNet_ConvKan(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n    \n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            # Swish(),\n            # nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n            FastKANConvLayer(now_ch, 3, 3, stride=1, padding=1)\n        )\n        \n        # self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle\n        for layer in self.middleblocks:\n            h = layer(h, temb)\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\nif __name__ == '__main__':\n    batch_size = 8\n    model = UNet_ConvKan(\n        T=1000, ch=128, ch_mult=[1, 2, 2, 2], attn=[1],\n        num_res_blocks=2, dropout=0.1)\n    x = torch.randn(batch_size, 3, 32, 32)\n    t = torch.randint(1000, (batch_size, ))\n    y = model(x, t)\n    print(y.shape)\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_UKAN_Hybrid.py",
    "content": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        enable_standalone_scale_spline=True,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KANLinear, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        h = (grid_range[1] - grid_range[0]) / grid_size\n        grid = (\n            (\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n                + grid_range[0]\n            )\n            .expand(in_features, -1)\n            .contiguous()\n        )\n        self.register_buffer(\"grid\", grid)\n\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n        self.spline_weight = torch.nn.Parameter(\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\n        )\n        if enable_standalone_scale_spline:\n            self.spline_scaler = torch.nn.Parameter(\n                torch.Tensor(out_features, in_features)\n            )\n\n        self.scale_noise = scale_noise\n        self.scale_base = scale_base\n        self.scale_spline = scale_spline\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n        self.base_activation = base_activation()\n        self.grid_eps = grid_eps\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n        with torch.no_grad():\n            noise = (\n                (\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n                    - 1 / 2\n                )\n                * self.scale_noise\n                / self.grid_size\n            )\n            self.spline_weight.data.copy_(\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n                * self.curve2coeff(\n                    self.grid.T[self.spline_order : -self.spline_order],\n                    noise,\n                )\n            )\n            if self.enable_standalone_scale_spline:\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n\n    def b_splines(self, x: torch.Tensor):\n        \"\"\"\n        Compute the B-spline bases for the given input tensor.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n\n        Returns:\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        grid: torch.Tensor = (\n            self.grid\n        )  # (in_features, grid_size + 2 * spline_order + 1)\n        x = x.unsqueeze(-1)\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n        for k in range(1, self.spline_order + 1):\n            bases = (\n                (x - grid[:, : -(k + 1)])\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n                * bases[:, :, :-1]\n            ) + (\n                (grid[:, k + 1 :] - x)\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n                * bases[:, :, 1:]\n            )\n\n        assert bases.size() == (\n            x.size(0),\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return bases.contiguous()\n\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Compute the coefficients of the curve that interpolates the given points.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n\n        Returns:\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\n\n        A = self.b_splines(x).transpose(\n            0, 1\n        )  # (in_features, batch_size, grid_size + spline_order)\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n        solution = torch.linalg.lstsq(\n            A, B\n        ).solution  # (in_features, grid_size + spline_order, out_features)\n        result = solution.permute(\n            2, 0, 1\n        )  # (out_features, in_features, grid_size + spline_order)\n\n        assert result.size() == (\n            self.out_features,\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return result.contiguous()\n\n    @property\n    def scaled_spline_weight(self):\n        return self.spline_weight * (\n            self.spline_scaler.unsqueeze(-1)\n            if self.enable_standalone_scale_spline\n            else 1.0\n        )\n\n    def forward(self, x: torch.Tensor):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        base_output = F.linear(self.base_activation(x), self.base_weight)\n        spline_output = F.linear(\n            self.b_splines(x).view(x.size(0), -1),\n            self.scaled_spline_weight.view(self.out_features, -1),\n        )\n        return base_output + spline_output\n\n    @torch.no_grad()\n    def update_grid(self, x: torch.Tensor, margin=0.01):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        batch = x.size(0)\n\n        splines = self.b_splines(x)  # (batch, in, coeff)\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n        unreduced_spline_output = unreduced_spline_output.permute(\n            1, 0, 2\n        )  # (batch, in, out)\n\n        # sort each channel individually to collect data distribution\n        x_sorted = torch.sort(x, dim=0)[0]\n        grid_adaptive = x_sorted[\n            torch.linspace(\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n            )\n        ]\n\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n        grid_uniform = (\n            torch.arange(\n                self.grid_size + 1, dtype=torch.float32, device=x.device\n            ).unsqueeze(1)\n            * uniform_step\n            + x_sorted[0]\n            - margin\n        )\n\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n        grid = torch.concatenate(\n            [\n                grid[:1]\n                - uniform_step\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n                grid,\n                grid[-1:]\n                + uniform_step\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n            ],\n            dim=0,\n        )\n\n        self.grid.copy_(grid.T)\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        \"\"\"\n        Compute the regularization loss.\n\n        This is a dumb simulation of the original L1 regularization as stated in the\n        paper, since the original one requires computing absolutes and entropy from the\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n        behind the F.linear function if we want an memory efficient implementation.\n\n        The L1 regularization is now computed as mean absolute value of the spline\n        weights. The authors implementation also includes this term in addition to the\n        sample-based regularization.\n        \"\"\"\n        l1_fake = self.spline_weight.abs().mean(-1)\n        regularization_loss_activation = l1_fake.sum()\n        p = l1_fake / regularization_loss_activation\n        regularization_loss_entropy = -torch.sum(p * p.log())\n        return (\n            regularize_activation * regularization_loss_activation\n            + regularize_entropy * regularization_loss_entropy\n        )\n\n\nclass KAN(torch.nn.Module):\n    def __init__(\n        self,\n        layers_hidden,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KAN, self).__init__()\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        self.layers = torch.nn.ModuleList()\n        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\n            self.layers.append(\n                KANLinear(\n                    in_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n            )\n\n    def forward(self, x: torch.Tensor, update_grid=False):\n        for layer in self.layers:\n            if update_grid:\n                layer.update_grid(x)\n            x = layer(x)\n        return x\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        return sum(\n            layer.regularization_loss(regularize_activation, regularize_entropy)\n            for layer in self.layers\n        )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)\n\n\ndef shift(dim):\n            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]\n            x_cat = torch.cat(x_shift, 1)\n            x_cat = torch.narrow(x_cat, 2, self.pad, H)\n            x_cat = torch.narrow(x_cat, 3, self.pad, W)\n            return x_cat\n\n\nclass OverlapPatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,\n                              padding=(patch_size[0] // 2, patch_size[1] // 2))\n        self.norm = nn.LayerNorm(embed_dim)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n\n        return x, H, W\n\n\nclass Swish(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\ndef swish(x):\n    \n    return x * torch.sigmoid(x)\n\n\nclass TimeEmbedding(nn.Module):\n    def __init__(self, T, d_model, dim):\n        assert d_model % 2 == 0\n        super().__init__()\n        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n        emb = torch.exp(-emb)\n        pos = torch.arange(T).float()\n        emb = pos[:, None] * emb[None, :]\n        assert list(emb.shape) == [T, d_model // 2]\n        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n        assert list(emb.shape) == [T, d_model // 2, 2]\n        emb = emb.view(T, d_model)\n\n        self.timembedding = nn.Sequential(\n            nn.Embedding.from_pretrained(emb),\n            nn.Linear(d_model, dim),\n            Swish(),\n            nn.Linear(dim, dim),\n        )\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, nn.Linear):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n\n    def forward(self, t):\n        emb = self.timembedding(t)\n        return emb\n\n\nclass DownSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        x = self.main(x)\n        return x\n\n\nclass UpSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        _, _, H, W = x.shape\n        x = F.interpolate(\n            x, scale_factor=2, mode='nearest')\n        x = self.main(x)\n        return x\n    \nclass kan(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.dim = in_features\n        \n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=Swish\n        grid_eps=0.02\n        grid_range=[-1, 1]\n\n        self.fc1 = KANLinear(\n                    in_features,\n                    hidden_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    \n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = self.fc1(x.reshape(B*N,C))\n        x = x.reshape(B,N,C).contiguous()\n\n        return x\n\nclass shiftedBlock(nn.Module):\n    def __init__(self, dim,  mlp_ratio=4.,drop_path=0.,norm_layer=nn.LayerNorm):\n        super().__init__()\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, dim),\n        )\n\n        self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W, temb):\n\n        temb = self.temb_proj(temb)\n        x = self.drop_path(self.kan(self.norm2(x), H, W))\n        x = x + temb.unsqueeze(1)\n\n        return x\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass DW_bn_relu(nn.Module):\n    def __init__(self, dim=768):\n        super(DW_bn_relu, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n        self.bn = nn.GroupNorm(32, dim)\n        # self.relu = Swish()\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = self.bn(x)\n        x = swish(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass SingleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(SingleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n        )\n\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass DoubleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(DoubleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n            nn.GroupNorm(32, h_ch),\n            Swish(),\n            nn.Conv2d(h_ch, h_ch, 3, padding=1),\n            nn.GroupNorm(32, h_ch),\n            Swish()\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass D_SingleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(D_SingleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.GroupNorm(32,in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass D_DoubleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(D_DoubleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, in_ch, 3, padding=1),\n            nn.GroupNorm(32,in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n             nn.GroupNorm(32,h_ch),\n            Swish()\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input,temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.group_norm = nn.GroupNorm(32, in_ch)\n        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.initialize()\n\n    def initialize(self):\n        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n            init.xavier_uniform_(module.weight)\n            init.zeros_(module.bias)\n        init.xavier_uniform_(self.proj.weight, gain=1e-5)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        h = self.group_norm(x)\n        q = self.proj_q(h)\n        k = self.proj_k(h)\n        v = self.proj_v(h)\n\n        q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n        k = k.view(B, C, H * W)\n        w = torch.bmm(q, k) * (int(C) ** (-0.5))\n        assert list(w.shape) == [B, H * W, H * W]\n        w = F.softmax(w, dim=-1)\n\n        v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n        h = torch.bmm(w, v)\n        assert list(h.shape) == [B, H * W, C]\n        h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n        h = self.proj(h)\n\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):\n        super().__init__()\n        self.block1 = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(tdim, h_ch),\n        )\n        self.block2 = nn.Sequential(\n            nn.GroupNorm(32, h_ch),\n            Swish(),\n            nn.Dropout(dropout),\n            nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1),\n        )\n        if in_ch != h_ch:\n            self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0)\n        else:\n            self.shortcut = nn.Identity()\n        if attn:\n            self.attn = AttnBlock(h_ch)\n        else:\n            self.attn = nn.Identity()\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n\n    def forward(self, x, temb):\n        h = self.block1(x)\n        h += self.temb_proj(temb)[:, :, None, None]\n        h = self.block2(h)\n\n        h = h + self.shortcut(x)\n        h = self.attn(h)\n        return h\n\n\nclass UKan_Hybrid(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n        attn = []\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record hput channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            h_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, h_ch=h_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = h_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            h_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = h_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n        # \n        embed_dims = [256, 320, 512]\n        norm_layer = nn.LayerNorm\n        dpr = [0.0, 0.0, 0.0]\n        self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])\n        self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])\n\n        self.norm3 = norm_layer(embed_dims[1])\n        self.norm4 = norm_layer(embed_dims[2])\n        self.dnorm3 = norm_layer(embed_dims[1])\n\n        self.kan_block1 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[1],  mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)])\n\n        self.kan_block2 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[2],  mlp_ratio=1, drop_path=dpr[1], norm_layer=norm_layer)])\n\n        self.kan_dblock1 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[1], mlp_ratio=1, drop_path=dpr[0], norm_layer=norm_layer)])\n\n        self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1])  \n        self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0])  \n\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n    \n        t3 = h\n\n        B = x.shape[0]\n        h, H, W = self.patch_embed3(h)\n \n        for i, blk in enumerate(self.kan_block1):\n            h = blk(h, H, W, temb)\n        h = self.norm3(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        t4 = h\n\n        h, H, W= self.patch_embed4(h)\n        for i, blk in enumerate(self.kan_block2):\n            h = blk(h, H, W, temb)\n        h = self.norm4(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        ### Stage 4\n        h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear'))\n\n        h = torch.add(h, t4)\n\n        _, _, H, W = h.shape\n        h = h.flatten(2).transpose(1,2)\n        for i, blk in enumerate(self.kan_dblock1):\n            h = blk(h, H, W, temb)\n\n            \n        ### Stage 3\n        h = self.dnorm3(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear'))\n\n        h = torch.add(h,t3)\n\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\n\nif __name__ == '__main__':\n    batch_size = 8\n    model = UKan_Hybrid(\n        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[],\n        num_res_blocks=2, dropout=0.1)\n\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Model_UMLP.py",
    "content": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        enable_standalone_scale_spline=True,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KANLinear, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        h = (grid_range[1] - grid_range[0]) / grid_size\n        grid = (\n            (\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n                + grid_range[0]\n            )\n            .expand(in_features, -1)\n            .contiguous()\n        )\n        self.register_buffer(\"grid\", grid)\n\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n        self.spline_weight = torch.nn.Parameter(\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\n        )\n        if enable_standalone_scale_spline:\n            self.spline_scaler = torch.nn.Parameter(\n                torch.Tensor(out_features, in_features)\n            )\n\n        self.scale_noise = scale_noise\n        self.scale_base = scale_base\n        self.scale_spline = scale_spline\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n        self.base_activation = base_activation()\n        self.grid_eps = grid_eps\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n        with torch.no_grad():\n            noise = (\n                (\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n                    - 1 / 2\n                )\n                * self.scale_noise\n                / self.grid_size\n            )\n            self.spline_weight.data.copy_(\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n                * self.curve2coeff(\n                    self.grid.T[self.spline_order : -self.spline_order],\n                    noise,\n                )\n            )\n            if self.enable_standalone_scale_spline:\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n\n    def b_splines(self, x: torch.Tensor):\n        \"\"\"\n        Compute the B-spline bases for the given input tensor.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n\n        Returns:\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        grid: torch.Tensor = (\n            self.grid\n        )  # (in_features, grid_size + 2 * spline_order + 1)\n        x = x.unsqueeze(-1)\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n        for k in range(1, self.spline_order + 1):\n            bases = (\n                (x - grid[:, : -(k + 1)])\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n                * bases[:, :, :-1]\n            ) + (\n                (grid[:, k + 1 :] - x)\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n                * bases[:, :, 1:]\n            )\n\n        assert bases.size() == (\n            x.size(0),\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return bases.contiguous()\n\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Compute the coefficients of the curve that interpolates the given points.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n\n        Returns:\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\n\n        A = self.b_splines(x).transpose(\n            0, 1\n        )  # (in_features, batch_size, grid_size + spline_order)\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n        solution = torch.linalg.lstsq(\n            A, B\n        ).solution  # (in_features, grid_size + spline_order, out_features)\n        result = solution.permute(\n            2, 0, 1\n        )  # (out_features, in_features, grid_size + spline_order)\n\n        assert result.size() == (\n            self.out_features,\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return result.contiguous()\n\n    @property\n    def scaled_spline_weight(self):\n        return self.spline_weight * (\n            self.spline_scaler.unsqueeze(-1)\n            if self.enable_standalone_scale_spline\n            else 1.0\n        )\n\n    def forward(self, x: torch.Tensor):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        base_output = F.linear(self.base_activation(x), self.base_weight)\n        spline_output = F.linear(\n            self.b_splines(x).view(x.size(0), -1),\n            self.scaled_spline_weight.view(self.out_features, -1),\n        )\n        return base_output + spline_output\n\n    @torch.no_grad()\n    def update_grid(self, x: torch.Tensor, margin=0.01):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        batch = x.size(0)\n\n        splines = self.b_splines(x)  # (batch, in, coeff)\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n        unreduced_spline_output = unreduced_spline_output.permute(\n            1, 0, 2\n        )  # (batch, in, out)\n\n        # sort each channel individually to collect data distribution\n        x_sorted = torch.sort(x, dim=0)[0]\n        grid_adaptive = x_sorted[\n            torch.linspace(\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n            )\n        ]\n\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n        grid_uniform = (\n            torch.arange(\n                self.grid_size + 1, dtype=torch.float32, device=x.device\n            ).unsqueeze(1)\n            * uniform_step\n            + x_sorted[0]\n            - margin\n        )\n\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n        grid = torch.concatenate(\n            [\n                grid[:1]\n                - uniform_step\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n                grid,\n                grid[-1:]\n                + uniform_step\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n            ],\n            dim=0,\n        )\n\n        self.grid.copy_(grid.T)\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        \"\"\"\n        Compute the regularization loss.\n\n        This is a dumb simulation of the original L1 regularization as stated in the\n        paper, since the original one requires computing absolutes and entropy from the\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n        behind the F.linear function if we want an memory efficient implementation.\n\n        The L1 regularization is now computed as mean absolute value of the spline\n        weights. The authors implementation also includes this term in addition to the\n        sample-based regularization.\n        \"\"\"\n        l1_fake = self.spline_weight.abs().mean(-1)\n        regularization_loss_activation = l1_fake.sum()\n        p = l1_fake / regularization_loss_activation\n        regularization_loss_entropy = -torch.sum(p * p.log())\n        return (\n            regularize_activation * regularization_loss_activation\n            + regularize_entropy * regularization_loss_entropy\n        )\n\n\nclass KAN(torch.nn.Module):\n    def __init__(\n        self,\n        layers_hidden,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KAN, self).__init__()\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        self.layers = torch.nn.ModuleList()\n        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\n            self.layers.append(\n                KANLinear(\n                    in_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n            )\n\n    def forward(self, x: torch.Tensor, update_grid=False):\n        for layer in self.layers:\n            if update_grid:\n                layer.update_grid(x)\n            x = layer(x)\n        return x\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        return sum(\n            layer.regularization_loss(regularize_activation, regularize_entropy)\n            for layer in self.layers\n        )\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, bias=False)\n\n\ndef shift(dim):\n            x_shift = [ torch.roll(x_c, shift, dim) for x_c, shift in zip(xs, range(-self.pad, self.pad+1))]\n            x_cat = torch.cat(x_shift, 1)\n            x_cat = torch.narrow(x_cat, 2, self.pad, H)\n            x_cat = torch.narrow(x_cat, 3, self.pad, W)\n            return x_cat\n\n\nclass OverlapPatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,\n                              padding=(patch_size[0] // 2, patch_size[1] // 2))\n        self.norm = nn.LayerNorm(embed_dim)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n\n        return x, H, W\n\n\nclass Swish(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\ndef swish(x):\n    \n    return x * torch.sigmoid(x)\n\n\nclass TimeEmbedding(nn.Module):\n    def __init__(self, T, d_model, dim):\n        assert d_model % 2 == 0\n        super().__init__()\n        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n        emb = torch.exp(-emb)\n        pos = torch.arange(T).float()\n        emb = pos[:, None] * emb[None, :]\n        assert list(emb.shape) == [T, d_model // 2]\n        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n        assert list(emb.shape) == [T, d_model // 2, 2]\n        emb = emb.view(T, d_model)\n\n        self.timembedding = nn.Sequential(\n            nn.Embedding.from_pretrained(emb),\n            nn.Linear(d_model, dim),\n            Swish(),\n            nn.Linear(dim, dim),\n        )\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, nn.Linear):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n\n    def forward(self, t):\n        emb = self.timembedding(t)\n        return emb\n\n\nclass DownSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        x = self.main(x)\n        return x\n\n\nclass UpSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        _, _, H, W = x.shape\n        x = F.interpolate(\n            x, scale_factor=2, mode='nearest')\n        x = self.main(x)\n        return x\n    \nclass kan(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4, kan_val=False):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.dim = in_features\n        \n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=Swish\n        grid_eps=0.02\n        grid_range=[-1, 1]\n\n        if kan_val:\n            self.fc1 = nn.Linear(in_features, hidden_features)\n            self.fc2 = nn.Linear(hidden_features, out_features)\n            self.fc3 = nn.Linear(hidden_features, out_features)\n        else:\n            self.fc1 = nn.Sequential(\n                nn.Linear(in_features, hidden_features),\n                Swish(),\n                nn.Linear(hidden_features, out_features))\n            \n \n        self.drop = nn.Dropout(drop)\n\n        self.shift_size = shift_size\n        self.pad = shift_size // 2\n\n        \n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    \n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n\n        x = self.fc1(x.reshape(B*N,C))\n\n        x = x.reshape(B,N,C).contiguous()\n\n        return x\n\nclass shiftedBlock(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, version=1, kan_val=False):\n        super().__init__()\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, dim),\n        )\n        # self.mlp = shiftmlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n        self.kan = kan(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, kan_val=kan_val)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W, temb):\n\n        temb = self.temb_proj(temb)\n        # x = x + self.drop_path(self.kan(self.norm2(x), H, W))\n        x = self.drop_path(self.kan(self.norm2(x), H, W))\n        x = x + temb.unsqueeze(1)\n\n        return x\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass DW_bn_relu(nn.Module):\n    def __init__(self, dim=768):\n        super(DW_bn_relu, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n        self.bn = nn.GroupNorm(32, dim)\n        # self.relu = Swish()\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = self.bn(x)\n        x = swish(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass SingleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(SingleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n        )\n\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass DoubleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(DoubleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n            nn.GroupNorm(32, h_ch),\n            Swish(),\n            nn.Conv2d(h_ch, h_ch, 3, padding=1),\n            nn.GroupNorm(32, h_ch),\n            Swish()\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass D_SingleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(D_SingleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.GroupNorm(32,in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input, temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\n\nclass D_DoubleConv(nn.Module):\n    def __init__(self, in_ch, h_ch):\n        super(D_DoubleConv, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, in_ch, 3, padding=1),\n            nn.GroupNorm(32,in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, padding=1),\n             nn.GroupNorm(32,h_ch),\n            Swish()\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(256, h_ch),\n        )\n    def forward(self, input,temb):\n        return self.conv(input) + self.temb_proj(temb)[:,:,None, None]\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.group_norm = nn.GroupNorm(32, in_ch)\n        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.initialize()\n\n    def initialize(self):\n        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n            init.xavier_uniform_(module.weight)\n            init.zeros_(module.bias)\n        init.xavier_uniform_(self.proj.weight, gain=1e-5)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        h = self.group_norm(x)\n        q = self.proj_q(h)\n        k = self.proj_k(h)\n        v = self.proj_v(h)\n\n        q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n        k = k.view(B, C, H * W)\n        w = torch.bmm(q, k) * (int(C) ** (-0.5))\n        assert list(w.shape) == [B, H * W, H * W]\n        w = F.softmax(w, dim=-1)\n\n        v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n        h = torch.bmm(w, v)\n        assert list(h.shape) == [B, H * W, C]\n        h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n        h = self.proj(h)\n\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, h_ch, tdim, dropout, attn=False):\n        super().__init__()\n        self.block1 = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, h_ch, 3, stride=1, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(tdim, h_ch),\n        )\n        self.block2 = nn.Sequential(\n            nn.GroupNorm(32, h_ch),\n            Swish(),\n            nn.Dropout(dropout),\n            nn.Conv2d(h_ch, h_ch, 3, stride=1, padding=1),\n        )\n        if in_ch != h_ch:\n            self.shortcut = nn.Conv2d(in_ch, h_ch, 1, stride=1, padding=0)\n        else:\n            self.shortcut = nn.Identity()\n        if attn:\n            self.attn = AttnBlock(h_ch)\n        else:\n            self.attn = nn.Identity()\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n\n    def forward(self, x, temb):\n        h = self.block1(x)\n        h += self.temb_proj(temb)[:, :, None, None]\n        h = self.block2(h)\n\n        h = h + self.shortcut(x)\n        h = self.attn(h)\n        return h\n\n\nclass UMLP(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index h of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n        attn = []\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record hput channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            h_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, h_ch=h_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = h_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            h_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, h_ch=h_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = h_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n\n        # \n        embed_dims = [256, 320, 512]\n        drop_rate = 0.0\n        attn_drop_rate = 0.0\n        kan_val = False\n        version = 4\n        sr_ratios = [8, 4, 2, 1]\n        num_heads=[1, 2, 4, 8]\n        qkv_bias=False\n        qk_scale=None\n        norm_layer = nn.LayerNorm\n        dpr = [0.0, 0.0, 0.0]\n        self.patch_embed3 = OverlapPatchEmbed(img_size=64 // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])\n        self.patch_embed4 = OverlapPatchEmbed(img_size=64 // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])\n\n        \n        self.norm3 = norm_layer(embed_dims[1])\n        self.norm4 = norm_layer(embed_dims[2])\n\n        self.dnorm3 = norm_layer(embed_dims[1])\n        self.dnorm4 = norm_layer(embed_dims[0])\n\n\n        self.kan_block1 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])\n\n        self.kan_block2 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[2], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])\n\n        self.kan_dblock1 = nn.ModuleList([shiftedBlock(\n            dim=embed_dims[1], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,\n            drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[0], norm_layer=norm_layer,\n            sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])\n\n        # self.kan_dblock2 = nn.ModuleList([shiftedBlock(\n        #     dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=1, qkv_bias=qkv_bias, qk_scale=qk_scale,\n        #     drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[1], norm_layer=norm_layer,\n        #     sr_ratio=sr_ratios[0], version=version, kan_val=kan_val)])\n\n        self.decoder1 = D_SingleConv(embed_dims[2], embed_dims[1])  \n        self.decoder2 = D_SingleConv(embed_dims[1], embed_dims[0])  \n\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n    \n        t3 = h\n\n        B = x.shape[0]\n        h, H, W = self.patch_embed3(h)\n \n        for i, blk in enumerate(self.kan_block1):\n            h = blk(h, H, W, temb)\n        h = self.norm3(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        t4 = h\n\n        h, H, W= self.patch_embed4(h)\n        for i, blk in enumerate(self.kan_block2):\n            h = blk(h, H, W, temb)\n        h = self.norm4(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        ### Stage 4\n        h = swish(F.interpolate(self.decoder1(h, temb), scale_factor=(2,2), mode ='bilinear'))\n\n        h = torch.add(h, t4)\n\n        _, _, H, W = h.shape\n        h = h.flatten(2).transpose(1,2)\n        for i, blk in enumerate(self.kan_dblock1):\n            h = blk(h, H, W, temb)\n\n            \n        ### Stage 3\n        h = self.dnorm3(h)\n        h = h.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        h = swish(F.interpolate(self.decoder2(h, temb),scale_factor=(2,2),mode ='bilinear'))\n\n        h = torch.add(h,t3)\n\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\n\nif __name__ == '__main__':\n    batch_size = 8\n    model = UMLP(\n        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[],\n        num_res_blocks=2, dropout=0.1)\n    \n\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/Train.py",
    "content": "\nimport os\nfrom typing import Dict\nimport torch\nimport torch.optim as optim\nfrom tqdm import tqdm\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms, transforms\n# from torchvision.datasets import CIFAR10\nfrom torchvision.utils import save_image\nfrom Diffusion import GaussianDiffusionSampler, GaussianDiffusionTrainer\nfrom Diffusion.UNet import UNet, UNet_Baseline\nfrom Diffusion.Model_ConvKan import UNet_ConvKan\nfrom Diffusion.Model_UMLP import UMLP\nfrom Diffusion.Model_UKAN_Hybrid import UKan_Hybrid\nfrom Scheduler import GradualWarmupScheduler\nfrom skimage import io\nimport os\nfrom torchvision.transforms import ToTensor, Normalize, Compose\nfrom torch.utils.data import Dataset\nimport sys\n\n\nmodel_dict = {\n    'UNet': UNet,\n    'UNet_ConvKan': UNet_ConvKan, # dose not work\n    'UMLP': UMLP,\n    'UKan_Hybrid': UKan_Hybrid,\n    'UNet_Baseline': UNet_Baseline,\n}\n\nclass UnlabeledDataset(Dataset):\n    def __init__(self, folder, transform=None, repeat_n=1):\n        self.folder = folder\n        self.transform = transform\n        # self.image_files = os.listdir(folder) * repeat_n\n        self.image_files = os.listdir(folder) \n\n    def __len__(self):\n        return len(self.image_files)\n\n    def __getitem__(self, idx):\n        image_file = self.image_files[idx]\n        image_path = os.path.join(self.folder, image_file)\n        image = io.imread(image_path)\n        if self.transform:\n            image = self.transform(image)\n        return image, torch.Tensor([0])\n\n\ndef train(modelConfig: Dict):\n    device = torch.device(modelConfig[\"device\"])\n    log_print = True\n    if log_print:\n        file = open(modelConfig[\"save_weight_dir\"]+'log.txt', \"w\")\n        sys.stdout = file\n    transform = Compose([\n        ToTensor(),\n        transforms.RandomHorizontalFlip(),\n        transforms.RandomVerticalFlip(),\n        Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n        ])\n\n    if modelConfig[\"dataset\"] == 'cvc':\n        dataset = UnlabeledDataset('data/cvc/images_64/', transform=transform, repeat_n=modelConfig[\"dataset_repeat\"])\n    elif modelConfig[\"dataset\"] == 'glas':\n        dataset = UnlabeledDataset('data/glas/images_64/', transform=transform, repeat_n=modelConfig[\"dataset_repeat\"])\n    elif modelConfig[\"dataset\"] == 'glas_resize':\n        dataset = UnlabeledDataset('data/glas/images_64_resize/', transform=transform, repeat_n=modelConfig[\"dataset_repeat\"])\n    elif modelConfig[\"dataset\"] == 'busi':\n        dataset = UnlabeledDataset('data/busi/images_64/', transform=transform, repeat_n=modelConfig[\"dataset_repeat\"])\n    else:\n        raise ValueError('dataset not found')\n\n    print('modelConfig: ')\n    for key, value in modelConfig.items():\n        print(key, ' : ', value)\n        \n    dataloader = DataLoader(\n        dataset, batch_size=modelConfig[\"batch_size\"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)\n    \n    print('Using {}'.format(modelConfig[\"model\"]))\n    # model setup\n    net_model =model_dict[modelConfig[\"model\"]](T=modelConfig[\"T\"], ch=modelConfig[\"channel\"], ch_mult=modelConfig[\"channel_mult\"], attn=modelConfig[\"attn\"],\n                    num_res_blocks=modelConfig[\"num_res_blocks\"], dropout=modelConfig[\"dropout\"]).to(device)\n\n    if modelConfig[\"training_load_weight\"] is not None:\n        net_model.load_state_dict(torch.load(os.path.join(\n            modelConfig[\"save_weight_dir\"], modelConfig[\"training_load_weight\"]), map_location=device))\n        \n    optimizer = torch.optim.AdamW(\n        net_model.parameters(), lr=modelConfig[\"lr\"], weight_decay=1e-4)\n    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(\n        optimizer=optimizer, T_max=modelConfig[\"epoch\"], eta_min=0, last_epoch=-1)\n    warmUpScheduler = GradualWarmupScheduler(\n        optimizer=optimizer, multiplier=modelConfig[\"multiplier\"], warm_epoch=modelConfig[\"epoch\"] // 10, after_scheduler=cosineScheduler)\n\n    trainer = GaussianDiffusionTrainer(\n        net_model, modelConfig[\"beta_1\"], modelConfig[\"beta_T\"], modelConfig[\"T\"]).to(device)\n\n    # start training\n    for e in range(1,modelConfig[\"epoch\"]+1):\n        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:\n            for images, labels in tqdmDataLoader:\n                # train\n                optimizer.zero_grad()\n                x_0 = images.to(device)\n                \n                loss = trainer(x_0).sum() / 1000.\n                loss.backward()\n                torch.nn.utils.clip_grad_norm_(\n                    net_model.parameters(), modelConfig[\"grad_clip\"])\n                optimizer.step()\n                tqdmDataLoader.set_postfix(ordered_dict={\n                    \"epoch\": e,\n                    \"loss: \": loss.item(),\n                    \"img shape: \": x_0.shape,\n                    \"LR\": optimizer.state_dict()['param_groups'][0][\"lr\"]\n                })\n                # print version\n                if log_print:\n                    print(\"epoch: \", e, \"loss: \", loss.item(), \"img shape: \", x_0.shape, \"LR: \", optimizer.state_dict()['param_groups'][0][\"lr\"])\n        warmUpScheduler.step()\n        if e % 50 ==0:\n            torch.save(net_model.state_dict(), os.path.join(\n                modelConfig[\"save_weight_dir\"], 'ckpt_' + str(e) + \"_.pt\"))\n            modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(e)\n            eval_tmp(modelConfig, e)\n\n    torch.save(net_model.state_dict(), os.path.join(\n        modelConfig[\"save_weight_dir\"], 'ckpt_' + str(e) + \"_.pt\"))\n    if log_print:\n        file.close()\n        sys.stdout = sys.__stdout__\n    \ndef eval_tmp(modelConfig: Dict, nme: int):\n    # load model and evaluate\n    with torch.no_grad():\n        device = torch.device(modelConfig[\"device\"])\n        model = model_dict[modelConfig[\"model\"]](T=modelConfig[\"T\"], ch=modelConfig[\"channel\"], ch_mult=modelConfig[\"channel_mult\"], attn=modelConfig[\"attn\"],\n                     num_res_blocks=modelConfig[\"num_res_blocks\"], dropout=0.)\n        ckpt = torch.load(os.path.join(\n            modelConfig[\"save_weight_dir\"], modelConfig[\"test_load_weight\"]), map_location=device)\n    \n        model.load_state_dict(ckpt)\n        \n        print(\"model load weight done.\")\n        model.eval()\n        sampler = GaussianDiffusionSampler(\n            model, modelConfig[\"beta_1\"], modelConfig[\"beta_T\"], modelConfig[\"T\"]).to(device)\n        # Sampled from standard normal distribution\n        noisyImage = torch.randn(\n            size=[modelConfig[\"batch_size\"], 3, modelConfig[\"img_size\"], modelConfig[\"img_size\"]], device=device)\n        # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)\n        # save_image(saveNoisy, os.path.join(\n            # modelConfig[\"sampled_dir\"], modelConfig[\"sampledNoisyImgName\"]), nrow=modelConfig[\"nrow\"])\n        sampledImgs = sampler(noisyImage)\n        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]\n\n        save_root = modelConfig[\"sampled_dir\"].replace('Gens','Tmp')\n        os.makedirs(save_root, exist_ok=True)\n        save_image(sampledImgs, os.path.join(\n            save_root,  modelConfig[\"sampledImgName\"].replace('.png','_{}.png').format(nme)), nrow=modelConfig[\"nrow\"])\n        if nme < 0.95 * modelConfig[\"epoch\"]:\n            os.remove(os.path.join(\n                modelConfig[\"save_weight_dir\"], modelConfig[\"test_load_weight\"]))\n\ndef eval(modelConfig: Dict):\n    # load model and evaluate\n    with torch.no_grad():\n        device = torch.device(modelConfig[\"device\"])\n\n        model = model_dict[modelConfig[\"model\"]](T=modelConfig[\"T\"], ch=modelConfig[\"channel\"], ch_mult=modelConfig[\"channel_mult\"], attn=modelConfig[\"attn\"],\n                    num_res_blocks=modelConfig[\"num_res_blocks\"], dropout=modelConfig[\"dropout\"]).to(device)\n    \n        ckpt = torch.load(os.path.join(\n            modelConfig[\"save_weight_dir\"], modelConfig[\"test_load_weight\"]), map_location=device)\n\n        model.load_state_dict(ckpt)\n        print(\"model load weight done.\")\n        model.eval()\n        sampler = GaussianDiffusionSampler(\n            model, modelConfig[\"beta_1\"], modelConfig[\"beta_T\"], modelConfig[\"T\"]).to(device)\n        # Sampled from standard normal distribution\n        noisyImage = torch.randn(\n            size=[modelConfig[\"batch_size\"], 3, modelConfig[\"img_size\"], modelConfig[\"img_size\"]], device=device)     \n        # saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)\n        # save_image(saveNoisy, os.path.join(\n        #     modelConfig[\"sampled_dir\"], modelConfig[\"sampledNoisyImgName\"]), nrow=modelConfig[\"nrow\"])\n        sampledImgs = sampler(noisyImage)\n        sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]\n\n        for i, image in enumerate(sampledImgs):\n    \n            save_image(image, os.path.join(modelConfig[\"sampled_dir\"],  modelConfig[\"sampledImgName\"].replace('.png','_{}.png').format(i)), nrow=modelConfig[\"nrow\"])\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/UNet.py",
    "content": "\n   \nimport math\nimport torch\nfrom torch import nn\nfrom torch.nn import init\nfrom torch.nn import functional as F\n\n\nclass Swish(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass TimeEmbedding(nn.Module):\n    def __init__(self, T, d_model, dim):\n        assert d_model % 2 == 0\n        super().__init__()\n        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)\n        emb = torch.exp(-emb)\n        pos = torch.arange(T).float()\n        emb = pos[:, None] * emb[None, :]\n        assert list(emb.shape) == [T, d_model // 2]\n        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)\n        assert list(emb.shape) == [T, d_model // 2, 2]\n        emb = emb.view(T, d_model)\n\n        self.timembedding = nn.Sequential(\n            nn.Embedding.from_pretrained(emb),\n            nn.Linear(d_model, dim),\n            Swish(),\n            nn.Linear(dim, dim),\n        )\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, nn.Linear):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n\n    def forward(self, t):\n        emb = self.timembedding(t)\n        return emb\n\n\nclass DownSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        x = self.main(x)\n        return x\n\n\nclass UpSample(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.main.weight)\n        init.zeros_(self.main.bias)\n\n    def forward(self, x, temb):\n        _, _, H, W = x.shape\n        x = F.interpolate(\n            x, scale_factor=2, mode='nearest')\n        x = self.main(x)\n        return x\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_ch):\n        super().__init__()\n        self.group_norm = nn.GroupNorm(32, in_ch)\n        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)\n        self.initialize()\n\n    def initialize(self):\n        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:\n            init.xavier_uniform_(module.weight)\n            init.zeros_(module.bias)\n        init.xavier_uniform_(self.proj.weight, gain=1e-5)\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        h = self.group_norm(x)\n        q = self.proj_q(h)\n        k = self.proj_k(h)\n        v = self.proj_v(h)\n\n        q = q.permute(0, 2, 3, 1).view(B, H * W, C)\n        k = k.view(B, C, H * W)\n        w = torch.bmm(q, k) * (int(C) ** (-0.5))\n        assert list(w.shape) == [B, H * W, H * W]\n        w = F.softmax(w, dim=-1)\n\n        v = v.permute(0, 2, 3, 1).view(B, H * W, C)\n        h = torch.bmm(w, v)\n        assert list(h.shape) == [B, H * W, C]\n        h = h.view(B, H, W, C).permute(0, 3, 1, 2)\n        h = self.proj(h)\n\n        return x + h\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):\n        super().__init__()\n        self.block1 = nn.Sequential(\n            nn.GroupNorm(32, in_ch),\n            Swish(),\n            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),\n        )\n        self.temb_proj = nn.Sequential(\n            Swish(),\n            nn.Linear(tdim, out_ch),\n        )\n        self.block2 = nn.Sequential(\n            nn.GroupNorm(32, out_ch),\n            Swish(),\n            nn.Dropout(dropout),\n            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),\n        )\n        if in_ch != out_ch:\n            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)\n        else:\n            self.shortcut = nn.Identity()\n        if attn:\n            self.attn = AttnBlock(out_ch)\n        else:\n            self.attn = nn.Identity()\n        self.initialize()\n\n    def initialize(self):\n        for module in self.modules():\n            if isinstance(module, (nn.Conv2d, nn.Linear)):\n                init.xavier_uniform_(module.weight)\n                init.zeros_(module.bias)\n        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)\n\n    def forward(self, x, temb):\n        h = self.block1(x)\n        h += self.temb_proj(temb)[:, :, None, None]\n        h = self.block2(h)\n\n        h = h + self.shortcut(x)\n        h = self.attn(h)\n        return h\n\n\nclass UNet(nn.Module):\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n        self.middleblocks = nn.ModuleList([\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),\n            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),\n        ])\n\n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Middle \n        # torch.Size([8, 512, 4, 4])\n        for layer in self.middleblocks:\n            h = layer(h, temb)\n        # torch.Size([8, 512, 4, 4])\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n    \nclass UNet_Baseline(nn.Module):\n    # Remove the middle blocks\n    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):\n        super().__init__()\n        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'\n        tdim = ch * 4\n        self.time_embedding = TimeEmbedding(T, ch, tdim)\n\n        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)\n        self.downblocks = nn.ModuleList()\n        chs = [ch]  # record output channel when dowmsample for upsample\n        now_ch = ch\n        for i, mult in enumerate(ch_mult):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks):\n                self.downblocks.append(ResBlock(\n                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n                chs.append(now_ch)\n            if i != len(ch_mult) - 1:\n                self.downblocks.append(DownSample(now_ch))\n                chs.append(now_ch)\n\n\n        self.upblocks = nn.ModuleList()\n        for i, mult in reversed(list(enumerate(ch_mult))):\n            out_ch = ch * mult\n            for _ in range(num_res_blocks + 1):\n                self.upblocks.append(ResBlock(\n                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,\n                    dropout=dropout, attn=(i in attn)))\n                now_ch = out_ch\n            if i != 0:\n                self.upblocks.append(UpSample(now_ch))\n        assert len(chs) == 0\n\n        self.tail = nn.Sequential(\n            nn.GroupNorm(32, now_ch),\n            Swish(),\n            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)\n        )\n        self.initialize()\n\n    def initialize(self):\n        init.xavier_uniform_(self.head.weight)\n        init.zeros_(self.head.bias)\n        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)\n        init.zeros_(self.tail[-1].bias)\n\n    def forward(self, x, t):\n        # Timestep embedding\n        temb = self.time_embedding(t)\n        # Downsampling\n        h = self.head(x)\n        hs = [h]\n        for layer in self.downblocks:\n            h = layer(h, temb)\n            hs.append(h)\n        # Upsampling\n        for layer in self.upblocks:\n            if isinstance(layer, ResBlock):\n                h = torch.cat([h, hs.pop()], dim=1)\n            h = layer(h, temb)\n        h = self.tail(h)\n\n        assert len(hs) == 0\n        return h\n\n\n\nif __name__ == '__main__':\n    batch_size = 8\n    model = UNet(\n        T=1000, ch=64, ch_mult=[1, 2, 2, 2], attn=[1],\n        num_res_blocks=2, dropout=0.1)"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/__init__.py",
    "content": "from .Diffusion import *\nfrom .UNet import *\nfrom .Train import *\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/__init__.py",
    "content": "from .kan import *\nfrom .fastkanconv import *\n# from .kan_convolutional import *\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/fastkanconv.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom typing import List, Tuple, Union\n\n\nclass PolynomialFunction(nn.Module):\n    def __init__(self, \n                 degree: int = 3):\n        super().__init__()\n        self.degree = degree\n\n    def forward(self, x):\n        return torch.stack([x ** i for i in range(self.degree)], dim=-1)\n    \nclass BSplineFunction(nn.Module):\n    def __init__(self, grid_min: float = -2.,\n        grid_max: float = 2., degree: int = 3, num_basis: int = 8):\n        super(BSplineFunction, self).__init__()\n        self.degree = degree\n        self.num_basis = num_basis\n        self.knots = torch.linspace(grid_min, grid_max, num_basis + degree + 1)  # Uniform knots\n\n    def basis_function(self, i, k, t):\n        if k == 0:\n            return ((self.knots[i] <= t) & (t < self.knots[i + 1])).float()\n        else:\n            left_num = (t - self.knots[i]) * self.basis_function(i, k - 1, t)\n            left_den = self.knots[i + k] - self.knots[i]\n            left = left_num / left_den if left_den != 0 else 0\n\n            right_num = (self.knots[i + k + 1] - t) * self.basis_function(i + 1, k - 1, t)\n            right_den = self.knots[i + k + 1] - self.knots[i + 1]\n            right = right_num / right_den if right_den != 0 else 0\n            return left + right \n    \n    def forward(self, x):\n        x = x.squeeze()  # Assuming x is of shape (B, 1)\n        basis_functions = torch.stack([self.basis_function(i, self.degree, x) for i in range(self.num_basis)], dim=-1)\n        return basis_functions\n\nclass ChebyshevFunction(nn.Module):\n    def __init__(self, degree: int = 4):\n        super(ChebyshevFunction, self).__init__()\n        self.degree = degree\n\n    def forward(self, x):\n        chebyshev_polynomials = [torch.ones_like(x), x]\n        for n in range(2, self.degree):\n            chebyshev_polynomials.append(2 * x * chebyshev_polynomials[-1] - chebyshev_polynomials[-2])\n        return torch.stack(chebyshev_polynomials, dim=-1)\n\nclass FourierBasisFunction(nn.Module):\n    def __init__(self, \n                 num_frequencies: int = 4, \n                 period: float = 1.0):\n        super(FourierBasisFunction, self).__init__()\n        assert num_frequencies % 2 == 0, \"num_frequencies must be even\"\n        self.num_frequencies = num_frequencies\n        self.period = nn.Parameter(torch.Tensor([period]), requires_grad=False)\n\n    def forward(self, x):\n        frequencies = torch.arange(1, self.num_frequencies // 2 + 1, device=x.device)\n        sin_components = torch.sin(2 * torch.pi * frequencies * x[..., None] / self.period)\n        cos_components = torch.cos(2 * torch.pi * frequencies * x[..., None] / self.period)\n        basis_functions = torch.cat([sin_components, cos_components], dim=-1)\n        return basis_functions\n        \nclass RadialBasisFunction(nn.Module):\n    def __init__(\n        self,\n        grid_min: float = -2.,\n        grid_max: float = 2.,\n        num_grids: int = 4,\n        denominator: float = None,\n    ):\n        super().__init__()\n        grid = torch.linspace(grid_min, grid_max, num_grids)\n        self.grid = torch.nn.Parameter(grid, requires_grad=False)\n        self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)\n\n    def forward(self, x):\n        return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)\n    \n\n    \n    \nclass SplineConv2D(nn.Conv2d):\n    def __init__(self, \n                 in_channels: int, \n                 out_channels: int, \n                 kernel_size: Union[int, Tuple[int, int]] = 3,\n                 stride: Union[int, Tuple[int, int]] = 1, \n                 padding: Union[int, Tuple[int, int]] = 0, \n                 dilation: Union[int, Tuple[int, int]] = 1,\n                 groups: int = 1, \n                 bias: bool = True, \n                 init_scale: float = 0.1, \n                 padding_mode: str = \"zeros\", \n                 **kw\n                 ) -> None:\n        self.init_scale = init_scale\n        super().__init__(in_channels, \n                         out_channels, \n                         kernel_size, \n                         stride, \n                         padding, \n                         dilation, \n                         groups, \n                         bias, \n                         padding_mode, \n                         **kw\n                         )\n\n    def reset_parameters(self) -> None:\n        nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n\nclass FastKANConvLayer(nn.Module):\n    def __init__(self, \n                 in_channels: int, \n                 out_channels: int, \n                 kernel_size: Union[int, Tuple[int, int]] = 3,\n                 stride: Union[int, Tuple[int, int]] = 1, \n                 padding: Union[int, Tuple[int, int]] = 0, \n                 dilation: Union[int, Tuple[int, int]] = 1,\n                 groups: int = 1, \n                 bias: bool = True, \n                 grid_min: float = -2., \n                 grid_max: float = 2.,\n                 num_grids: int = 4, \n                 use_base_update: bool = True, \n                 base_activation = F.silu,\n                 spline_weight_init_scale: float = 0.1, \n                 padding_mode: str = \"zeros\",\n                 kan_type: str = \"BSpline\",\n                #  kan_type: str = \"RBF\",\n                 ) -> None:\n        \n        super().__init__()\n        if kan_type == \"RBF\":\n            self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)\n        elif kan_type == \"Fourier\":\n            self.rbf = FourierBasisFunction(num_grids)\n        elif kan_type == \"Poly\":\n            self.rbf = PolynomialFunction(num_grids)\n        elif kan_type == \"Chebyshev\":\n            self.rbf = ChebyshevFunction(num_grids)\n        elif kan_type == \"BSpline\":\n            self.rbf = BSplineFunction(grid_min, grid_max, 3, num_grids)\n\n        self.spline_conv = SplineConv2D(in_channels * num_grids, \n                                        out_channels, \n                                        kernel_size,\n                                        stride, \n                                        padding, \n                                        dilation, \n                                        groups, \n                                        bias,\n                                        spline_weight_init_scale, \n                                        padding_mode)\n        \n        self.use_base_update = use_base_update\n        if use_base_update:\n            self.base_activation = base_activation\n            self.base_conv = nn.Conv2d(in_channels, \n                                       out_channels, \n                                       kernel_size, \n                                       stride, \n                                       padding, \n                                       dilation, \n                                       groups, \n                                       bias, \n                                       padding_mode)\n\n    def forward(self, x):\n        batch_size, channels, height, width = x.shape\n        x_rbf = self.rbf(x.view(batch_size, channels, -1)).view(batch_size, channels, height, width, -1)\n        x_rbf = x_rbf.permute(0, 4, 1, 2, 3).contiguous().view(batch_size, -1, height, width)\n        \n        # Apply spline convolution\n        ret = self.spline_conv(x_rbf)\n         \n        if self.use_base_update:\n            base = self.base_conv(self.base_activation(x))\n            ret = ret + base\n        \n        return ret\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/kan_utils/kan.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        enable_standalone_scale_spline=True,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KANLinear, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        h = (grid_range[1] - grid_range[0]) / grid_size\n        grid = (\n            (\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n                + grid_range[0]\n            )\n            .expand(in_features, -1)\n            .contiguous()\n        )\n        self.register_buffer(\"grid\", grid)\n\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n        self.spline_weight = torch.nn.Parameter(\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\n        )\n        if enable_standalone_scale_spline:\n            self.spline_scaler = torch.nn.Parameter(\n                torch.Tensor(out_features, in_features)\n            )\n\n        self.scale_noise = scale_noise\n        self.scale_base = scale_base\n        self.scale_spline = scale_spline\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n        self.base_activation = base_activation()\n        self.grid_eps = grid_eps\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n        with torch.no_grad():\n            noise = (\n                (\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n                    - 1 / 2\n                )\n                * self.scale_noise\n                / self.grid_size\n            )\n            self.spline_weight.data.copy_(\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n                * self.curve2coeff(\n                    self.grid.T[self.spline_order : -self.spline_order],\n                    noise,\n                )\n            )\n            if self.enable_standalone_scale_spline:\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n\n    def b_splines(self, x: torch.Tensor):\n        \"\"\"\n        Compute the B-spline bases for the given input tensor.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n\n        Returns:\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        grid: torch.Tensor = (\n            self.grid\n        )  # (in_features, grid_size + 2 * spline_order + 1)\n        x = x.unsqueeze(-1)\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n        for k in range(1, self.spline_order + 1):\n            bases = (\n                (x - grid[:, : -(k + 1)])\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n                * bases[:, :, :-1]\n            ) + (\n                (grid[:, k + 1 :] - x)\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n                * bases[:, :, 1:]\n            )\n\n        assert bases.size() == (\n            x.size(0),\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return bases.contiguous()\n\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Compute the coefficients of the curve that interpolates the given points.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n\n        Returns:\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\n\n        A = self.b_splines(x).transpose(\n            0, 1\n        )  # (in_features, batch_size, grid_size + spline_order)\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n        solution = torch.linalg.lstsq(\n            A, B\n        ).solution  # (in_features, grid_size + spline_order, out_features)\n        result = solution.permute(\n            2, 0, 1\n        )  # (out_features, in_features, grid_size + spline_order)\n\n        assert result.size() == (\n            self.out_features,\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return result.contiguous()\n\n    @property\n    def scaled_spline_weight(self):\n        return self.spline_weight * (\n            self.spline_scaler.unsqueeze(-1)\n            if self.enable_standalone_scale_spline\n            else 1.0\n        )\n\n    def forward(self, x: torch.Tensor):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        base_output = F.linear(self.base_activation(x), self.base_weight)\n        spline_output = F.linear(\n            self.b_splines(x).view(x.size(0), -1),\n            self.scaled_spline_weight.view(self.out_features, -1),\n        )\n        return base_output + spline_output\n\n    @torch.no_grad()\n    def update_grid(self, x: torch.Tensor, margin=0.01):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        batch = x.size(0)\n\n        splines = self.b_splines(x)  # (batch, in, coeff)\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n        unreduced_spline_output = unreduced_spline_output.permute(\n            1, 0, 2\n        )  # (batch, in, out)\n\n        # sort each channel individually to collect data distribution\n        x_sorted = torch.sort(x, dim=0)[0]\n        grid_adaptive = x_sorted[\n            torch.linspace(\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n            )\n        ]\n\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n        grid_uniform = (\n            torch.arange(\n                self.grid_size + 1, dtype=torch.float32, device=x.device\n            ).unsqueeze(1)\n            * uniform_step\n            + x_sorted[0]\n            - margin\n        )\n\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n        grid = torch.concatenate(\n            [\n                grid[:1]\n                - uniform_step\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n                grid,\n                grid[-1:]\n                + uniform_step\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n            ],\n            dim=0,\n        )\n\n        self.grid.copy_(grid.T)\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        \"\"\"\n        Compute the regularization loss.\n\n        This is a dumb simulation of the original L1 regularization as stated in the\n        paper, since the original one requires computing absolutes and entropy from the\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n        behind the F.linear function if we want an memory efficient implementation.\n\n        The L1 regularization is now computed as mean absolute value of the spline\n        weights. The authors implementation also includes this term in addition to the\n        sample-based regularization.\n        \"\"\"\n        l1_fake = self.spline_weight.abs().mean(-1)\n        regularization_loss_activation = l1_fake.sum()\n        p = l1_fake / regularization_loss_activation\n        regularization_loss_entropy = -torch.sum(p * p.log())\n        return (\n            regularize_activation * regularization_loss_activation\n            + regularize_entropy * regularization_loss_entropy\n        )\n\n\nclass KAN(torch.nn.Module):\n    def __init__(\n        self,\n        layers_hidden,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KAN, self).__init__()\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        self.layers = torch.nn.ModuleList()\n        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\n            self.layers.append(\n                KANLinear(\n                    in_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n            )\n\n    def forward(self, x: torch.Tensor, update_grid=False):\n        for layer in self.layers:\n            if update_grid:\n                layer.update_grid(x)\n            x = layer(x)\n        return x\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        return sum(\n            layer.regularization_loss(regularize_activation, regularize_entropy)\n            for layer in self.layers\n        )\n"
  },
  {
    "path": "Diffusion_UKAN/Diffusion/utils.py",
    "content": "import argparse\nimport torch.nn as nn\n\nclass qkv_transform(nn.Conv1d):\n    \"\"\"Conv1d for qkv_transform\"\"\"\n\ndef str2bool(v):\n    if v.lower() in ['true', 1]:\n        return True\n    elif v.lower() in ['false', 0]:\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\ndef count_params(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "Diffusion_UKAN/Main.py",
    "content": "from Diffusion.Train import train, eval\nimport os\nimport argparse\nimport torch\nimport numpy as np\n\ndef main(model_config = None):\n\n    if model_config is not None:\n        modelConfig = model_config\n    if modelConfig[\"state\"] == \"train\":\n        train(modelConfig)\n        modelConfig['batch_size'] = 64\n        modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])\n        for i in range(32):\n            modelConfig[\"sampledImgName\"] = \"sampledImgName{}.png\".format(i)\n            eval(modelConfig)\n    else:\n        for i in range(32):\n            modelConfig[\"sampledImgName\"] = \"sampledImgName{}.png\".format(i)\n            eval(modelConfig)\n\ndef seed_all(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    np.random.seed(args.seed)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--state', type=str, default='train') # train or eval\n    parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc\n    parser.add_argument('--epoch', type=int, default=1000) # 1000 for cvc/glas, 5000 for busi\n    parser.add_argument('--batch_size', type=int, default=32)\n    parser.add_argument('--T', type=int, default=1000)\n    parser.add_argument('--channel', type=int, default=64) # 64 or 128\n    parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')\n    parser.add_argument('--num_res_blocks', type=int, default=2)\n    parser.add_argument('--dropout', type=float, default=0.15)\n    parser.add_argument('--lr', type=float, default=2e-4)\n    parser.add_argument('--img_size', type=float, default=64) \n    parser.add_argument('--dataset_repeat', type=int, default=1) # did not use\n    parser.add_argument('--seed', type=int, default=0) # did not use\n    parser.add_argument('--model', type=str, default='UKAN_Hybrid')\n    parser.add_argument('--exp_nme', type=str, default='UKAN_Hybrid')\n    parser.add_argument('--save_root', type=str, default='./Output/') \n    args = parser.parse_args()\n\n    save_root = args.save_root\n    if args.seed != 0:\n        seed_all(args)\n\n    modelConfig = {\n        \"dataset\": args.dataset, \n        \"state\": args.state, # or eval\n        \"epoch\": args.epoch,\n        \"batch_size\": args.batch_size,\n        \"T\": args.T,\n        \"channel\": args.channel,\n        \"channel_mult\": [1, 2, 3, 4],\n        \"attn\": [2],\n        \"num_res_blocks\": args.num_res_blocks,\n        \"dropout\": args.dropout,\n        \"lr\": args.lr,\n        \"multiplier\": 2.,\n        \"beta_1\": 1e-4,\n        \"beta_T\": 0.02,\n        \"img_size\": 64,\n        \"grad_clip\": 1.,\n        \"device\": \"cuda\", ### MAKE SURE YOU HAVE A GPU !!!\n        \"training_load_weight\": None,\n        \"save_weight_dir\": os.path.join(save_root, args.exp_nme, \"Weights\"),\n        \"sampled_dir\": os.path.join(save_root, args.exp_nme, \"Gens\"),\n        \"test_load_weight\": args.test_load_weight,\n        \"sampledNoisyImgName\": \"NoisyNoGuidenceImgs.png\",\n        \"sampledImgName\": \"SampledNoGuidenceImgs.png\",\n        \"nrow\": 8,\n        \"model\":args.model,\n        \"version\": 1,\n        \"dataset_repeat\": args.dataset_repeat,\n        \"seed\": args.seed,\n        \"save_root\": args.save_root,\n        }\n\n    os.makedirs(modelConfig[\"save_weight_dir\"], exist_ok=True)\n    os.makedirs(modelConfig[\"sampled_dir\"], exist_ok=True)\n\n    # backup \n    import shutil\n    shutil.copy(\"Diffusion/Model_UKAN_Hybrid.py\", os.path.join(save_root, args.exp_nme))\n    shutil.copy(\"Diffusion/Train.py\", os.path.join(save_root, args.exp_nme))\n\n    main(modelConfig)\n"
  },
  {
    "path": "Diffusion_UKAN/Main_Test.py",
    "content": "from Diffusion.Train import train, eval, eval_tmp\nimport os\nimport argparse\nimport torch\ndef main(model_config = None):\n\n    if model_config is not None:\n        modelConfig = model_config\n    if modelConfig[\"state\"] == \"train\":\n        train(modelConfig)\n        modelConfig['batch_size'] = 64\n        modelConfig['test_load_weight'] = 'ckpt_{}_.pt'.format(modelConfig['epoch'])\n        for i in range(32):\n            modelConfig[\"sampledImgName\"] = \"sampledImgName{}.png\".format(i)\n            eval(modelConfig)\n    else:\n        for i in range(1):\n            modelConfig[\"sampledImgName\"] = \"sampledImgName{}.png\".format(i)\n            eval_tmp(modelConfig,1000) # for grid visualization\n            # eval(modelConfig) # for metric evaluation\n\ndef seed_all(args):\n    torch.manual_seed(args.seed)\n    torch.cuda.manual_seed(args.seed)\n    torch.cuda.manual_seed_all(args.seed)\n    torch.backends.cudnn.deterministic = True\n    torch.backends.cudnn.benchmark = False\n    import numpy as np\n    np.random.seed(args.seed)\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--state', type=str, default='eval')\n    parser.add_argument('--dataset', type=str, default='cvc') # busi, glas, cvc\n    parser.add_argument('--epoch', type=int, default=1000)\n    parser.add_argument('--batch_size', type=int, default=32)\n    parser.add_argument('--T', type=int, default=1000)\n    parser.add_argument('--channel', type=int, default=64)\n    parser.add_argument('--test_load_weight', type=str, default='ckpt_1000_.pt')\n    parser.add_argument('--num_res_blocks', type=int, default=2)\n    parser.add_argument('--dropout', type=float, default=0.15)\n    parser.add_argument('--lr', type=float, default=1e-4)\n    parser.add_argument('--img_size', type=float, default=64) # 64 or 128\n    parser.add_argument('--dataset_repeat', type=int, default=1) # didnot use\n    parser.add_argument('--seed', type=int, default=0) \n    parser.add_argument('--model', type=str, default='UKan_Hybrid')\n    parser.add_argument('--exp_nme', type=str, default='./')\n\n    parser.add_argument('--save_root', type=str, default='released_models/ukan_cvc') \n    # parser.add_argument('--save_root', type=str, default='released_models/ukan_glas') \n    # parser.add_argument('--save_root', type=str, default='released_models/ukan_busi') \n    args = parser.parse_args()\n\n    save_root = args.save_root\n    if args.seed != 0:\n        seed_all(args)\n\n    modelConfig = {\n        \"dataset\": args.dataset, \n        \"state\": args.state, # or eval\n        \"epoch\": args.epoch,\n        \"batch_size\": args.batch_size,\n        \"T\": args.T,\n        \"channel\": args.channel,\n        \"channel_mult\": [1, 2, 3, 4],\n        \"attn\": [2],\n        \"num_res_blocks\": args.num_res_blocks,\n        \"dropout\": args.dropout,\n        \"lr\": args.lr,\n        \"multiplier\": 2.,\n        \"beta_1\": 1e-4,\n        \"beta_T\": 0.02,\n        \"img_size\": 64,\n        \"grad_clip\": 1.,\n        \"device\": \"cuda\", ### MAKE SURE YOU HAVE A GPU !!!\n        \"training_load_weight\": None,\n        \"save_weight_dir\": os.path.join(save_root, args.exp_nme, \"Weights\"),\n        \"sampled_dir\": os.path.join(save_root, args.exp_nme, \"FinalCheck\"),\n        \"test_load_weight\": args.test_load_weight,\n        \"sampledNoisyImgName\": \"NoisyNoGuidenceImgs.png\",\n        \"sampledImgName\": \"SampledNoGuidenceImgs.png\",\n        \"nrow\": 8,\n        \"model\":args.model, \n        \"version\": 1,\n        \"dataset_repeat\": args.dataset_repeat,\n        \"seed\": args.seed,\n        \"save_root\": args.save_root,\n        }\n\n    os.makedirs(modelConfig[\"save_weight_dir\"], exist_ok=True)\n    os.makedirs(modelConfig[\"sampled_dir\"], exist_ok=True)\n\n    main(modelConfig)\n"
  },
  {
    "path": "Diffusion_UKAN/README.md",
    "content": "# Diffusion UKAN (arxiv)\n\n> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>\n> [Chenxin Li](https://xggnet.github.io/)\\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)\\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)\\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)\\*, [Hengyu Liu](), [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>✉</sup><br>The Chinese Univerisity of Hong Kong\n\nContact: wuyangli@cuhk.edu.hk\n\n## 💡 Environment \nYou can change the torch and Cuda versions to satisfy your device.\n```bash\nconda create --name UKAN python=3.10\nconda activate UKAN\nconda install cudatoolkit=11.3\npip install -r requirement.txt\n```\n\n## 🖼️ Gallery of Diffusion UKAN \n\n![image](./assets/gen.png)\n\n## 📚 Prepare datasets\nDownload the pre-processed dataset from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/ESqX-V_eLSBEuaJXAzf64JMB16xF9kz3661pJSwQ-hOspg?e=XdABCH) and unzip it into the project folder. The data is pre-processed by the scripts in [tools](./tools).\n```\nDiffusion_UKAN\n|    data\n|    └─ cvc\n|        └─ images_64\n|    └─ busi\n|        └─ images_64\n|    └─ glas\n|        └─ images_64\n```\n## 📦 Prepare pre-trained models\n\nDownload released_models from [Onedrive](https://gocuhk-my.sharepoint.com/:u:/g/personal/wuyangli_cuhk_edu_hk/EUVSH8QFUmpJlxyoEj8Pr2IB8PzGbVJg53rc6GcqxGgLDg?e=a4glNt) and unzip it in the project folder.\n```\nDiffusion_UKAN\n|    released_models\n|    └─ ukan_cvc\n|        └─ FinalCheck   # generated toy images (see next section)\n|        └─ Gens         # the generated images used for evaluation in our paper\n|        └─ Tmp          # saved generated images during model training with a 50-epoch interval\n|        └─ Weights      # The final checkpoint\n|        └─ FID.txt      # raw evaluation data \n|        └─ IS.txt       # raw evaluation data  \n|    └─ ukan_busi\n|    └─ ukan_glas\n```\n## 🧸 Toy example\nImages will be generated in `released_models/ukan_cvc/FinalCheck` by running this:\n\n```python\npython Main_Test.py\n```\n## 🔥 Training\n<!-- You may need to modify the dirs slightly. -->\nPlease refer to the [training_scripts](./training_scripts) folder. Besides, you can play with different network variations by modifying `MODEL` according to the following dictionary,\n\n```python\nmodel_dict = {\n    'UNet': UNet,\n    'UNet_ConvKan': UNet_ConvKan,\n    'UMLP': UMLP,\n    'UKan_Hybrid': UKan_Hybrid,\n    'UNet_Baseline': UNet_Baseline,\n}\n```\n\n\n## 🤞 Acknowledgement \nThanks for \nWe mainly appreciate these excellent projects\n- [Simple DDPM](https://github.com/zoubohao/DenoisingDiffusionProbabilityModel-ddpm-) \n- [Kolmogorov-Arnold Network](https://github.com/mintisan/awesome-kan) \n- [Efficient Kolmogorov-Arnold Network](https://github.com/Blealtan/efficient-kan.git)\n\n\n## 📜Citation\nIf you find this work helpful for your project, please consider citing the following paper:\n```\n@article{li2024ukan,\n  title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation},\n  author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan},\n  journal={arXiv preprint arXiv:2406.02918},\n  year={2024}\n}\n```\n\n"
  },
  {
    "path": "Diffusion_UKAN/Scheduler.py",
    "content": "from torch.optim.lr_scheduler import _LRScheduler\n\nclass GradualWarmupScheduler(_LRScheduler):\n    def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):\n        self.multiplier = multiplier\n        self.total_epoch = warm_epoch\n        self.after_scheduler = after_scheduler\n        self.finished = False\n        self.last_epoch = None\n        self.base_lrs = None\n        super().__init__(optimizer)\n\n    def get_lr(self):\n        if self.last_epoch > self.total_epoch:\n            if self.after_scheduler:\n                if not self.finished:\n                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]\n                    self.finished = True\n                return self.after_scheduler.get_lr()\n            return [base_lr * self.multiplier for base_lr in self.base_lrs]\n        return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]\n\n\n    def step(self, epoch=None, metrics=None):\n        if self.finished and self.after_scheduler:\n            if epoch is None:\n                self.after_scheduler.step(None)\n            else:\n                self.after_scheduler.step(epoch - self.total_epoch)\n        else:\n            return super(GradualWarmupScheduler, self).step(epoch)"
  },
  {
    "path": "Diffusion_UKAN/data/readme.txt",
    "content": "download data.zip and unzip here"
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/LICENSE.md",
    "content": "Copyright 2017 Shane T. Barratt\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE."
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/README.md",
    "content": "# Inception Score Pytorch\n\nPytorch was lacking code to calculate the Inception Score for GANs. This repository fills this gap.\nHowever, we do not recommend using the Inception Score to evaluate generative models, see [our note](https://arxiv.org/abs/1801.01973) for why.\n\n## Getting Started\n\nClone the repository and navigate to it:\n```\n$ git clone git@github.com:sbarratt/inception-score-pytorch.git\n$ cd inception-score-pytorch\n```\n\nTo generate random 64x64 images and calculate the inception score, do the following:\n```\n$ python inception_score.py\n```\n\nThe only function is `inception_score`. It takes a list of numpy images normalized to the range [0,1] and a set of arguments and then calculates the inception score. Please assure your images are 3x299x299 and if not (e.g. your GAN was trained on CIFAR), pass `resize=True` to the function to have it automatically resize using bilinear interpolation before passing the images to the inception network.\n\n```python\ndef inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1):\n    \"\"\"Computes the inception score of the generated images imgs\n    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]\n    cuda -- whether or not to run on GPU\n    batch_size -- batch size for feeding into Inception v3\n    splits -- number of splits\n    \"\"\"\n```\n\n### Prerequisites\n\nYou will need [torch](http://pytorch.org/), [torchvision](https://github.com/pytorch/vision), [numpy/scipy](https://scipy.org/).\n\n## License\n\nThis project is licensed under the MIT License - see the [LICENSE.md](LICENSE.md) file for details\n\n## Acknowledgments\n\n* Inception Score from [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)\n"
  },
  {
    "path": "Diffusion_UKAN/inception-score-pytorch/inception_score.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.nn import functional as F\nimport torch.utils.data\nfrom torchvision.models.inception import inception_v3\nimport os\nfrom skimage import io\nimport cv2\nimport os\nimport numpy as np\nfrom scipy.stats import entropy\nimport torchvision.datasets as dset\nimport torchvision.transforms as transforms\n\nimport argparse\ndef inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=32):\n    \"\"\"Computes the inception score of the generated images imgs\n\n    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]\n    cuda -- whether or not to run on GPU\n    batch_size -- batch size for feeding into Inception v3\n    splits -- number of splits\n    \"\"\"\n    N = len(imgs)\n\n    assert batch_size > 0\n    assert N > batch_size\n\n    # Set up dtype\n    if cuda:\n        dtype = torch.cuda.FloatTensor\n    else:\n        if torch.cuda.is_available():\n            print(\"WARNING: You have a CUDA device, so you should probably set cuda=True\")\n        dtype = torch.FloatTensor\n\n    # Set up dataloader\n    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)\n\n    # Load inception model\n    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)\n    inception_model.eval();\n    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)\n    def get_pred(x):\n        if resize:\n            x = up(x)\n        x = inception_model(x)\n        return F.softmax(x).data.cpu().numpy()\n\n    # Get predictions\n    preds = np.zeros((N, 1000))\n\n    for i, batch in enumerate(dataloader, 0):\n        batch = batch.type(dtype)\n        batchv = Variable(batch)\n        batch_size_i = batch.size()[0]\n\n        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)\n\n    # Now compute the mean kl-div\n    split_scores = []\n\n    for k in range(splits):\n        part = preds[k * (N // splits): (k+1) * (N // splits), :]\n        py = np.mean(part, axis=0)\n        scores = []\n        for i in range(part.shape[0]):\n            pyx = part[i, :]\n            scores.append(entropy(pyx, py))\n        split_scores.append(np.exp(np.mean(scores)))\n\n    return np.mean(split_scores), np.std(split_scores)\n\nclass UnlabeledDataset(torch.utils.data.Dataset):\n    def __init__(self, folder, transform=None):\n        self.folder = folder\n        self.transform = transform\n        self.image_files = os.listdir(folder)\n\n    def __len__(self):\n        return len(self.image_files)\n\n    def __getitem__(self, idx):\n        image_file = self.image_files[idx]\n        image_path = os.path.join(self.folder, image_file)\n        image = io.imread(image_path)\n  \n        if self.transform:\n            image = self.transform(image)\n        return image\n    \nclass IgnoreLabelDataset(torch.utils.data.Dataset):\n    def __init__(self, orig):\n        self.orig = orig\n\n    def __getitem__(self, index):\n        return self.orig[index][0]\n\n    def __len__(self):\n        return len(self.orig)\n\n\nif __name__ == '__main__':\n\n    # cifar = dset.CIFAR10(root='data/', download=True,\n    #                          transform=transforms.Compose([\n    #                              transforms.Resize(32),``\n    #                              transforms.ToTensor(),\n    #                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n    #                          ])\n    # )\n\n    transform = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n        ])\n\n\n    # set args\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data-root', type=str, default='/data/wyli/code/TinyDDPM/Output/unet_busi/Gens/')\n\n    args = parser.parse_args()\n\n    dataset = UnlabeledDataset(args.data_root, transform=transform)\n    \n    print (\"Calculating Inception Score...\")\n    print (inception_score(dataset, cuda=True, batch_size=1, resize=True, splits=10))\n\n\n"
  },
  {
    "path": "Diffusion_UKAN/released_models/readme.txt",
    "content": "download released_models.zip and unzip here"
  },
  {
    "path": "Diffusion_UKAN/requirements.txt",
    "content": "pytorch-fid==0.30.0\ntorch==2.3.0\ntorchvision==0.18.0\ntqdm\ntimm==0.9.16\nscikit-image==0.23.1"
  },
  {
    "path": "Diffusion_UKAN/tools/resive_cvc.py",
    "content": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\n\n# Define the source and destination directories\nsrc_dir = '/data/wyli/data/CVC-ClinicDB/Original/'\ndst_dir = '/data/wyli/data/cvc/images_64/'\n\nos.makedirs(dst_dir, exist_ok=True)\n\n# Get a list of all the image files in the source directory\nimage_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]\n\n# Define the size of the crop box\ncrop_size = np.array([288 ,288])\n\n# Define the size of the resized image\nresize_size = (64, 64)\n\nfor image_file in image_files:\n    # Load the image\n    image = io.imread(os.path.join(src_dir, image_file))\n    # print(image.shape)\n\n    # Calculate the center of the image\n    center = np.array(image.shape[:2]) // 2\n\n    # Calculate the start and end points of the crop box\n    start = center - crop_size // 2\n    end = start + crop_size\n\n    # Crop the image\n    cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])\n\n    # Resize the cropped image\n    resized_image = transform.resize(cropped_image, resize_size, mode='reflect')\n\n    # Save the resized image to the destination directory\n    io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))"
  },
  {
    "path": "Diffusion_UKAN/tools/resize_busi.py",
    "content": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\n\n# Define the source and destination directories\nsrc_dir = '/data/wyli/data/busi/images/'\ndst_dir = '/data/wyli/data/busi/images_64/'\n\nos.makedirs(dst_dir, exist_ok=True)\n\n# Get a list of all the image files in the source directory\nimage_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]\n\n# Define the size of the crop box\ncrop_size = np.array([400 ,400])\n\n# Define the size of the resized image\n# resize_size = (64, 64)\nresize_size = (64, 64)\n\nfor image_file in image_files:\n    # Load the image\n    image = io.imread(os.path.join(src_dir, image_file))\n    print(image.shape)\n\n\n    # Calculate the center of the image\n    center = np.array(image.shape[:2]) // 2\n\n    # Calculate the start and end points of the crop box\n    start = center - crop_size // 2\n    end = start + crop_size\n\n    # Crop the image\n    cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])\n\n    # Resize the cropped image\n    resized_image = transform.resize(cropped_image, resize_size, mode='reflect')\n\n    # Save the resized image to the destination directory\n    io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))"
  },
  {
    "path": "Diffusion_UKAN/tools/resize_glas.py",
    "content": "import os\nfrom skimage import io, transform\nfrom skimage.util import img_as_ubyte\nimport numpy as np\nimport random\n\n# Define the source and destination directories\nsrc_dir = '/data/wyli/data/glas/images/'\ndst_dir = '/data/wyli/data/glas/images_64/'\n\nos.makedirs(dst_dir, exist_ok=True)\n\n# Get a list of all the image files in the source directory\nimage_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]\n\n# Define the size of the crop box\ncrop_size = np.array([64, 64])\n\n# Define the number of crops per image\nK = 5\n\nfor image_file in image_files:\n    # Load the image\n    image = io.imread(os.path.join(src_dir, image_file))\n\n    # Get the size of the image\n    image_size = np.array(image.shape[:2])\n\n    for i in range(K):\n        # Calculate a random start point for the crop box\n        start = np.array([random.randint(0, image_size[0] - crop_size[0]), random.randint(0, image_size[1] - crop_size[1])])\n\n        # Calculate the end point of the crop box\n        end = start + crop_size\n\n        # Crop the image\n        cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])\n\n        # Save the cropped image to the destination directory\n        io.imsave(os.path.join(dst_dir, f\"{image_file}_{i}.png\"), cropped_image)"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/busi.sh",
    "content": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_cvc'\nSAVE_ROOT='./Output/'\nDATASET='busi'\n\ncd ../\n\nCUDA_VISIBLE_DEVICES=${GPU} python Main.py \\\n--model ${MODEL} \\\n--exp_nme ${EXP_NME}  \\\n--batch_size 32  \\\n--channel 64 \\\n--dataset ${DATASET} \\\n--epoch 5000 \\\n--save_root ${SAVE_ROOT} \n# --lr 1e-4 \n\n# calcuate FID and IS\nCUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid \"data/${DATASET}/images_64/\" \"${SAVE_ROOT}/${EXP_NME}/Gens\" > \"${SAVE_ROOT}/${EXP_NME}/FID.txt\" 2>&1\n\ncd inception-score-pytorch\n\nCUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root \"${SAVE_ROOT}/${EXP_NME}/Gens\"  > \"${SAVE_ROOT}/${EXP_NME}/IS.txt\" 2>&1\n"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/cvc.sh",
    "content": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_cvc'\nSAVE_ROOT='./Output/'\nDATASET='cvc'\n\ncd ../\n\nCUDA_VISIBLE_DEVICES=${GPU} python Main.py \\\n--model ${MODEL} \\\n--exp_nme ${EXP_NME}  \\\n--batch_size 32  \\\n--channel 64 \\\n--dataset ${DATASET} \\\n--epoch 1000 \\\n--save_root ${SAVE_ROOT} \n# --lr 1e-4 \n\n# calcuate FID and IS\nCUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid \"data/${DATASET}/images_64/\" \"${SAVE_ROOT}/${EXP_NME}/Gens\" > \"${SAVE_ROOT}/${EXP_NME}/FID.txt\" 2>&1\n\ncd inception-score-pytorch\n\nCUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root \"${SAVE_ROOT}/${EXP_NME}/Gens\"  > \"${SAVE_ROOT}/${EXP_NME}/IS.txt\" 2>&1\n"
  },
  {
    "path": "Diffusion_UKAN/training_scripts/glas.sh",
    "content": "##!/bin/bash\nsource ~/miniconda3/etc/profile.d/conda.sh\n\nconda activate kan\n\nGPU=0\nMODEL='UKan_Hybrid'\nEXP_NME='UKan_glas'\nSAVE_ROOT='./Output/'\nDATASET='glas'\n\ncd ../\n\nCUDA_VISIBLE_DEVICES=${GPU} python Main.py \\\n--model ${MODEL} \\\n--exp_nme ${EXP_NME}  \\\n--batch_size 32  \\\n--channel 64 \\\n--dataset ${DATASET} \\\n--epoch 1000 \\\n--save_root ${SAVE_ROOT} \n# --lr 1e-4 \n\n# calcuate FID and IS\nCUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid \"data/${DATASET}/images_64/\" \"${SAVE_ROOT}/${EXP_NME}/Gens\" > \"${SAVE_ROOT}/${EXP_NME}/FID.txt\" 2>&1\n\ncd inception-score-pytorch\n\nCUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root \"${SAVE_ROOT}/${EXP_NME}/Gens\"  > \"${SAVE_ROOT}/${EXP_NME}/IS.txt\" 2>&1\n"
  },
  {
    "path": "README.md",
    "content": "# U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation\n\n:pushpin: This is an official PyTorch implementation of **U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**\n\n[[`Project Page`](https://yes-u-kan.github.io/)] [[`arXiv`](https://arxiv.org/abs/2406.02918)] [[`BibTeX`](#citation)]\n\n<p align=\"center\">\n  <img src=\"./assets/logo_1.png\" alt=\"\" width=\"120\" height=\"120\">\n</p>\n\n> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>\n> [Chenxin Li](https://xggnet.github.io/)<sup>1\\*</sup>, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)<sup>1\\*</sup>, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)<sup>1\\*</sup>, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)<sup>1\\*</sup>, [Hengyu Liu](https://liuhengyu321.github.io/)<sup>1</sup>, [Yifan Liu](https://yifliu3.github.io/)<sup>1</sup>, [Chen Zhen](https://franciszchen.github.io/)<sup>2</sup>, [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>1✉</sup><br> <sup>1</sup>The Chinese Univerisity of Hong Kong, <sup>2</sup>Centre for Artificial Intelligence and Robotics, Hong Kong\n\nWe explore the untapped potential of Kolmogorov-Anold Network (aka. KAN) in improving backbones for vision tasks. We investigate, modify and re-design the established U-Net pipeline by integrating the dedicated KAN layers on the tokenized intermediate representation, termed U-KAN. Rigorous medical image segmentation benchmarks verify the superiority of U-KAN by higher accuracy even with less computation cost. We further delved into the potential of U-KAN as an alternative U-Net noise predictor in diffusion models, demonstrating its applicability in generating task-oriented model architectures. These endeavours unveil valuable insights and sheds light on the prospect that with U-KAN, you can make strong backbone for medical image segmentation and generation.\n\n<div align=\"center\">\n    <img width=\"100%\" alt=\"UKAN overview\" src=\"assets/framework-1.jpg\"/>\n</div>\n\n## 📰News\n\n **[NOTE]** Random seed is essential for eval metric, and all reported results are calculated over three random runs with seeds of 2981, 6142, 1187, following rolling-UNet. We think most issues are related with this.\n\n**[2024.10]** U-KAN is accepted by AAAI-25. \n\n**[2024.6]** Some modifications are done in Seg_UKAN for better performance reproduction. The previous code can be quickly updated by replacing the contents of train.py and archs.py with the new ones.\n\n**[2024.6]** Model checkpoints and training logs are released!\n\n**[2024.6]** Code and paper of U-KAN are released!\n\n## 💡Key Features\n- The first effort to incorporate the advantage of emerging KAN to improve established U-Net pipeline to be more **accurate, efficient and interpretable**.\n- A Segmentation U-KAN with **tokenized KAN block to effectively steer the KAN operators** to be compatible with the exiting convolution-based designs.\n- A Diffusion U-KAN as an **improved noise predictor** demonstrates its potential in backboning generative tasks and broader vision settings.\n\n## 🛠Setup\n\n```bash\ngit clone https://github.com/CUHK-AIM-Group/U-KAN.git\ncd U-KAN\nconda create -n ukan python=3.10\nconda activate ukan\ncd Seg_UKAN && pip install -r requirements.txt\n```\n\n**Tips A**: We test the framework using pytorch=1.13.0, and the CUDA compile version=11.6. Other versions should be also fine but not totally ensured.\n\n\n## 📚Data Preparation\n**BUSI**:  The dataset can be found [here](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset). \n\n**GLAS**:  The dataset can be found [here](https://websignon.warwick.ac.uk/origin/slogin?shire=https%3A%2F%2Fwarwick.ac.uk%2Fsitebuilder2%2Fshire-read&providerId=urn%3Awarwick.ac.uk%3Asitebuilder2%3Aread%3Aservice&target=https%3A%2F%2Fwarwick.ac.uk%2Ffac%2Fcross_fac%2Ftia%2Fdata%2Fglascontest&status=notloggedin). \n<!-- You can directly use the [processed GLAS data]() without further data processing. -->\n**CVC-ClinicDB**:  The dataset can be found [here](https://www.dropbox.com/s/p5qe9eotetjnbmq/CVC-ClinicDB.rar?e=3&dl=0). \n<!-- You can directly use the [processed CVC-ClinicDB data]() without further data processing. -->\n\nWe also provide all the [pre-processed dataset](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/ErDlT-t0WoBNlKhBlbYfReYB-iviSCmkNRb1GqZ90oYjJA?e=hrPNWD) without requiring any further data processing. You can directly download and put them into the data dir.\n\n\n\nThe resulted file structure is as follows.\n```\nSeg_UKAN\n├── inputs\n│   ├── busi\n│     ├── images\n│           ├── malignant (1).png\n|           ├── ...\n|     ├── masks\n│        ├── 0\n│           ├── malignant (1)_mask.png\n|           ├── ...\n│   ├── GLAS\n│     ├── images\n│           ├── 0.png\n|           ├── ...\n|     ├── masks\n│        ├── 0\n│           ├── 0.png\n|           ├── ...\n│   ├── CVC-ClinicDB\n│     ├── images\n│           ├── 0.png\n|           ├── ...\n|     ├── masks\n│        ├── 0\n│           ├── 0.png\n|           ├── ...\n```\n\n## 🔖Evaluating Segmentation U-KAN\n\nYou can directly evaluate U-KAN from the checkpoint model. Here is an example for quick usage for using our **pre-trained models** in [Segmentation Model Zoo](#segmentation-model-zoo):\n1. Download the pre-trained weights and put them to ```{args.output_dir}/{args.name}/model.pth```\n2. Run the following scripts to \n```bash\ncd Seg_UKAN\npython val.py --name ${dataset}_UKAN --output_dir [YOUR_OUTPUT_DIR] \n```\n\n## ⏳Training Segmentation U-KAN\n\nYou can simply train U-KAN on a single GPU by specifing the dataset name ```--dataset``` and input size ```--input_size```.\n```bash\ncd Seg_UKAN\npython train.py --arch UKAN --dataset {dataset} --input_w {input_size} --input_h {input_size} --name {dataset}_UKAN  --data_dir [YOUR_DATA_DIR]\n```\nFor example, train U-KAN with the resolution of 256x256 with a single GPU on the BUSI dataset in the ```inputs``` dir:\n```bash\ncd Seg_UKAN\npython train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN  --data_dir ./inputs\n```\nPlease see Seg_UKAN/scripts.sh for more details.\nNote that the resolution of glas is 512x512, differing with other datasets (256x256).\n\n**[Quick Update]** Please follow the seeds of 2981, 6142, 1187 to fully reproduce the paper experimental results. All compared methods are evaluated on the same seed setting.\n\n## 🎪Segmentation Model Zoo\nWe provide all the pre-trained model [checkpoints](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ej6yZBSIrU5Ds9q-gQdhXqwBbpov5_MaWF483uZHm2lccA?e=rmlHMo)\nHere is an overview of the released performance&checkpoints. Note that results on a single run and the reported average results in the paper differ.\n|Method| Dataset | IoU | F1  | Checkpoints |\n|-----|------|-----|-----|-----|\n|Seg U-KAN| BUSI | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|\n|Seg U-KAN| GLAS | 87.51 | 93.33 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EunQ9KRf6n1AqCJ40FWZF-QB25GMOoF7hoIwU15fefqFbw?e=m7kXwe)|\n|Seg U-KAN| CVC-ClinicDB | 85.61 | 92.19 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ekhb3PEmwZZMumSG69wPRRQBymYIi0PFNuLJcVNmmK1fjA?e=5XzVSi)|\n\nThe parameter ``--no_kan'' denotes the baseline model that is replaced the KAN layers with MLP layers. Please note: it is reasonable to find occasional inconsistencies in performance, and the average results over multiple runs can reveal a more obvious trend.\n|Method| Layer Type | IoU | F1  | Checkpoints |\n|-----|------|-----|-----|-----|\n|Seg U-KAN (--no_kan)| MLP Layer  | 63.49 |\t77.07 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EmEH_qokqIFNtP59yU7vY_4Bq4Yc424zuYufwaJuiAGKiw?e=IJ3clx)|\n|Seg U-KAN| KAN Layer |  65.26 | 78.75  | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|\n\n## 🎇Medical Image Generation with Diffusion U-KAN\n\nPlease refer to [Diffusion_UKAN](./Diffusion_UKAN/README.md)\n\n\n## 🛒TODO List\n- [X] Release code for Seg U-KAN.\n- [X] Release code for Diffusion U-KAN.\n- [X] Upload the pretrained checkpoints.\n\n\n## 🎈Acknowledgements\nGreatly appreciate the tremendous effort for the following projects!\n- [CKAN](https://github.com/AntonioTepsich/Convolutional-KANs)\n\n\n## 📜Citation\nIf you find this work helpful for your project,please consider citing the following paper:\n```\n@article{li2024ukan,\n  title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation},\n  author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan},\n  journal={arXiv preprint arXiv:2406.02918},\n  year={2024}\n'''\n}\n"
  },
  {
    "path": "Seg_UKAN/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Jeya Maria Jose\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "Seg_UKAN/archs.py",
    "content": "import torch\nfrom torch import nn\nimport torch\nimport torchvision\nfrom torch import nn\nfrom torch.autograd import Variable\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms\nfrom torchvision.utils import save_image\nimport torch.nn.functional as F\nimport os\nimport matplotlib.pyplot as plt\nfrom utils import *\n\nimport timm\nfrom timm.models.layers import DropPath, to_2tuple, trunc_normal_\nimport types\nimport math\nfrom abc import ABCMeta, abstractmethod\n# from mmcv.cnn import ConvModule\nfrom pdb import set_trace as st\n\nfrom kan import KANLinear, KAN\nfrom torch.nn import init\n\n\nclass KANLayer(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.dim = in_features\n        \n        grid_size=5\n        spline_order=3\n        scale_noise=0.1\n        scale_base=1.0\n        scale_spline=1.0\n        base_activation=torch.nn.SiLU\n        grid_eps=0.02\n        grid_range=[-1, 1]\n\n        if not no_kan:\n            self.fc1 = KANLinear(\n                        in_features,\n                        hidden_features,\n                        grid_size=grid_size,\n                        spline_order=spline_order,\n                        scale_noise=scale_noise,\n                        scale_base=scale_base,\n                        scale_spline=scale_spline,\n                        base_activation=base_activation,\n                        grid_eps=grid_eps,\n                        grid_range=grid_range,\n                    )\n            self.fc2 = KANLinear(\n                        hidden_features,\n                        out_features,\n                        grid_size=grid_size,\n                        spline_order=spline_order,\n                        scale_noise=scale_noise,\n                        scale_base=scale_base,\n                        scale_spline=scale_spline,\n                        base_activation=base_activation,\n                        grid_eps=grid_eps,\n                        grid_range=grid_range,\n                    )\n            self.fc3 = KANLinear(\n                        hidden_features,\n                        out_features,\n                        grid_size=grid_size,\n                        spline_order=spline_order,\n                        scale_noise=scale_noise,\n                        scale_base=scale_base,\n                        scale_spline=scale_spline,\n                        base_activation=base_activation,\n                        grid_eps=grid_eps,\n                        grid_range=grid_range,\n                    )\n            # # TODO   \n            # self.fc4 = KANLinear(\n            #             hidden_features,\n            #             out_features,\n            #             grid_size=grid_size,\n            #             spline_order=spline_order,\n            #             scale_noise=scale_noise,\n            #             scale_base=scale_base,\n            #             scale_spline=scale_spline,\n            #             base_activation=base_activation,\n            #             grid_eps=grid_eps,\n            #             grid_range=grid_range,\n            #         )   \n\n        else:\n            self.fc1 = nn.Linear(in_features, hidden_features)\n            self.fc2 = nn.Linear(hidden_features, out_features)\n            self.fc3 = nn.Linear(hidden_features, out_features)\n\n        # TODO\n        # self.fc1 = nn.Linear(in_features, hidden_features)\n\n\n        self.dwconv_1 = DW_bn_relu(hidden_features)\n        self.dwconv_2 = DW_bn_relu(hidden_features)\n        self.dwconv_3 = DW_bn_relu(hidden_features)\n\n        # # TODO\n        # self.dwconv_4 = DW_bn_relu(hidden_features)\n    \n        self.drop = nn.Dropout(drop)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n    \n\n    def forward(self, x, H, W):\n        # pdb.set_trace()\n        B, N, C = x.shape\n\n        x = self.fc1(x.reshape(B*N,C))\n        x = x.reshape(B,N,C).contiguous()\n        x = self.dwconv_1(x, H, W)\n        x = self.fc2(x.reshape(B*N,C))\n        x = x.reshape(B,N,C).contiguous()\n        x = self.dwconv_2(x, H, W)\n        x = self.fc3(x.reshape(B*N,C))\n        x = x.reshape(B,N,C).contiguous()\n        x = self.dwconv_3(x, H, W)\n\n        # # TODO\n        # x = x.reshape(B,N,C).contiguous()\n        # x = self.dwconv_4(x, H, W)\n    \n        return x\n\nclass KANBlock(nn.Module):\n    def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):\n        super().__init__()\n\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim)\n\n        self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x, H, W):\n        x = x + self.drop_path(self.layer(self.norm2(x), H, W))\n\n        return x\n\n\nclass DWConv(nn.Module):\n    def __init__(self, dim=768):\n        super(DWConv, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass DW_bn_relu(nn.Module):\n    def __init__(self, dim=768):\n        super(DW_bn_relu, self).__init__()\n        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)\n        self.bn = nn.BatchNorm2d(dim)\n        self.relu = nn.ReLU()\n\n    def forward(self, x, H, W):\n        B, N, C = x.shape\n        x = x.transpose(1, 2).view(B, C, H, W)\n        x = self.dwconv(x)\n        x = self.bn(x)\n        x = self.relu(x)\n        x = x.flatten(2).transpose(1, 2)\n\n        return x\n\nclass PatchEmbed(nn.Module):\n    \"\"\" Image to Patch Embedding\n    \"\"\"\n\n    def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]\n        self.num_patches = self.H * self.W\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,\n                              padding=(patch_size[0] // 2, patch_size[1] // 2))\n        self.norm = nn.LayerNorm(embed_dim)\n\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n        elif isinstance(m, nn.Conv2d):\n            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n            fan_out //= m.groups\n            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n            if m.bias is not None:\n                m.bias.data.zero_()\n\n    def forward(self, x):\n        x = self.proj(x)\n        _, _, H, W = x.shape\n        x = x.flatten(2).transpose(1, 2)\n        x = self.norm(x)\n\n        return x, H, W\n\n\nclass ConvLayer(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(ConvLayer, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, out_ch, 3, padding=1),\n            nn.BatchNorm2d(out_ch),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(out_ch, out_ch, 3, padding=1),\n            nn.BatchNorm2d(out_ch),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, input):\n        return self.conv(input)\n\nclass D_ConvLayer(nn.Module):\n    def __init__(self, in_ch, out_ch):\n        super(D_ConvLayer, self).__init__()\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_ch, in_ch, 3, padding=1),\n            nn.BatchNorm2d(in_ch),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(in_ch, out_ch, 3, padding=1),\n            nn.BatchNorm2d(out_ch),\n            nn.ReLU(inplace=True)\n        )\n\n    def forward(self, input):\n        return self.conv(input)\n\n\n\nclass UKAN(nn.Module):\n    def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, patch_size=16, in_chans=3, embed_dims=[256, 320, 512], no_kan=False,\n    drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[1, 1, 1], **kwargs):\n        super().__init__()\n\n        kan_input_dim = embed_dims[0]\n\n        self.encoder1 = ConvLayer(3, kan_input_dim//8)  \n        self.encoder2 = ConvLayer(kan_input_dim//8, kan_input_dim//4)  \n        self.encoder3 = ConvLayer(kan_input_dim//4, kan_input_dim)\n\n        self.norm3 = norm_layer(embed_dims[1])\n        self.norm4 = norm_layer(embed_dims[2])\n\n        self.dnorm3 = norm_layer(embed_dims[1])\n        self.dnorm4 = norm_layer(embed_dims[0])\n\n        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]\n\n        self.block1 = nn.ModuleList([KANBlock(\n            dim=embed_dims[1], \n            drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer\n            )])\n\n        self.block2 = nn.ModuleList([KANBlock(\n            dim=embed_dims[2],\n            drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer\n            )])\n\n        self.dblock1 = nn.ModuleList([KANBlock(\n            dim=embed_dims[1], \n            drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer\n            )])\n\n        self.dblock2 = nn.ModuleList([KANBlock(\n            dim=embed_dims[0], \n            drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer\n            )])\n\n        self.patch_embed3 = PatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])\n        self.patch_embed4 = PatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])\n\n        self.decoder1 = D_ConvLayer(embed_dims[2], embed_dims[1])  \n        self.decoder2 = D_ConvLayer(embed_dims[1], embed_dims[0])  \n        self.decoder3 = D_ConvLayer(embed_dims[0], embed_dims[0]//4) \n        self.decoder4 = D_ConvLayer(embed_dims[0]//4, embed_dims[0]//8)\n        self.decoder5 = D_ConvLayer(embed_dims[0]//8, embed_dims[0]//8)\n\n        self.final = nn.Conv2d(embed_dims[0]//8, num_classes, kernel_size=1)\n        self.soft = nn.Softmax(dim =1)\n\n    def forward(self, x):\n        \n        B = x.shape[0]\n        ### Encoder\n        ### Conv Stage\n\n        ### Stage 1\n        out = F.relu(F.max_pool2d(self.encoder1(x), 2, 2))\n        t1 = out\n        ### Stage 2\n        out = F.relu(F.max_pool2d(self.encoder2(out), 2, 2))\n        t2 = out\n        ### Stage 3\n        out = F.relu(F.max_pool2d(self.encoder3(out), 2, 2))\n        t3 = out\n\n        ### Tokenized KAN Stage\n        ### Stage 4\n\n        out, H, W = self.patch_embed3(out)\n        for i, blk in enumerate(self.block1):\n            out = blk(out, H, W)\n        out = self.norm3(out)\n        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        t4 = out\n\n        ### Bottleneck\n\n        out, H, W= self.patch_embed4(out)\n        for i, blk in enumerate(self.block2):\n            out = blk(out, H, W)\n        out = self.norm4(out)\n        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        ### Stage 4\n        out = F.relu(F.interpolate(self.decoder1(out), scale_factor=(2,2), mode ='bilinear'))\n\n        out = torch.add(out, t4)\n        _, _, H, W = out.shape\n        out = out.flatten(2).transpose(1,2)\n        for i, blk in enumerate(self.dblock1):\n            out = blk(out, H, W)\n\n        ### Stage 3\n        out = self.dnorm3(out)\n        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n        out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bilinear'))\n        out = torch.add(out,t3)\n        _,_,H,W = out.shape\n        out = out.flatten(2).transpose(1,2)\n        \n        for i, blk in enumerate(self.dblock2):\n            out = blk(out, H, W)\n\n        out = self.dnorm4(out)\n        out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()\n\n        out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bilinear'))\n        out = torch.add(out,t2)\n        out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2),mode ='bilinear'))\n        out = torch.add(out,t1)\n        out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))\n\n        return self.final(out)\n"
  },
  {
    "path": "Seg_UKAN/config.py",
    "content": "# --------------------------------------------------------\n# Swin Transformer\n# Copyright (c) 2021 Microsoft\n# Licensed under The MIT License [see LICENSE for details]\n# Written by Ze Liu\n# --------------------------------------------------------'\n\nimport os\nimport yaml\nfrom yacs.config import CfgNode as CN\n\n_C = CN()\n\n# Base config files\n_C.BASE = ['']\n\n# -----------------------------------------------------------------------------\n# Data settings\n# -----------------------------------------------------------------------------\n_C.DATA = CN()\n# Batch size for a single GPU, could be overwritten by command line argument\n_C.DATA.BATCH_SIZE = 1\n# Path to dataset, could be overwritten by command line argument\n_C.DATA.DATA_PATH = ''\n# Dataset name\n_C.DATA.DATASET = 'imagenet'\n# Input image size\n_C.DATA.IMG_SIZE = 256\n# Interpolation to resize image (random, bilinear, bicubic)\n_C.DATA.INTERPOLATION = 'bicubic'\n# Use zipped dataset instead of folder dataset\n# could be overwritten by command line argument\n_C.DATA.ZIP_MODE = False\n# Cache Data in Memory, could be overwritten by command line argument\n_C.DATA.CACHE_MODE = 'part'\n# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.\n_C.DATA.PIN_MEMORY = True\n# Number of data loading threads\n_C.DATA.NUM_WORKERS = 8\n\n# -----------------------------------------------------------------------------\n# Model settings\n# -----------------------------------------------------------------------------\n_C.MODEL = CN()\n# Model type\n_C.MODEL.TYPE = 'swin'\n# Model name\n_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'\n# Checkpoint to resume, could be overwritten by command line argument\n_C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'\n_C.MODEL.RESUME = ''\n# Number of classes, overwritten in data preparation\n_C.MODEL.NUM_CLASSES = 1000\n# Dropout rate\n_C.MODEL.DROP_RATE = 0.0\n# Drop path rate\n_C.MODEL.DROP_PATH_RATE = 0.1\n# Label Smoothing\n_C.MODEL.LABEL_SMOOTHING = 0.1\n\n# Swin Transformer parameters\n_C.MODEL.SWIN = CN()\n_C.MODEL.SWIN.PATCH_SIZE = 4\n_C.MODEL.SWIN.IN_CHANS = 3\n_C.MODEL.SWIN.EMBED_DIM = 96\n_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]\n_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]\n_C.MODEL.SWIN.WINDOW_SIZE = 4\n_C.MODEL.SWIN.MLP_RATIO = 4.\n_C.MODEL.SWIN.QKV_BIAS = True\n_C.MODEL.SWIN.QK_SCALE = False\n_C.MODEL.SWIN.APE = False\n_C.MODEL.SWIN.PATCH_NORM = True\n_C.MODEL.SWIN.FINAL_UPSAMPLE= \"expand_first\"\n\n# -----------------------------------------------------------------------------\n# Training settings\n# -----------------------------------------------------------------------------\n_C.TRAIN = CN()\n_C.TRAIN.START_EPOCH = 0\n_C.TRAIN.EPOCHS = 300\n_C.TRAIN.WARMUP_EPOCHS = 20\n_C.TRAIN.WEIGHT_DECAY = 0.05\n_C.TRAIN.BASE_LR = 5e-4\n_C.TRAIN.WARMUP_LR = 5e-7\n_C.TRAIN.MIN_LR = 5e-6\n# Clip gradient norm\n_C.TRAIN.CLIP_GRAD = 5.0\n# Auto resume from latest checkpoint\n_C.TRAIN.AUTO_RESUME = True\n# Gradient accumulation steps\n# could be overwritten by command line argument\n_C.TRAIN.ACCUMULATION_STEPS = 0\n# Whether to use gradient checkpointing to save memory\n# could be overwritten by command line argument\n_C.TRAIN.USE_CHECKPOINT = False\n\n# LR scheduler\n_C.TRAIN.LR_SCHEDULER = CN()\n_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'\n# Epoch interval to decay LR, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30\n# LR decay rate, used in StepLRScheduler\n_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1\n\n# Optimizer\n_C.TRAIN.OPTIMIZER = CN()\n_C.TRAIN.OPTIMIZER.NAME = 'adamw'\n# Optimizer Epsilon\n_C.TRAIN.OPTIMIZER.EPS = 1e-8\n# Optimizer Betas\n_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)\n# SGD momentum\n_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9\n\n# -----------------------------------------------------------------------------\n# Augmentation settings\n# -----------------------------------------------------------------------------\n_C.AUG = CN()\n# Color jitter factor\n_C.AUG.COLOR_JITTER = 0.4\n# Use AutoAugment policy. \"v0\" or \"original\"\n_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'\n# Random erase prob\n_C.AUG.REPROB = 0.25\n# Random erase mode\n_C.AUG.REMODE = 'pixel'\n# Random erase count\n_C.AUG.RECOUNT = 1\n# Mixup alpha, mixup enabled if > 0\n_C.AUG.MIXUP = 0.8\n# Cutmix alpha, cutmix enabled if > 0\n_C.AUG.CUTMIX = 1.0\n# Cutmix min/max ratio, overrides alpha and enables cutmix if set\n_C.AUG.CUTMIX_MINMAX = False\n# Probability of performing mixup or cutmix when either/both is enabled\n_C.AUG.MIXUP_PROB = 1.0\n# Probability of switching to cutmix when both mixup and cutmix enabled\n_C.AUG.MIXUP_SWITCH_PROB = 0.5\n# How to apply mixup/cutmix params. Per \"batch\", \"pair\", or \"elem\"\n_C.AUG.MIXUP_MODE = 'batch'\n\n# -----------------------------------------------------------------------------\n# Testing settings\n# -----------------------------------------------------------------------------\n_C.TEST = CN()\n# Whether to use center crop when testing\n_C.TEST.CROP = True\n\n# -----------------------------------------------------------------------------\n# Misc\n# -----------------------------------------------------------------------------\n# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')\n# overwritten by command line argument\n_C.AMP_OPT_LEVEL = ''\n# Path to output folder, overwritten by command line argument\n_C.OUTPUT = ''\n# Tag of experiment, overwritten by command line argument\n_C.TAG = 'default'\n# Frequency to save checkpoint\n_C.SAVE_FREQ = 1\n# Frequency to logging info\n_C.PRINT_FREQ = 10\n# Fixed random seed\n_C.SEED = 0\n# Perform evaluation only, overwritten by command line argument\n_C.EVAL_MODE = False\n# Test throughput only, overwritten by command line argument\n_C.THROUGHPUT_MODE = False\n# local rank for DistributedDataParallel, given by command line argument\n_C.LOCAL_RANK = 0\n\n\ndef _update_config_from_file(config, cfg_file):\n    config.defrost()\n    with open(cfg_file, 'r') as f:\n        yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)\n\n    for cfg in yaml_cfg.setdefault('BASE', ['']):\n        if cfg:\n            _update_config_from_file(\n                config, os.path.join(os.path.dirname(cfg_file), cfg)\n            )\n    print('=> merge config from {}'.format(cfg_file))\n    config.merge_from_file(cfg_file)\n    config.freeze()\n\n\ndef update_config(config, args):\n    _update_config_from_file(config, args.cfg)\n\n    config.defrost()\n    if args.opts:\n        config.merge_from_list(args.opts)\n\n    # merge from specific arguments\n    if args.batch_size:\n        config.DATA.BATCH_SIZE = args.batch_size\n    if args.zip:\n        config.DATA.ZIP_MODE = True\n    if args.cache_mode:\n        config.DATA.CACHE_MODE = args.cache_mode\n    if args.resume:\n        config.MODEL.RESUME = args.resume\n    if args.accumulation_steps:\n        config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps\n    if args.use_checkpoint:\n        config.TRAIN.USE_CHECKPOINT = True\n    if args.amp_opt_level:\n        config.AMP_OPT_LEVEL = args.amp_opt_level\n    if args.tag:\n        config.TAG = args.tag\n    if args.eval:\n        config.EVAL_MODE = True\n    if args.throughput:\n        config.THROUGHPUT_MODE = True\n\n    config.freeze()\n\n\ndef get_config(args):\n    \"\"\"Get a yacs CfgNode object with default values.\"\"\"\n    # Return a clone so that the defaults will not be altered\n    # This is for the \"local variable\" use pattern\n    config = _C.clone()\n    # update_config(config, args)\n\n    return config\n"
  },
  {
    "path": "Seg_UKAN/dataset.py",
    "content": "import os\n\nimport cv2\nimport numpy as np\nimport torch\nimport torch.utils.data\n\n\nclass Dataset(torch.utils.data.Dataset):\n    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):\n        \"\"\"\n        Args:\n            img_ids (list): Image ids.\n            img_dir: Image file directory.\n            mask_dir: Mask file directory.\n            img_ext (str): Image file extension.\n            mask_ext (str): Mask file extension.\n            num_classes (int): Number of classes.\n            transform (Compose, optional): Compose transforms of albumentations. Defaults to None.\n        \n        Note:\n            Make sure to put the files as the following structure:\n            <dataset name>\n            ├── images\n            |   ├── 0a7e06.jpg\n            │   ├── 0aab0a.jpg\n            │   ├── 0b1761.jpg\n            │   ├── ...\n            |\n            └── masks\n                ├── 0\n                |   ├── 0a7e06.png\n                |   ├── 0aab0a.png\n                |   ├── 0b1761.png\n                |   ├── ...\n                |\n                ├── 1\n                |   ├── 0a7e06.png\n                |   ├── 0aab0a.png\n                |   ├── 0b1761.png\n                |   ├── ...\n                ...\n        \"\"\"\n        self.img_ids = img_ids\n        self.img_dir = img_dir\n        self.mask_dir = mask_dir\n        self.img_ext = img_ext\n        self.mask_ext = mask_ext\n        self.num_classes = num_classes\n        self.transform = transform\n\n    def __len__(self):\n        return len(self.img_ids)\n\n    def __getitem__(self, idx):\n        img_id = self.img_ids[idx]\n        \n        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))\n\n        mask = []\n        for i in range(self.num_classes):\n\n            # print(os.path.join(self.mask_dir, str(i),\n            #             img_id + self.mask_ext))\n\n            mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),\n                        img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])\n        mask = np.dstack(mask)\n\n        if self.transform is not None:\n            augmented = self.transform(image=img, mask=mask)\n            img = augmented['image']\n            mask = augmented['mask']\n        \n        img = img.astype('float32') / 255\n        img = img.transpose(2, 0, 1)\n        mask = mask.astype('float32') / 255\n        mask = mask.transpose(2, 0, 1)\n\n        if mask.max()<1:\n            mask[mask>0] = 1.0\n\n        return img, mask, {'img_id': img_id}\n"
  },
  {
    "path": "Seg_UKAN/environment.yml",
    "content": "name: ukan\nchannels:\n  - defaults\ndependencies:\n  - _libgcc_mutex=0.1=main\n  - _openmp_mutex=4.5=1_gnu\n  - ca-certificates=2021.10.26=h06a4308_2\n  - certifi=2021.5.30=py36h06a4308_0\n  - ld_impl_linux-64=2.35.1=h7274673_9\n  - libffi=3.3=he6710b0_2\n  - libgcc-ng=9.3.0=h5101ec6_17\n  - libgomp=9.3.0=h5101ec6_17\n  - libstdcxx-ng=9.3.0=hd4cf53a_17\n  - ncurses=6.3=h7f8727e_2\n  - openssl=1.1.1l=h7f8727e_0\n  - pip=21.2.2=py36h06a4308_0\n  - python=3.6.13=h12debd9_1\n  - readline=8.1=h27cfd23_0\n  - setuptools=58.0.4=py36h06a4308_0\n  - sqlite=3.36.0=hc218d9a_0\n  - tk=8.6.11=h1ccaba5_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h7b6447c_0\n  - zlib=1.2.11=h7b6447c_3\n  - pip:\n    - addict==2.4.0\n    - dataclasses==0.8\n    - mmcv-full==1.2.7\n    - numpy==1.19.5\n    - opencv-python==4.5.1.48\n    - perceptual==0.1\n    - pillow==8.4.0\n    - scikit-image==0.17.2\n    - scipy==1.5.4\n    - tifffile==2020.9.3\n    - timm==0.3.2\n    - torch==1.7.1\n    - torchvision==0.8.2\n    - typing-extensions==4.0.0\n    - yapf==0.31.0\n# prefix: /home/jeyamariajose/anaconda3/envs/transweather\n\n"
  },
  {
    "path": "Seg_UKAN/kan.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass KANLinear(torch.nn.Module):\n    def __init__(\n        self,\n        in_features,\n        out_features,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        enable_standalone_scale_spline=True,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KANLinear, self).__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        h = (grid_range[1] - grid_range[0]) / grid_size\n        grid = (\n            (\n                torch.arange(-spline_order, grid_size + spline_order + 1) * h\n                + grid_range[0]\n            )\n            .expand(in_features, -1)\n            .contiguous()\n        )\n        self.register_buffer(\"grid\", grid)\n\n        self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))\n        self.spline_weight = torch.nn.Parameter(\n            torch.Tensor(out_features, in_features, grid_size + spline_order)\n        )\n        if enable_standalone_scale_spline:\n            self.spline_scaler = torch.nn.Parameter(\n                torch.Tensor(out_features, in_features)\n            )\n\n        self.scale_noise = scale_noise\n        self.scale_base = scale_base\n        self.scale_spline = scale_spline\n        self.enable_standalone_scale_spline = enable_standalone_scale_spline\n        self.base_activation = base_activation()\n        self.grid_eps = grid_eps\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)\n        with torch.no_grad():\n            noise = (\n                (\n                    torch.rand(self.grid_size + 1, self.in_features, self.out_features)\n                    - 1 / 2\n                )\n                * self.scale_noise\n                / self.grid_size\n            )\n            self.spline_weight.data.copy_(\n                (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)\n                * self.curve2coeff(\n                    self.grid.T[self.spline_order : -self.spline_order],\n                    noise,\n                )\n            )\n            if self.enable_standalone_scale_spline:\n                # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)\n                torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)\n\n    def b_splines(self, x: torch.Tensor):\n        \"\"\"\n        Compute the B-spline bases for the given input tensor.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n\n        Returns:\n            torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        grid: torch.Tensor = (\n            self.grid\n        )  # (in_features, grid_size + 2 * spline_order + 1)\n        x = x.unsqueeze(-1)\n        bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)\n        for k in range(1, self.spline_order + 1):\n            bases = (\n                (x - grid[:, : -(k + 1)])\n                / (grid[:, k:-1] - grid[:, : -(k + 1)])\n                * bases[:, :, :-1]\n            ) + (\n                (grid[:, k + 1 :] - x)\n                / (grid[:, k + 1 :] - grid[:, 1:(-k)])\n                * bases[:, :, 1:]\n            )\n\n        assert bases.size() == (\n            x.size(0),\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return bases.contiguous()\n\n    def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):\n        \"\"\"\n        Compute the coefficients of the curve that interpolates the given points.\n\n        Args:\n            x (torch.Tensor): Input tensor of shape (batch_size, in_features).\n            y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).\n\n        Returns:\n            torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).\n        \"\"\"\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        assert y.size() == (x.size(0), self.in_features, self.out_features)\n\n        A = self.b_splines(x).transpose(\n            0, 1\n        )  # (in_features, batch_size, grid_size + spline_order)\n        B = y.transpose(0, 1)  # (in_features, batch_size, out_features)\n        solution = torch.linalg.lstsq(\n            A, B\n        ).solution  # (in_features, grid_size + spline_order, out_features)\n        result = solution.permute(\n            2, 0, 1\n        )  # (out_features, in_features, grid_size + spline_order)\n\n        assert result.size() == (\n            self.out_features,\n            self.in_features,\n            self.grid_size + self.spline_order,\n        )\n        return result.contiguous()\n\n    @property\n    def scaled_spline_weight(self):\n        return self.spline_weight * (\n            self.spline_scaler.unsqueeze(-1)\n            if self.enable_standalone_scale_spline\n            else 1.0\n        )\n\n    def forward(self, x: torch.Tensor):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n\n        base_output = F.linear(self.base_activation(x), self.base_weight)\n        spline_output = F.linear(\n            self.b_splines(x).view(x.size(0), -1),\n            self.scaled_spline_weight.view(self.out_features, -1),\n        )\n        return base_output + spline_output\n\n    @torch.no_grad()\n    def update_grid(self, x: torch.Tensor, margin=0.01):\n        assert x.dim() == 2 and x.size(1) == self.in_features\n        batch = x.size(0)\n\n        splines = self.b_splines(x)  # (batch, in, coeff)\n        splines = splines.permute(1, 0, 2)  # (in, batch, coeff)\n        orig_coeff = self.scaled_spline_weight  # (out, in, coeff)\n        orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)\n        unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)\n        unreduced_spline_output = unreduced_spline_output.permute(\n            1, 0, 2\n        )  # (batch, in, out)\n\n        # sort each channel individually to collect data distribution\n        x_sorted = torch.sort(x, dim=0)[0]\n        grid_adaptive = x_sorted[\n            torch.linspace(\n                0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device\n            )\n        ]\n\n        uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size\n        grid_uniform = (\n            torch.arange(\n                self.grid_size + 1, dtype=torch.float32, device=x.device\n            ).unsqueeze(1)\n            * uniform_step\n            + x_sorted[0]\n            - margin\n        )\n\n        grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive\n        grid = torch.concatenate(\n            [\n                grid[:1]\n                - uniform_step\n                * torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),\n                grid,\n                grid[-1:]\n                + uniform_step\n                * torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),\n            ],\n            dim=0,\n        )\n\n        self.grid.copy_(grid.T)\n        self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        \"\"\"\n        Compute the regularization loss.\n\n        This is a dumb simulation of the original L1 regularization as stated in the\n        paper, since the original one requires computing absolutes and entropy from the\n        expanded (batch, in_features, out_features) intermediate tensor, which is hidden\n        behind the F.linear function if we want an memory efficient implementation.\n\n        The L1 regularization is now computed as mean absolute value of the spline\n        weights. The authors implementation also includes this term in addition to the\n        sample-based regularization.\n        \"\"\"\n        l1_fake = self.spline_weight.abs().mean(-1)\n        regularization_loss_activation = l1_fake.sum()\n        p = l1_fake / regularization_loss_activation\n        regularization_loss_entropy = -torch.sum(p * p.log())\n        return (\n            regularize_activation * regularization_loss_activation\n            + regularize_entropy * regularization_loss_entropy\n        )\n\n\nclass KAN(torch.nn.Module):\n    def __init__(\n        self,\n        layers_hidden,\n        grid_size=5,\n        spline_order=3,\n        scale_noise=0.1,\n        scale_base=1.0,\n        scale_spline=1.0,\n        base_activation=torch.nn.SiLU,\n        grid_eps=0.02,\n        grid_range=[-1, 1],\n    ):\n        super(KAN, self).__init__()\n        self.grid_size = grid_size\n        self.spline_order = spline_order\n\n        self.layers = torch.nn.ModuleList()\n        for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):\n            self.layers.append(\n                KANLinear(\n                    in_features,\n                    out_features,\n                    grid_size=grid_size,\n                    spline_order=spline_order,\n                    scale_noise=scale_noise,\n                    scale_base=scale_base,\n                    scale_spline=scale_spline,\n                    base_activation=base_activation,\n                    grid_eps=grid_eps,\n                    grid_range=grid_range,\n                )\n            )\n\n    def forward(self, x: torch.Tensor, update_grid=False):\n        for layer in self.layers:\n            if update_grid:\n                layer.update_grid(x)\n            x = layer(x)\n        return x\n\n    def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):\n        return sum(\n            layer.regularization_loss(regularize_activation, regularize_entropy)\n            for layer in self.layers\n        )\n"
  },
  {
    "path": "Seg_UKAN/losses.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ntry:\n    from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge\nexcept ImportError:\n    pass\n\n__all__ = ['BCEDiceLoss', 'LovaszHingeLoss']\n\n\nclass BCEDiceLoss(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input, target):\n        bce = F.binary_cross_entropy_with_logits(input, target)\n        smooth = 1e-5\n        input = torch.sigmoid(input)\n        num = target.size(0)\n        input = input.view(num, -1)\n        target = target.view(num, -1)\n        intersection = (input * target)\n        dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)\n        dice = 1 - dice.sum() / num\n        return 0.5 * bce + dice\n\n\nclass LovaszHingeLoss(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input, target):\n        input = input.squeeze(1)\n        target = target.squeeze(1)\n        loss = lovasz_hinge(input, target, per_image=True)\n\n        return loss\n"
  },
  {
    "path": "Seg_UKAN/metrics.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\n\nfrom medpy.metric.binary import jc, dc, hd, hd95, recall, specificity, precision\n\n\n\ndef iou_score(output, target):\n    smooth = 1e-5\n\n    if torch.is_tensor(output):\n        output = torch.sigmoid(output).data.cpu().numpy()\n    if torch.is_tensor(target):\n        target = target.data.cpu().numpy()\n    output_ = output > 0.5\n    target_ = target > 0.5\n    intersection = (output_ & target_).sum()\n    union = (output_ | target_).sum()\n    iou = (intersection + smooth) / (union + smooth)\n    dice = (2* iou) / (iou+1)\n\n    try:\n        hd95_ = hd95(output_, target_)\n    except:\n        hd95_ = 0\n    \n    return iou, dice, hd95_\n\n\ndef dice_coef(output, target):\n    smooth = 1e-5\n\n    output = torch.sigmoid(output).view(-1).data.cpu().numpy()\n    target = target.view(-1).data.cpu().numpy()\n    intersection = (output * target).sum()\n\n    return (2. * intersection + smooth) / \\\n        (output.sum() + target.sum() + smooth)\n\ndef indicators(output, target):\n    if torch.is_tensor(output):\n        output = torch.sigmoid(output).data.cpu().numpy()\n    if torch.is_tensor(target):\n        target = target.data.cpu().numpy()\n    output_ = output > 0.5\n    target_ = target > 0.5\n\n    iou_ = jc(output_, target_)\n    dice_ = dc(output_, target_)\n    hd_ = hd(output_, target_)\n    hd95_ = hd95(output_, target_)\n    recall_ = recall(output_, target_)\n    specificity_ = specificity(output_, target_)\n    precision_ = precision(output_, target_)\n\n    return iou_, dice_, hd_, hd95_, recall_, specificity_, precision_\n"
  },
  {
    "path": "Seg_UKAN/requirements.txt",
    "content": "addict==2.4.0\ndataclasses\npandas\npyyaml\nalbumentations\ntqdm\ntensorboardX\n# mmcv-full==1.2.7\nnumpy\nopencv-python\nperceptual==0.1\npillow==8.4.0\nscikit-image==0.17.2\nscipy==1.5.4\ntifffile==2020.9.3\ntimm==0.3.2\ntyping-extensions==4.0.0\nyapf==0.31.0\n\npip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116"
  },
  {
    "path": "Seg_UKAN/scripts.sh",
    "content": "dataset=busi\ninput_size=256\npython train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN  --data_dir [YOUR_DATA_DIR]\npython val.py --name ${dataset}_UKAN \n\ndataset=glas\ninput_size=512\npython train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN  --data_dir [YOUR_DATA_DIR]\npython val.py --name ${dataset}_UKAN \n\ndataset=cvc\ninput_size=256\npython train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN  --data_dir [YOUR_DATA_DIR]\npython val.py --name ${dataset}_UKAN \n\n\n\n\n\n\n\n"
  },
  {
    "path": "Seg_UKAN/train.py",
    "content": "import argparse\nimport os\nfrom collections import OrderedDict\nfrom glob import glob\nimport random\nimport numpy as np\n\nimport pandas as pd\nimport torch\nimport torch.backends.cudnn as cudnn\nimport torch.nn as nn\nimport torch.optim as optim\nimport yaml\n\nfrom albumentations.augmentations import transforms\nfrom albumentations.augmentations import geometric\n\nfrom albumentations.core.composition import Compose, OneOf\nfrom sklearn.model_selection import train_test_split\nfrom torch.optim import lr_scheduler\nfrom tqdm import tqdm\nfrom albumentations import RandomRotate90, Resize\n\nimport archs\n\nimport losses\nfrom dataset import Dataset\n\nfrom metrics import iou_score, indicators\n\nfrom utils import AverageMeter, str2bool\n\nfrom tensorboardX import SummaryWriter\n\nimport shutil\nimport os\nimport subprocess\n\nfrom pdb import set_trace as st\n\n\nARCH_NAMES = archs.__all__\nLOSS_NAMES = losses.__all__\nLOSS_NAMES.append('BCEWithLogitsLoss')\n\n\ndef list_type(s):\n    str_list = s.split(',')\n    int_list = [int(a) for a in str_list]\n    return int_list\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--name', default=None,\n                        help='model name: (default: arch+timestamp)')\n    parser.add_argument('--epochs', default=400, type=int, metavar='N',\n                        help='number of total epochs to run')\n    parser.add_argument('-b', '--batch_size', default=8, type=int,\n                        metavar='N', help='mini-batch size (default: 16)')\n\n    parser.add_argument('--dataseed', default=2981, type=int,\n                        help='')\n    \n    # model\n    parser.add_argument('--arch', '-a', metavar='ARCH', default='UKAN')\n    \n    parser.add_argument('--deep_supervision', default=False, type=str2bool)\n    parser.add_argument('--input_channels', default=3, type=int,\n                        help='input channels')\n    parser.add_argument('--num_classes', default=1, type=int,\n                        help='number of classes')\n    parser.add_argument('--input_w', default=256, type=int,\n                        help='image width')\n    parser.add_argument('--input_h', default=256, type=int,\n                        help='image height')\n    parser.add_argument('--input_list', type=list_type, default=[128, 160, 256])\n\n    # loss\n    parser.add_argument('--loss', default='BCEDiceLoss',\n                        choices=LOSS_NAMES,\n                        help='loss: ' +\n                        ' | '.join(LOSS_NAMES) +\n                        ' (default: BCEDiceLoss)')\n    \n    # dataset\n    parser.add_argument('--dataset', default='busi', help='dataset name')      \n    parser.add_argument('--data_dir', default='inputs', help='dataset dir')\n\n    parser.add_argument('--output_dir', default='outputs', help='ouput dir')\n\n\n    # optimizer\n    parser.add_argument('--optimizer', default='Adam',\n                        choices=['Adam', 'SGD'],\n                        help='loss: ' +\n                        ' | '.join(['Adam', 'SGD']) +\n                        ' (default: Adam)')\n\n    parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,\n                        metavar='LR', help='initial learning rate')\n    parser.add_argument('--momentum', default=0.9, type=float,\n                        help='momentum')\n    parser.add_argument('--weight_decay', default=1e-4, type=float,\n                        help='weight decay')\n    parser.add_argument('--nesterov', default=False, type=str2bool,\n                        help='nesterov')\n\n    parser.add_argument('--kan_lr', default=1e-2, type=float,\n                        metavar='LR', help='initial learning rate')\n    parser.add_argument('--kan_weight_decay', default=1e-4, type=float,\n                        help='weight decay')\n\n    # scheduler\n    parser.add_argument('--scheduler', default='CosineAnnealingLR',\n                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])\n    parser.add_argument('--min_lr', default=1e-5, type=float,\n                        help='minimum learning rate')\n    parser.add_argument('--factor', default=0.1, type=float)\n    parser.add_argument('--patience', default=2, type=int)\n    parser.add_argument('--milestones', default='1,2', type=str)\n    parser.add_argument('--gamma', default=2/3, type=float)\n    parser.add_argument('--early_stopping', default=-1, type=int,\n                        metavar='N', help='early stopping (default: -1)')\n    parser.add_argument('--cfg', type=str, metavar=\"FILE\", help='path to config file', )\n    parser.add_argument('--num_workers', default=4, type=int)\n\n    parser.add_argument('--no_kan', action='store_true')\n\n\n\n    config = parser.parse_args()\n\n    return config\n\n\ndef train(config, train_loader, model, criterion, optimizer):\n    avg_meters = {'loss': AverageMeter(),\n                  'iou': AverageMeter()}\n\n    model.train()\n\n    pbar = tqdm(total=len(train_loader))\n    for input, target, _ in train_loader:\n        input = input.cuda()\n        target = target.cuda()\n\n        # compute output\n        if config['deep_supervision']:\n            outputs = model(input)\n            loss = 0\n            for output in outputs:\n                loss += criterion(output, target)\n            loss /= len(outputs)\n\n            iou, dice, _ = iou_score(outputs[-1], target)\n            iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(outputs[-1], target)\n            \n        else:\n            output = model(input)\n            loss = criterion(output, target)\n            iou, dice, _ = iou_score(output, target)\n            iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(output, target)\n\n        # compute gradient and do optimizing step\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        avg_meters['loss'].update(loss.item(), input.size(0))\n        avg_meters['iou'].update(iou, input.size(0))\n\n        postfix = OrderedDict([\n            ('loss', avg_meters['loss'].avg),\n            ('iou', avg_meters['iou'].avg),\n        ])\n        pbar.set_postfix(postfix)\n        pbar.update(1)\n    pbar.close()\n\n    return OrderedDict([('loss', avg_meters['loss'].avg),\n                        ('iou', avg_meters['iou'].avg)])\n\n\ndef validate(config, val_loader, model, criterion):\n    avg_meters = {'loss': AverageMeter(),\n                  'iou': AverageMeter(),\n                   'dice': AverageMeter()}\n\n    # switch to evaluate mode\n    model.eval()\n\n    with torch.no_grad():\n        pbar = tqdm(total=len(val_loader))\n        for input, target, _ in val_loader:\n            input = input.cuda()\n            target = target.cuda()\n\n            # compute output\n            if config['deep_supervision']:\n                outputs = model(input)\n                loss = 0\n                for output in outputs:\n                    loss += criterion(output, target)\n                loss /= len(outputs)\n                iou, dice, _ = iou_score(outputs[-1], target)\n            else:\n                output = model(input)\n                loss = criterion(output, target)\n                iou, dice, _ = iou_score(output, target)\n\n            avg_meters['loss'].update(loss.item(), input.size(0))\n            avg_meters['iou'].update(iou, input.size(0))\n            avg_meters['dice'].update(dice, input.size(0))\n\n            postfix = OrderedDict([\n                ('loss', avg_meters['loss'].avg),\n                ('iou', avg_meters['iou'].avg),\n                ('dice', avg_meters['dice'].avg)\n            ])\n            pbar.set_postfix(postfix)\n            pbar.update(1)\n        pbar.close()\n\n\n    return OrderedDict([('loss', avg_meters['loss'].avg),\n                        ('iou', avg_meters['iou'].avg),\n                        ('dice', avg_meters['dice'].avg)])\n\ndef seed_torch(seed=1029):\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\ndef main():\n    seed_torch()\n    config = vars(parse_args())\n\n    exp_name = config.get('name')\n    output_dir = config.get('output_dir')\n\n    my_writer = SummaryWriter(f'{output_dir}/{exp_name}')\n\n    if config['name'] is None:\n        if config['deep_supervision']:\n            config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])\n        else:\n            config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])\n    \n    os.makedirs(f'{output_dir}/{exp_name}', exist_ok=True)\n\n    print('-' * 20)\n    for key in config:\n        print('%s: %s' % (key, config[key]))\n    print('-' * 20)\n\n    with open(f'{output_dir}/{exp_name}/config.yml', 'w') as f:\n        yaml.dump(config, f)\n\n    # define loss function (criterion)\n    if config['loss'] == 'BCEWithLogitsLoss':\n        criterion = nn.BCEWithLogitsLoss().cuda()\n    else:\n        criterion = losses.__dict__[config['loss']]().cuda()\n\n    cudnn.benchmark = True\n\n    # create model\n    model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'], no_kan=config['no_kan'])\n\n    model = model.cuda()\n\n\n    param_groups = []\n\n    kan_fc_params = []\n    other_params = []\n\n    for name, param in model.named_parameters():\n        # print(name, \"=>\", param.shape)\n        if 'layer' in name.lower() and 'fc' in name.lower(): # higher lr for kan layers\n            # kan_fc_params.append(name)\n            param_groups.append({'params': param, 'lr': config['kan_lr'], 'weight_decay': config['kan_weight_decay']}) \n        else:\n            # other_params.append(name)\n            param_groups.append({'params': param, 'lr': config['lr'], 'weight_decay': config['weight_decay']})  \n    \n\n    \n    # st()\n    if config['optimizer'] == 'Adam':\n        optimizer = optim.Adam(param_groups)\n\n\n    elif config['optimizer'] == 'SGD':\n        optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay'])\n    else:\n        raise NotImplementedError\n\n    if config['scheduler'] == 'CosineAnnealingLR':\n        scheduler = lr_scheduler.CosineAnnealingLR(\n            optimizer, T_max=config['epochs'], eta_min=config['min_lr'])\n    elif config['scheduler'] == 'ReduceLROnPlateau':\n        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr'])\n    elif config['scheduler'] == 'MultiStepLR':\n        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])\n    elif config['scheduler'] == 'ConstantLR':\n        scheduler = None\n    else:\n        raise NotImplementedError\n\n    shutil.copy2('train.py', f'{output_dir}/{exp_name}/')\n    shutil.copy2('archs.py', f'{output_dir}/{exp_name}/')\n\n    dataset_name = config['dataset']\n    img_ext = '.png'\n\n    if dataset_name == 'busi':\n        mask_ext = '_mask.png'\n    elif dataset_name == 'glas':\n        mask_ext = '.png'\n    elif dataset_name == 'cvc':\n        mask_ext = '.png'\n\n    # Data loading code\n    img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext)))\n    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]\n\n    train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed'])\n\n    train_transform = Compose([\n        RandomRotate90(),\n        geometric.transforms.Flip(),\n        Resize(config['input_h'], config['input_w']),\n        transforms.Normalize(),\n    ])\n\n    val_transform = Compose([\n        Resize(config['input_h'], config['input_w']),\n        transforms.Normalize(),\n    ])\n\n    train_dataset = Dataset(\n        img_ids=train_img_ids,\n        img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'),\n        mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),\n        img_ext=img_ext,\n        mask_ext=mask_ext,\n        num_classes=config['num_classes'],\n        transform=train_transform)\n    val_dataset = Dataset(\n        img_ids=val_img_ids,\n        img_dir=os.path.join(config['data_dir'] ,config['dataset'], 'images'),\n        mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),\n        img_ext=img_ext,\n        mask_ext=mask_ext,\n        num_classes=config['num_classes'],\n        transform=val_transform)\n\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=config['batch_size'],\n        shuffle=True,\n        num_workers=config['num_workers'],\n        drop_last=True)\n    val_loader = torch.utils.data.DataLoader(\n        val_dataset,\n        batch_size=config['batch_size'],\n        shuffle=False,\n        num_workers=config['num_workers'],\n        drop_last=False)\n\n    log = OrderedDict([\n        ('epoch', []),\n        ('lr', []),\n        ('loss', []),\n        ('iou', []),\n        ('val_loss', []),\n        ('val_iou', []),\n        ('val_dice', []),\n    ])\n\n\n    best_iou = 0\n    best_dice= 0\n    trigger = 0\n    for epoch in range(config['epochs']):\n        print('Epoch [%d/%d]' % (epoch, config['epochs']))\n\n        # train for one epoch\n        train_log = train(config, train_loader, model, criterion, optimizer)\n        # evaluate on validation set\n        val_log = validate(config, val_loader, model, criterion)\n\n        if config['scheduler'] == 'CosineAnnealingLR':\n            scheduler.step()\n        elif config['scheduler'] == 'ReduceLROnPlateau':\n            scheduler.step(val_log['loss'])\n\n        print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'\n              % (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))\n\n        log['epoch'].append(epoch)\n        log['lr'].append(config['lr'])\n        log['loss'].append(train_log['loss'])\n        log['iou'].append(train_log['iou'])\n        log['val_loss'].append(val_log['loss'])\n        log['val_iou'].append(val_log['iou'])\n        log['val_dice'].append(val_log['dice'])\n\n        pd.DataFrame(log).to_csv(f'{output_dir}/{exp_name}/log.csv', index=False)\n\n        my_writer.add_scalar('train/loss', train_log['loss'], global_step=epoch)\n        my_writer.add_scalar('train/iou', train_log['iou'], global_step=epoch)\n        my_writer.add_scalar('val/loss', val_log['loss'], global_step=epoch)\n        my_writer.add_scalar('val/iou', val_log['iou'], global_step=epoch)\n        my_writer.add_scalar('val/dice', val_log['dice'], global_step=epoch)\n\n        my_writer.add_scalar('val/best_iou_value', best_iou, global_step=epoch)\n        my_writer.add_scalar('val/best_dice_value', best_dice, global_step=epoch)\n\n        trigger += 1\n\n        if val_log['iou'] > best_iou:\n            torch.save(model.state_dict(), f'{output_dir}/{exp_name}/model.pth')\n            best_iou = val_log['iou']\n            best_dice = val_log['dice']\n            print(\"=> saved best model\")\n            print('IoU: %.4f' % best_iou)\n            print('Dice: %.4f' % best_dice)\n            trigger = 0\n\n        # early stopping\n        if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:\n            print(\"=> early stopping\")\n            break\n\n        torch.cuda.empty_cache()\n    \nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "Seg_UKAN/utils.py",
    "content": "import argparse\nimport torch.nn as nn\n\nclass qkv_transform(nn.Conv1d):\n    \"\"\"Conv1d for qkv_transform\"\"\"\n\ndef str2bool(v):\n    if v.lower() in ['true', 1]:\n        return True\n    elif v.lower() in ['false', 0]:\n        return False\n    else:\n        raise argparse.ArgumentTypeError('Boolean value expected.')\n\n\ndef count_params(model):\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  },
  {
    "path": "Seg_UKAN/val.py",
    "content": "#! /data/cxli/miniconda3/envs/th200/bin/python\nimport argparse\nimport os\nfrom glob import glob\nimport random\nimport numpy as np\n\nimport cv2\nimport torch\nimport torch.backends.cudnn as cudnn\nimport yaml\nfrom albumentations.augmentations import transforms\nfrom albumentations.core.composition import Compose\nfrom sklearn.model_selection import train_test_split\nfrom tqdm import tqdm\nfrom collections import OrderedDict\n\nimport archs\n\nfrom dataset import Dataset\nfrom metrics import iou_score\nfrom utils import AverageMeter\nfrom albumentations import RandomRotate90,Resize\nimport time\n\nfrom PIL import Image\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--name', default=None, help='model name')\n    parser.add_argument('--output_dir', default='outputs', help='ouput dir')\n            \n    args = parser.parse_args()\n\n    return args\n\ndef seed_torch(seed=1029):\n    random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\n\ndef main():\n    seed_torch()\n    args = parse_args()\n\n    with open(f'{args.output_dir}/{args.name}/config.yml', 'r') as f:\n        config = yaml.load(f, Loader=yaml.FullLoader)\n\n    print('-'*20)\n    for key in config.keys():\n        print('%s: %s' % (key, str(config[key])))\n    print('-'*20)\n\n    cudnn.benchmark = True\n\n    model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'])\n\n    model = model.cuda()\n\n    dataset_name = config['dataset']\n    img_ext = '.png'\n\n    if dataset_name == 'busi':\n        mask_ext = '_mask.png'\n    elif dataset_name == 'glas':\n        mask_ext = '.png'\n    elif dataset_name == 'cvc':\n        mask_ext = '.png'\n\n    # Data loading code\n    img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext)))\n    # img_ids.sort()\n    img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]\n\n    _, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed'])\n\n    ckpt = torch.load(f'{args.output_dir}/{args.name}/model.pth')\n\n    try:        \n        model.load_state_dict(ckpt)\n    except:\n        print(\"Pretrained model keys:\", ckpt.keys())\n        print(\"Current model keys:\", model.state_dict().keys())\n\n        pretrained_dict = {k: v for k, v in ckpt.items() if k in model.state_dict()}\n        current_dict = model.state_dict()\n        diff_keys = set(current_dict.keys()) - set(pretrained_dict.keys())\n\n        print(\"Difference in model keys:\")\n        for key in diff_keys:\n            print(f\"Key: {key}\")\n\n        model.load_state_dict(ckpt, strict=False)\n        \n    model.eval()\n\n    val_transform = Compose([\n        Resize(config['input_h'], config['input_w']),\n        transforms.Normalize(),\n    ])\n\n    val_dataset = Dataset(\n        img_ids=val_img_ids,\n        img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'),\n        mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),\n        img_ext=img_ext,\n        mask_ext=mask_ext,\n        num_classes=config['num_classes'],\n        transform=val_transform)\n    val_loader = torch.utils.data.DataLoader(\n        val_dataset,\n        batch_size=config['batch_size'],\n        shuffle=False,\n        num_workers=config['num_workers'],\n        drop_last=False)\n\n    iou_avg_meter = AverageMeter()\n    dice_avg_meter = AverageMeter()\n    hd95_avg_meter = AverageMeter()\n\n    with torch.no_grad():\n        for input, target, meta in tqdm(val_loader, total=len(val_loader)):\n            input = input.cuda()\n            target = target.cuda()\n            model = model.cuda()\n            # compute output\n            output = model(input)\n\n            iou, dice, hd95_ = iou_score(output, target)\n            iou_avg_meter.update(iou, input.size(0))\n            dice_avg_meter.update(dice, input.size(0))\n            hd95_avg_meter.update(hd95_, input.size(0))\n\n            output = torch.sigmoid(output).cpu().numpy()\n            output[output>=0.5]=1\n            output[output<0.5]=0\n\n            os.makedirs(os.path.join(args.output_dir, config['name'], 'out_val'), exist_ok=True)\n            for pred, img_id in zip(output, meta['img_id']):\n                pred_np = pred[0].astype(np.uint8)\n                pred_np = pred_np * 255\n                img = Image.fromarray(pred_np, 'L')\n                img.save(os.path.join(args.output_dir, config['name'], 'out_val/{}.jpg'.format(img_id)))\n\n    \n    print(config['name'])\n    print('IoU: %.4f' % iou_avg_meter.avg)\n    print('Dice: %.4f' % dice_avg_meter.avg)\n    print('HD95: %.4f' % hd95_avg_meter.avg)\n\n\n\nif __name__ == '__main__':\n    main()\n"
  }
]