[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Arian Mousakhan\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Anomaly Detection with Conditioned Denoising Diffusion Models.\n\nOfficial implementation of [DDAD](https://arxiv.org/abs/2305.15956) \n\n\n[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/anomaly-detection-with-conditioned-denoising/anomaly-detection-on-mvtec-ad)](https://paperswithcode.com/sota/anomaly-detection-on-mvtec-ad?p=anomaly-detection-with-conditioned-denoising)  [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/anomaly-detection-with-conditioned-denoising/anomaly-detection-on-visa)](https://paperswithcode.com/sota/anomaly-detection-on-visa?p=anomaly-detection-with-conditioned-denoising)\n\n\n![Framework](images/DDAD_Framework.png)\n\n\n\n## Requirements\nThis repository is implemented and tested on Python 3.8 and PyTorch 2.1.\nTo install requirements:\n\n```setup\npip install -r requirements.txt\n```\n\n## Train and Evaluation of the Model\nYou can download the model checkpoints directly from [Checkpoints](https://drive.google.com/drive/u/0/folders/1FF83llo3a-mN5pJN8-_mw0hL5eZqe9fC) \n\nTo train the denoising UNet, run:\n\n```train\npython main.py --train True\n```\n\nModify the settings in the config.yaml file to train the model on different categories.\n\n\nFor fine-tuning the feature extractor, use the following command:\n\n```domain_adaptation\npython main.py --domain_adaptation True\n```\n\nTo evaluate and test the model, run:\n\n```detection\npython main.py --detection True\n```\n\n\n## Dataset\nYou can download  [MVTec AD: MVTec Software](https://www.mvtec.com/company/research/datasets/mvtec-ad/) and [VisA](https://amazon-visual-anomaly.s3.us-west-2.amazonaws.com/VisA_20220922.tar) Benchmarks.\nFor preprocessing of VisA dataset check out the [Data preparation](https://github.com/amazon-science/spot-diff/tree/main) section of this repository.\n\nThe dataset should be placed in the 'datasets' folder. The training dataset should only contain one subcategory consisting of nominal samples, which should be named 'good'. The test dataset should include one category named 'good' for nominal samples, and any other subcategories of anomalous samples. It should be made as follows:\n\n```shell\nName_of_Dataset\n|-- Category\n|-----|----- ground_truth\n|-----|----- test\n|-----|--------|------ good\n|-----|--------|------ ...\n|-----|--------|------ ...\n|-----|----- train\n|-----|--------|------ good\n```\n\n\n\n\n## Results\nRunning the code as explained in this file should achieve the following results for MVTec AD:\n\nAnomaly Detection (Image AUROC) and Anomaly Localization (Pixel AUROC, PRO)\n\nExpected results for MVTec AD:\n| Category | Carpet | Grid |  Leather | Tile | Wood | Bottle |  Cable | Capsule | Hazel nut | Metalnut | Pill | Screw | Toothbrush | Transistor | Zipper |Average\n|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n| Detection | 99.3% | 100% | 100% | 100% | 100% | 100% | 99.4% | 99.4% | 100% | 100% | 100% | 99.0% | 100% | 100% | 100% | 99.8% \n| Localization | (98.7%,93.9%) |  (99.4%,97.3%) | (99.4%,97.7%) | (98.2%,93.1%) | (95.0%,82.9%) | (98.7%,91.8%) | (98.1%,88.9%) | (95.7%,93.4%) | (98.4%,86.7%) | (99.0%,91.1%) | (99.1%,95.5%) | (99.3%,96.3%) | (98.7%,92.6%) | (95.3%,90.1%) | (98.2%,93.2%) | (98,1%,92.3%)\n\nThe settings used for these results are detailed in the table.\n\n| **Categories** | Carpet | Grid | Leather | Tile | Wood | Bottle | Cable | Capsule | Hazelnut | Metal nut | Pill | Screw | Toothbrush | Transistor | Zipper |\n| -------------- | ------ | ---- | ------- | ---- | ---- | ------ | ----- | ------- | -------- | --------- | ---- | ----- | ----------- | ---------- | ------ |\n| **\\(w\\)**       | 0      | 4    | 11      | 4    | 11   | 3      | 3     | 8       | 5        | 7         | 9    | 2     | 0           | 0          | 10     |\n| **Training epochs** | 2500 | 2000 | 2000 | 1000 | 2000 | 1000 | 3000 | 1500 | 2000 | 3000 | 1000 | 2000 | 2000 | 2000 | 1000 |\n| **FE epochs**   | 0      | 6    | 8       | 0    | 16   | 5      | 0     | 8       | 3        | 1         | 4    | 4     | 2           | 0          | 6      |\n\n\nFollowing is the expected results on VisA Dataset. \n\n| Category | Candle | Capsules |  Cashew | Chewing gum | Fryum | Macaroni1 |  Macaroni2 | PCB1 | PCB2 | PCB3 | PCB4 | Pipe fryum | Average\n|---|---|---|---|---|---|---|---|---|---|---|---|---|---|\n| Detection | 99.9% | 100% | 94.5% | 98.1% | 99.0% | 99.2% | 99.2% | 100% |  99.7% | 97.2% | 100% | 100% | 98.9%\n| Localization | (98.7%,96.6%) |  (99.5%,95.0%) | (97.4%,80.3%) | (96.5%,85.2%) | (96.9%,94.2%) | (98.7%,98.5%) | (98.2%,99.3%) | (93.4%,93.3%) | (97.4%,93.3%) | (96.3%,86.6%) | (98.5%,95.5%) | (99.5%,94.7%) |(97.6%,92.7%)\n\nThe settings used for these results are detailed in the table.\n\n| **Categories**   | Candle | Capsules | Cashew | Chewing gum | Fryum | Macaroni1 | Macaroni2 | PCB1 | PCB2 | PCB3 | PCB4 | Pipe fryum |\n| ---------------- | ------ | -------- | ------ | ------------ | ----- | --------- | --------- | ---- | ---- | ---- | ---- | ---------- |\n| **\\(w\\)**         | 6      | 5        | 0      | 6            | 4     | 5         | 2         | 9    | 5    | 6    | 6    | 8          |\n| **Training epochs** | 1000   | 1000     | 1750   | 1250         | 1000  | 500       | 500       | 500  | 500  | 500  | 500  | 500        |\n| **FE epochs**     | 1      | 3        | 0      | 0            | 3     | 7         | 11        | 8    | 5    | 1    | 1    | 6          |\n\n\n![Framework](images/Qualitative.png)\n\n## Citation\n\n```\n@article{mousakhan2023anomaly,\n  title={Anomaly Detection with Conditioned Denoising Diffusion Models},\n  author={Mousakhan, Arian and Brox, Thomas and Tayyub, Jawad},\n  journal={arXiv preprint arXiv:2305.15956},\n  year={2023}\n}\n```\n\n## Feedback\n\nFor any feedback or inquiries, please contact arian.mousakhan@gmail.com\n"
  },
  {
    "path": "anomaly_map.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom kornia.filters import gaussian_blur2d\nfrom torchvision.transforms import transforms\nimport math \nfrom dataset import *\nfrom visualize import *\nfrom feature_extractor import *\nimport numpy as np\n\n\ndef heat_map(output, target, FE, config):\n    '''\n    Compute the anomaly map\n    :param output: the output of the reconstruction\n    :param target: the target image\n    :param FE: the feature extractor\n    :param sigma: the sigma of the gaussian kernel\n    :param i_d: the pixel distance\n    :param f_d: the feature distance\n    '''\n    sigma = 4\n    kernel_size = 2 * int(4 * sigma + 0.5) +1\n    anomaly_map = 0\n\n    output = output.to(config.model.device)\n    target = target.to(config.model.device)\n\n    i_d = pixel_distance(output, target)\n    f_d = feature_distance((output),  (target), FE, config)\n    f_d = torch.Tensor(f_d).to(config.model.device)\n\n    anomaly_map += f_d + config.model.v * (torch.max(f_d)/ torch.max(i_d)) * i_d  \n    anomaly_map = gaussian_blur2d(\n        anomaly_map , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)\n        )\n    anomaly_map = torch.sum(anomaly_map, dim=1).unsqueeze(1)\n    return anomaly_map\n\n\n\ndef pixel_distance(output, target):\n    '''\n    Pixel distance between image1 and image2\n    '''\n    distance_map = torch.mean(torch.abs(output - target), dim=1).unsqueeze(1)\n    return distance_map\n\n\n\n\ndef feature_distance(output, target, FE, config):\n    '''\n    Feature distance between output and target\n    '''\n    FE.eval()\n    transform = transforms.Compose([\n            transforms.Lambda(lambda t: (t + 1) / (2)),\n            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n        ])\n    target = transform(target)\n    output = transform(output)\n    inputs_features = FE(target)\n    output_features = FE(output)\n    out_size = config.data.image_size\n    anomaly_map = torch.zeros([inputs_features[0].shape[0] ,1 ,out_size, out_size]).to(config.model.device)\n    for i in range(len(inputs_features)):\n        if i == 0:\n            continue\n        a_map = 1 - F.cosine_similarity(patchify(inputs_features[i]), patchify(output_features[i]))\n        a_map = torch.unsqueeze(a_map, dim=1)\n        a_map = F.interpolate(a_map, size=out_size, mode='bilinear', align_corners=True)\n        anomaly_map += a_map\n    return anomaly_map \n\n\n#https://github.com/amazon-science/patchcore-inspection\ndef patchify(features, return_spatial_info=False):\n    \"\"\"Convert a tensor into a tensor of respective patches.\n    Args:\n        x: [torch.Tensor, bs x c x w x h]\n    Returns:\n        x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,\n        patchsize]\n    \"\"\"\n    patchsize = 3\n    stride = 1\n    padding = int((patchsize - 1) / 2)\n    unfolder = torch.nn.Unfold(\n        kernel_size=patchsize, stride=stride, padding=padding, dilation=1\n    )\n    unfolded_features = unfolder(features)\n    number_of_total_patches = []\n    for s in features.shape[-2:]:\n        n_patches = (\n            s + 2 * padding - 1 * (patchsize - 1) - 1\n        ) / stride + 1\n        number_of_total_patches.append(int(n_patches))\n    unfolded_features = unfolded_features.reshape(\n        *features.shape[:2], patchsize, patchsize, -1\n    )\n    unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)\n    max_features = torch.mean(unfolded_features, dim=(3,4))\n    features = max_features.reshape(features.shape[0], int(math.sqrt(max_features.shape[1])) , int(math.sqrt(max_features.shape[1])), max_features.shape[-1]).permute(0,3,1,2)\n    if return_spatial_info:\n        return unfolded_features, number_of_total_patches\n    return features\n\n"
  },
  {
    "path": "config.yaml",
    "content": "data :\n  name: MVTec  #MVTec #MTD #VisA \n  data_dir: datasets/MVTec  #MVTec #VisA #MTD  \n  category: screw  #['carpet', 'bottle', 'hazelnut', 'leather', 'cable', 'capsule', 'grid', 'pill', 'transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']    \n                   # ['candle', 'capsules', 'cashew', 'chewinggum', 'fryum', 'macaroni1', 'macaroni2', 'pcb1', 'pcb2' ,'pcb3', 'pcb4', 'pipe_fryum']\n  image_size: 256 \n  batch_size: 32 # 32 for DDAD and 16 for DDADS\n  DA_batch_size: 16 #16 for MVTec and [macaroni2, pcb1] in VisA, and 32 for other categories in VisA\n  test_batch_size: 16 #16 for MVTec, 32 for VisA\n  mask : True \n  input_channel : 3\n\n\n\nmodel:\n  DDADS: False\n  checkpoint_dir: checkpoints/MVTec   #MTD  #MVTec  #VisA\n  checkpoint_name: weights\n  exp_name: default\n  feature_extractor: wide_resnet101_2 #wide_resnet101_2  # wide_resnet50_2 #resnet50\n  learning_rate: 3e-4 \n  weight_decay: 0.05\n  epochs: 3000\n  load_chp : 2000 # From this epoch checkpoint will be loaded. Every 250 epochs a checkpoint is saved. Try to load 750 or 1000 epochs for Visa and 1000-1500-2000 for MVTec.\n  DA_epochs: 4 # Number of epochs for Domain adaptation.\n  DA_chp: 4\n  v : 1 #7 # 1 for MVTec and cashew in VisA, and 7 for VisA (1.5 for cashew). Control parameter for pixel-wise and feature-wise comparison. v * D_p + D_f\n  w : 2 # Conditionig parameter. The higher the value, the more the model is conditioned on the target image. \"Fine tuninig this parameter results in better performance\".\n  w_DA : 3 #3 # Conditionig parameter for domain adaptation. The higher the value, the more the model is conditioned on the target image.\n  DLlambda : 0.1 # 0.1 for MVTec and 0.01 for VisA\n  trajectory_steps: 1000\n  test_trajectoy_steps: 250   # Starting point for denoining trajectory.\n  test_trajectoy_steps_DA: 250  # Starting point for denoining trajectory for domain adaptation.\n  skip : 25   # Number of steps to skip for denoising trajectory.\n  skip_DA : 25\n  eta : 1 # Stochasticity parameter for denoising process.\n  beta_start : 0.0001\n  beta_end : 0.02 \n  device: 'cuda' #<\"cpu\", \"gpu\", \"tpu\", \"ipu\">\n  save_model: True\n  num_workers : 2\n  seed : 42\n\n\n\nmetrics:\n  auroc: True\n  pro: True\n  misclassifications: False\n  visualisation: False"
  },
  {
    "path": "dataset.py",
    "content": "import os\nfrom glob import glob\nfrom pathlib import Path\nimport shutil\nimport numpy as np\nimport csv\nimport torch\nimport torch.utils.data\nfrom PIL import Image\nfrom torchvision import transforms\nimport torch.nn.functional as F\nimport torchvision.datasets as datasets\nfrom torchvision.datasets import CIFAR10\n\n\n\nclass Dataset_maker(torch.utils.data.Dataset):\n    def __init__(self, root, category, config, is_train=True):\n        self.image_transform = transforms.Compose(\n            [\n                transforms.Resize((config.data.image_size, config.data.image_size)),  \n                transforms.ToTensor(), # Scales data into [0,1] \n                transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] \n            ]\n        )\n        self.config = config\n        self.mask_transform = transforms.Compose(\n            [\n                transforms.Resize((config.data.image_size, config.data.image_size)),\n                transforms.ToTensor(), # Scales data into [0,1] \n            ]\n        )\n        if is_train:\n            if category:\n                self.image_files = glob(\n                    os.path.join(root, category, \"train\", \"good\", \"*.png\")\n                )\n            else:\n                self.image_files = glob(\n                    os.path.join(root, \"train\", \"good\", \"*.png\")\n                )\n        else:\n            if category:\n                self.image_files = glob(os.path.join(root, category, \"test\", \"*\", \"*.png\"))\n            else:\n                self.image_files = glob(os.path.join(root, \"test\", \"*\", \"*.png\"))\n        self.is_train = is_train\n\n    def __getitem__(self, index):\n        image_file = self.image_files[index]\n        image = Image.open(image_file)\n        image = self.image_transform(image)\n        if(image.shape[0] == 1):\n            image = image.expand(3, self.config.data.image_size, self.config.data.image_size)\n        if self.is_train:\n            label = 'good'\n            return image, label\n        else:\n            if self.config.data.mask:\n                if os.path.dirname(image_file).endswith(\"good\"):\n                    target = torch.zeros([1, image.shape[-2], image.shape[-1]])\n                    label = 'good'\n                else :\n                    if self.config.data.name == 'MVTec':\n                        target = Image.open(\n                            image_file.replace(\"/test/\", \"/ground_truth/\").replace(\n                                \".png\", \"_mask.png\"\n                            )\n                        )\n                    else:\n                        target = Image.open(\n                            image_file.replace(\"/test/\", \"/ground_truth/\"))\n                    target = self.mask_transform(target)\n                    label = 'defective'\n            else:\n                if os.path.dirname(image_file).endswith(\"good\"):\n                    target = torch.zeros([1, image.shape[-2], image.shape[-1]])\n                    label = 'good'\n                else :\n                    target = torch.zeros([1, image.shape[-2], image.shape[-1]])\n                    label = 'defective'\n                \n            return image, target, label\n\n    def __len__(self):\n        return len(self.image_files)\n"
  },
  {
    "path": "ddad.py",
    "content": "from asyncio import constants\nfrom typing import Any\nimport torch\nfrom unet import *\nfrom dataset import *\nfrom visualize import *\nfrom anomaly_map import *\nfrom metrics import *\nfrom feature_extractor import *\nfrom reconstruction import *\nos.environ['CUDA_VISIBLE_DEVICES'] = \"0,1,2\"\n\nclass DDAD:\n    def __init__(self, unet, config) -> None:\n        self.test_dataset = Dataset_maker(\n            root= config.data.data_dir,\n            category=config.data.category,\n            config = config,\n            is_train=False,\n        )\n        self.testloader = torch.utils.data.DataLoader(\n            self.test_dataset,\n            batch_size= config.data.test_batch_size,\n            shuffle=False,\n            num_workers= config.model.num_workers,\n            drop_last=False,\n        )\n        self.unet = unet\n        self.config = config\n        self.reconstruction = Reconstruction(self.unet, self.config)\n        self.transform = transforms.Compose([\n                            transforms.CenterCrop((224)), \n                        ])\n\n    def __call__(self) -> Any:\n        feature_extractor = domain_adaptation(self.unet, self.config, fine_tune=False)\n        feature_extractor.eval()\n        \n        labels_list = []\n        predictions= []\n        anomaly_map_list = []\n        gt_list = []\n        reconstructed_list = []\n        forward_list = []\n\n\n\n        with torch.no_grad():\n            for input, gt, labels in self.testloader:\n                input = input.to(self.config.model.device)\n                x0 = self.reconstruction(input, input, self.config.model.w)[-1]\n                anomaly_map = heat_map(x0, input, feature_extractor, self.config)\n\n                anomaly_map = self.transform(anomaly_map)\n                gt = self.transform(gt)\n\n                forward_list.append(input)\n                anomaly_map_list.append(anomaly_map)\n\n\n                gt_list.append(gt)\n                reconstructed_list.append(x0)\n                for pred, label in zip(anomaly_map, labels):\n                    labels_list.append(0 if label == 'good' else 1)\n                    predictions.append(torch.max(pred).item())\n\n        \n        metric = Metric(labels_list, predictions, anomaly_map_list, gt_list, self.config)\n        metric.optimal_threshold()\n        if self.config.metrics.auroc:\n            print('AUROC: ({:.1f},{:.1f})'.format(metric.image_auroc() * 100, metric.pixel_auroc() * 100))\n        if self.config.metrics.pro:\n            print('PRO: {:.1f}'.format(metric.pixel_pro() * 100))\n        if self.config.metrics.misclassifications:\n            metric.miscalssified()\n        reconstructed_list = torch.cat(reconstructed_list, dim=0)\n        forward_list = torch.cat(forward_list, dim=0)\n        anomaly_map_list = torch.cat(anomaly_map_list, dim=0)\n        pred_mask = (anomaly_map_list > metric.threshold).float()\n        gt_list = torch.cat(gt_list, dim=0)\n        if not os.path.exists('results'):\n                os.mkdir('results')\n        if self.config.metrics.visualisation:\n            visualize(forward_list, reconstructed_list, gt_list, pred_mask, anomaly_map_list, self.config.data.category)\n"
  },
  {
    "path": "feature_extractor.py",
    "content": "import logging\nimport torch\nfrom dataset import *\nfrom dataset import *\nfrom unet import *\nfrom visualize import *\nfrom resnet import *\nimport torchvision.transforms as T\nfrom reconstruction import *\n\nos.environ['CUDA_VISIBLE_DEVICES'] = \"0,1,2\"\n\ndef loss_fucntion(a, b, c, d, config):\n    cos_loss = torch.nn.CosineSimilarity()\n    loss1 = 0\n    loss2 = 0\n    loss3 = 0\n    for item in range(len(a)):\n        loss1 += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1),b[item].view(b[item].shape[0],-1))) \n        loss2 += torch.mean(1-cos_loss(b[item].view(b[item].shape[0],-1),c[item].view(c[item].shape[0],-1))) * config.model.DLlambda\n        loss3 += torch.mean(1-cos_loss(a[item].view(a[item].shape[0],-1),d[item].view(d[item].shape[0],-1))) * config.model.DLlambda\n    loss = loss1+loss2+loss3\n    return loss\n\n\n\ndef domain_adaptation(unet, config, fine_tune):\n    if config.model.feature_extractor == 'wide_resnet101_2':\n        feature_extractor = wide_resnet101_2(pretrained=True)\n        frozen_feature_extractor = wide_resnet101_2(pretrained=True)\n    elif config.model.feature_extractor == 'wide_resnet50_2':\n        feature_extractor = wide_resnet50_2(pretrained=True)\n        frozen_feature_extractor = wide_resnet50_2(pretrained=True)\n    elif config.model.feature_extractor == 'resnet50': \n        feature_extractor = resnet50(pretrained=True)\n        frozen_feature_extractor = resnet50(pretrained=True)\n    else:\n        logging.warning(\"Feature extractor is not correctly selected, Default: wide_resnet101_2\")\n        feature_extractor = wide_resnet101_2(pretrained=True)\n        frozen_feature_extractor = wide_resnet101_2(pretrained=True)\n\n    feature_extractor.to(config.model.device)  \n    frozen_feature_extractor.to(config.model.device)\n\n    frozen_feature_extractor.eval()\n\n    feature_extractor = torch.nn.DataParallel(feature_extractor)\n    frozen_feature_extractor = torch.nn.DataParallel(frozen_feature_extractor)\n\n\n    train_dataset = Dataset_maker(\n        root= config.data.data_dir,\n        category= config.data.category,\n        config = config,\n        is_train=True,\n    )\n    trainloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=config.data.DA_batch_size,\n        shuffle=True,\n        num_workers=config.model.num_workers,\n        drop_last=True,\n    )   \n\n    if fine_tune:      \n        unet.eval()\n        feature_extractor.train()\n\n\n        transform = transforms.Compose([\n                    transforms.Lambda(lambda t: (t + 1) / (2)),\n                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n                ])\n\n        optimizer = torch.optim.AdamW(feature_extractor.parameters(),lr= 1e-4)\n        torch.save(frozen_feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feat0'))\n        reconstruction = Reconstruction(unet, config)\n        for epoch in range(config.model.DA_epochs):\n            for step, batch in enumerate(trainloader):\n                half_batch_size = batch[0].shape[0]//2\n                target = batch[0][:half_batch_size].to(config.model.device)  \n                input = batch[0][half_batch_size:].to(config.model.device)   \n                \n                x0 = reconstruction(input, target, config.model.w_DA)[-1].to(config.model.device)\n                x0 = transform(x0)\n                target = transform(target)\n\n                reconst_fe = feature_extractor(x0)\n                target_fe = feature_extractor(target)\n\n                target_frozen_fe = frozen_feature_extractor(target)\n                reconst_frozen_fe = frozen_feature_extractor(x0)\n                \n                \n\n                loss = loss_fucntion(reconst_fe, target_fe, target_frozen_fe,reconst_frozen_fe, config)\n                optimizer.zero_grad()\n                loss.backward()\n                optimizer.step()\n\n            print(f\"Epoch {epoch+1} | Loss: {loss.item()}\")\n            # if (epoch+1) % 5 == 0:\n            torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feat{epoch+1}'))\n    else:\n        checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,f'feat{config.model.DA_chp}'))#{config.model.DA_chp}            \n        feature_extractor.load_state_dict(checkpoint)  \n    return feature_extractor"
  },
  {
    "path": "loss.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\n\n\ndef get_loss(model, x_0, t, config):\n    x_0 = x_0.to(config.model.device)\n    betas = np.linspace(config.model.beta_start, config.model.beta_end, config.model.trajectory_steps, dtype=np.float64)\n    b = torch.tensor(betas).type(torch.float).to(config.model.device)\n    e = torch.randn_like(x_0, device = x_0.device)\n    at = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1)\n\n\n    x = at.sqrt() * x_0 + (1- at).sqrt() * e \n    output = model(x, t.float())\n    return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)\n\n"
  },
  {
    "path": "main.py",
    "content": "import torch\nimport numpy as np\nimport os\nimport argparse\nfrom unet import *\nfrom omegaconf import OmegaConf\nfrom train import trainer\nfrom feature_extractor import * \nfrom ddad import *\nos.environ['CUDA_VISIBLE_DEVICES'] = \"0,1,2\"\n\ndef build_model(config):\n    if config.model.DDADS:\n        unet = UNetModel(config.data.image_size, 32, dropout=0.3, n_heads=2 ,in_channels=config.data.input_channel)\n    else:\n        unet = UNetModel(config.data.image_size, 64, dropout=0.0, n_heads=4 ,in_channels=config.data.input_channel)\n    return unet\n\ndef train(config):\n    torch.manual_seed(42)\n    np.random.seed(42)\n    unet = build_model(config)\n    print(\" Num params: \", sum(p.numel() for p in unet.parameters()))\n    unet = unet.to(config.model.device)\n    unet.train()\n    unet = torch.nn.DataParallel(unet)\n    # checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'1000'))\n    # unet.load_state_dict(checkpoint)  \n    trainer(unet, config.data.category, config)#config.data.category, \n\n\ndef detection(config):\n    unet = build_model(config)\n    checkpoint = torch.load(os.path.join(os.getcwd(), config.model.checkpoint_dir, config.data.category, str(config.model.load_chp)))\n    unet = torch.nn.DataParallel(unet)\n    unet.load_state_dict(checkpoint)    \n    unet.to(config.model.device)\n    checkpoint = torch.load(os.path.join(os.getcwd(), config.model.checkpoint_dir, config.data.category, str(config.model.load_chp)))\n    unet.eval()\n    ddad = DDAD(unet, config)\n    ddad()\n    \n\ndef finetuning(config):\n    unet = build_model(config)\n    checkpoint = torch.load(os.path.join(os.getcwd(), config.model.checkpoint_dir, config.data.category, str(config.model.load_chp)))\n    unet = torch.nn.DataParallel(unet)\n    unet.load_state_dict(checkpoint)    \n    unet.to(config.model.device)\n    unet.eval()\n    domain_adaptation(unet, config, fine_tune=True)\n\n\n\n\n\ndef parse_args():\n    cmdline_parser = argparse.ArgumentParser('DDAD')    \n    cmdline_parser.add_argument('-cfg', '--config', \n                                default= os.path.join(os.path.dirname(os.path.abspath(__file__)),'config.yaml'), \n                                help='config file')\n    cmdline_parser.add_argument('--train', \n                                default= False, \n                                help='Train the diffusion model')\n    cmdline_parser.add_argument('--detection', \n                                default= False, \n                                help='Detection anomalies')\n    cmdline_parser.add_argument('--domain_adaptation', \n                                default= False, \n                                help='Domain adaptation')\n    args, unknowns = cmdline_parser.parse_known_args()\n    return args\n\n\n    \nif __name__ == \"__main__\":\n    torch.cuda.empty_cache()\n    args = parse_args()\n    config = OmegaConf.load(args.config)\n    print(\"Class: \",config.data.category, \"   w:\", config.model.w, \"   v:\", config.model.v, \"   load_chp:\", config.model.load_chp,   \"   feature extractor:\", config.model.feature_extractor,\"         w_DA: \",config.model.w_DA,\"         DLlambda: \",config.model.DLlambda)\n    print(f'{config.model.test_trajectoy_steps=} , {config.data.test_batch_size=}')\n    torch.manual_seed(42)\n    np.random.seed(42)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(42)\n    if args.train:\n        print('Training...')\n        train(config)\n    if args.domain_adaptation:\n        print('Domain Adaptation...')\n        finetuning(config)\n    if args.detection:\n        print('Detecting Anomalies...')\n        detection(config)\n\n\n        "
  },
  {
    "path": "metrics.py",
    "content": "import torch\nfrom torchmetrics import ROC, AUROC, F1Score\nimport os\nfrom torchvision.transforms import transforms\nfrom skimage import measure\nimport pandas as pd\nfrom statistics import mean\nimport numpy as np\nfrom sklearn.metrics import auc\nfrom sklearn import metrics\nfrom sklearn.metrics import roc_auc_score, roc_curve\n\n\nclass Metric:\n    def __init__(self,labels_list, predictions, anomaly_map_list, gt_list, config) -> None:\n        self.labels_list = labels_list\n        self.predictions = predictions\n        self.anomaly_map_list = anomaly_map_list\n        self.gt_list = gt_list\n        self.config = config\n        self.threshold = 0.5\n    \n    def image_auroc(self):\n        auroc_image = roc_auc_score(self.labels_list, self.predictions)\n        return auroc_image\n    \n    def pixel_auroc(self):\n        resutls_embeddings = self.anomaly_map_list[0]\n        for feature in self.anomaly_map_list[1:]:\n            resutls_embeddings = torch.cat((resutls_embeddings, feature), 0)\n        resutls_embeddings =  ((resutls_embeddings - resutls_embeddings.min())/ (resutls_embeddings.max() - resutls_embeddings.min())) \n\n        gt_embeddings = self.gt_list[0]\n        for feature in self.gt_list[1:]:\n            gt_embeddings = torch.cat((gt_embeddings, feature), 0)\n\n        resutls_embeddings = resutls_embeddings.clone().detach().requires_grad_(False)\n        gt_embeddings = gt_embeddings.clone().detach().requires_grad_(False)\n\n        auroc_p = AUROC(task=\"binary\")\n        \n        gt_embeddings = torch.flatten(gt_embeddings).type(torch.bool).cpu().detach()\n        resutls_embeddings = torch.flatten(resutls_embeddings).cpu().detach()\n        auroc_pixel = auroc_p(resutls_embeddings, gt_embeddings)\n        return auroc_pixel\n    \n    def optimal_threshold(self):\n        fpr, tpr, thresholds = roc_curve(self.labels_list, self.predictions)\n\n        # Calculate Youden's J statistic for each threshold\n        youden_j = tpr - fpr\n\n        # Find the optimal threshold that maximizes Youden's J statistic\n        optimal_threshold_index = np.argmax(youden_j)\n        optimal_threshold = thresholds[optimal_threshold_index]\n        self.threshold = optimal_threshold\n        return optimal_threshold\n    \n\n    def pixel_pro(self):\n        #https://github.com/hq-deng/RD4AD/blob/main/test.py#L337\n        def _compute_pro(masks, amaps, num_th = 200):\n            resutls_embeddings = amaps[0]\n            for feature in amaps[1:]:\n                resutls_embeddings = torch.cat((resutls_embeddings, feature), 0)\n            amaps =  ((resutls_embeddings - resutls_embeddings.min())/ (resutls_embeddings.max() - resutls_embeddings.min())) \n            amaps = amaps.squeeze(1)\n            amaps = amaps.cpu().detach().numpy()\n            gt_embeddings = masks[0]\n            for feature in masks[1:]:\n                gt_embeddings = torch.cat((gt_embeddings, feature), 0)\n            masks = gt_embeddings.squeeze(1).cpu().detach().numpy()\n            min_th = amaps.min()\n            max_th = amaps.max()\n            delta = (max_th - min_th) / num_th\n            binary_amaps = np.zeros_like(amaps)\n            df = pd.DataFrame([], columns=[\"pro\", \"fpr\", \"threshold\"])\n\n            for th in np.arange(min_th, max_th, delta):\n                binary_amaps[amaps <= th] = 0\n                binary_amaps[amaps > th] = 1\n\n                pros = []\n                for binary_amap, mask in zip(binary_amaps, masks):\n                    for region in measure.regionprops(measure.label(mask)):\n                        axes0_ids = region.coords[:, 0]\n                        axes1_ids = region.coords[:, 1]\n                        tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()\n                        pros.append(tp_pixels / region.area)\n\n                inverse_masks = 1 - masks\n                fp_pixels = np.logical_and(inverse_masks , binary_amaps).sum()\n                fpr = fp_pixels / inverse_masks.sum()\n                # print(f\"Threshold: {th}, FPR: {fpr}, PRO: {mean(pros)}\")\n\n                df = pd.concat([df, pd.DataFrame({\"pro\": mean(pros), \"fpr\": fpr, \"threshold\": th}, index=[0])], ignore_index=True)\n                # df = df.concat({\"pro\": mean(pros), \"fpr\": fpr, \"threshold\": th}, ignore_index=True)\n\n            # Normalize FPR from 0 ~ 1 to 0 ~ 0.3\n            df = df[df[\"fpr\"] < 0.3]\n            df[\"fpr\"] = df[\"fpr\"] / df[\"fpr\"].max()\n\n            pro_auc = auc(df[\"fpr\"], df[\"pro\"])\n            return pro_auc\n        \n        pro = _compute_pro(self.gt_list, self.anomaly_map_list, num_th = 200)\n        return pro\n    \n\n    def miscalssified(self):\n        predictions = torch.tensor(self.predictions)\n        labels_list = torch.tensor(self.labels_list)\n        predictions0_1 = (predictions > self.threshold).int()\n        for i,(l,p) in enumerate(zip(labels_list, predictions0_1)):\n            print('Sample : ', i, ' predicted as: ',p.item() ,' label is: ',l.item(),'\\n' ) if l != p else None\n\n"
  },
  {
    "path": "reconstruction.py",
    "content": "from typing import Any\nimport torch\n# from forward_process import *\nimport numpy as np\nimport os\nos.environ['CUDA_VISIBLE_DEVICES'] = \"0,1,2\"\n\nclass Reconstruction:\n    '''\n    The reconstruction process\n    :param y: the target image\n    :param x: the input image\n    :param seq: the sequence of denoising steps\n    :param unet: the UNet model\n    :param x0_t: the prediction of x0 at time step t\n    '''\n    def __init__(self, unet, config) -> None:\n        self.unet = unet\n        self.config = config\n\n    \n    \n    def __call__(self, x, y0, w) -> Any:\n        def _compute_alpha(t):\n            betas = np.linspace(self.config.model.beta_start, self.config.model.beta_end, self.config.model.trajectory_steps, dtype=np.float64)\n            betas = torch.tensor(betas).type(torch.float).to(self.config.model.device)\n            beta = torch.cat([torch.zeros(1).to(self.config.model.device), betas], dim=0)\n            beta = beta.to(self.config.model.device)\n            a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)\n            return a\n        \n        test_trajectoy_steps = torch.Tensor([self.config.model.test_trajectoy_steps]).type(torch.int64).to(self.config.model.device).long()\n        at = _compute_alpha(test_trajectoy_steps)\n        xt = at.sqrt() * x + (1- at).sqrt() * torch.randn_like(x).to(self.config.model.device)\n        seq = range(0 , self.config.model.test_trajectoy_steps, self.config.model.skip)\n\n\n        with torch.no_grad():\n            n = x.size(0)\n            seq_next = [-1] + list(seq[:-1])\n            xs = [xt]\n            for index, (i, j) in enumerate(zip(reversed(seq), reversed(seq_next))):\n                t = (torch.ones(n) * i).to(self.config.model.device)\n                next_t = (torch.ones(n) * j).to(self.config.model.device)\n                at = _compute_alpha(t.long())\n                at_next = _compute_alpha(next_t.long())\n                xt = xs[-1].to(self.config.model.device)\n                self.unet = self.unet.to(self.config.model.device)\n                et = self.unet(xt, t)\n                yt = at.sqrt() * y0 + (1- at).sqrt() *  et\n                et_hat = et - (1 - at).sqrt() * w * (yt-xt)\n                x0_t = (xt - et_hat * (1 - at).sqrt()) / at.sqrt()\n                c1 = (\n                    self.config.model.eta * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()\n                )\n                c2 = ((1 - at_next) - c1 ** 2).sqrt()\n                xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et_hat\n                xs.append(xt_next)\n        return xs\n\n         \n\n\n\n"
  },
  {
    "path": "requirements.txt",
    "content": "\nkornia==0.6.12\nmatplotlib==3.7.1\nnumpy==1.24.3\nomegaconf==2.1.2\nopencv-python-headless==4.5.5.64\npandas==2.0.1\nPillow==9.5.0\nscikit-image==0.19.2\nscikit-learn==1.2.2\nscipy==1.10.1\nsklearn==0.0.post5\ntorch==2.0.1\ntorchmetrics==0.11.4\ntorchvision==0.15.2"
  },
  {
    "path": "resnet.py",
    "content": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\nfrom typing import Type, Any, Callable, Union, List, Optional\n\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n           'wide_resnet50_2', 'wide_resnet101_2']\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',\n    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n}\n\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ) -> None:\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,\n                                       dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,\n                                       dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,\n                                       dilate=replace_stride_with_dilation[2])\n        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        feature_a = self.layer1(x)\n        feature_b = self.layer2(feature_a)\n        feature_c = self.layer3(feature_b)\n        feature_d = self.layer4(feature_c)\n\n\n        return [feature_a, feature_b, feature_c]\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n) -> ResNet:\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch],\n                                              progress=progress)\n        #for k,v in list(state_dict.items()):\n        #    if 'layer4' in k or 'fc' in k:\n        #        state_dict.pop(k)\n        model.load_state_dict(state_dict)\n    return model\n\nclass AttnBasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n        attention: bool = True,\n    ) -> None:\n        super(AttnBasicBlock, self).__init__()\n        self.attention = attention\n        #print(\"Attention:\", self.attention)\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        #self.cbam = GLEAM(planes, 16)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        #if self.attention:\n        #    x = self.cbam(x)\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\nclass AttnBottleneck(nn.Module):\n    \n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None,\n        attention: bool = True,\n    ) -> None:\n        super(AttnBottleneck, self).__init__()\n        self.attention = attention\n        #print(\"Attention:\",self.attention)\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        #self.cbam = GLEAM([int(planes * self.expansion/4),\n        #                   int(planes * self.expansion//2),\n        #                   planes * self.expansion], 16)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor) -> Tensor:\n        #if self.attention:\n        #    x = self.cbam(x)\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\n    def __init__(self,\n                 block: Type[Union[BasicBlock, Bottleneck]],\n                 layers: int,\n                 groups: int = 1,\n                 width_per_group: int = 64,\n                 norm_layer: Optional[Callable[..., nn.Module]] = None,\n                 ):\n        super(BN_layer, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n        self.groups = groups\n        self.base_width = width_per_group\n        self.inplanes = 256 * block.expansion\n        self.dilation = 1\n        self.bn_layer = self._make_layer(block, 512, layers, stride=2)\n\n        self.conv1 = conv3x3(64 * block.expansion, 128 * block.expansion, 2)\n        self.bn1 = norm_layer(128 * block.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(128 * block.expansion, 256 * block.expansion, 2)\n        self.bn2 = norm_layer(256 * block.expansion)\n        self.conv3 = conv3x3(128 * block.expansion, 256 * block.expansion, 2)\n        self.bn3 = norm_layer(256 * block.expansion)\n\n        self.conv4 = conv1x1(1024 * block.expansion, 512 * block.expansion, 1)\n        self.bn4 = norm_layer(512 * block.expansion)\n\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False) -> nn.Sequential:\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes*3, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),\n            )\n\n        layers = []\n        layers.append(block(self.inplanes*3, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor) -> Tensor:\n        # See note [TorchScript super()]\n        #x = self.cbam(x)\n        l1 = self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x[0]))))))\n        l2 = self.relu(self.bn3(self.conv3(x[1])))\n        feature = torch.cat([l1,l2,x[2]],1)\n        output = self.bn_layer(feature)\n        #x = self.avgpool(feature_d)\n        #x = torch.flatten(x, 1)\n        #x = self.fc(x)\n\n        return output.contiguous()\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self._forward_impl(x)\n\n\ndef resnet18(pretrained: bool = False, progress: bool = True,**kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n"
  },
  {
    "path": "train.py",
    "content": "import torch\nimport os\nimport torch.nn as nn\nfrom dataset import *\n\nfrom dataset import *\nfrom loss import *\n\n\ndef trainer(model, category, config):\n    '''\n    Training the UNet model\n    :param model: the UNet model\n    :param category: the category of the dataset\n    '''\n    # optimizer = torch.optim.AdamW(\n    #     model.parameters(), lr=config.model.learning_rate)\n    optimizer = torch.optim.Adam(\n        model.parameters(), lr=config.model.learning_rate, weight_decay=config.model.weight_decay\n    )\n    train_dataset = Dataset_maker(\n        root= config.data.data_dir,\n        category=category,\n        config = config,\n        is_train=True,\n    )\n    trainloader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=config.data.batch_size,\n        shuffle=True,\n        num_workers=config.model.num_workers,\n        drop_last=True,\n    )\n    if not os.path.exists('checkpoints'):\n        os.mkdir('checkpoints')\n    if not os.path.exists(config.model.checkpoint_dir):\n        os.mkdir(config.model.checkpoint_dir)\n\n\n    for epoch in range(config.model.epochs):\n        for step, batch in enumerate(trainloader):\n            optimizer.zero_grad()\n            # loss = 0\n            # for _ in range(2):\n            t = torch.randint(0, config.model.trajectory_steps, (batch[0].shape[0],), device=config.model.device).long()\n            loss = get_loss(model, batch[0], t, config) \n            loss.backward()\n            optimizer.step()\n            if (epoch+1) % 25 == 0 and step == 0:\n                print(f\"Epoch {epoch+1} | Loss: {loss.item()}\")\n            if (epoch+1) %250 == 0 and epoch>0 and step ==0:\n                if config.model.save_model:\n                    model_save_dir = os.path.join(os.getcwd(), config.model.checkpoint_dir, category)\n                    if not os.path.exists(model_save_dir):\n                        os.mkdir(model_save_dir)\n                    torch.save(model.state_dict(), os.path.join(model_save_dir, str(epoch+1)))\n                \n"
  },
  {
    "path": "unet.py",
    "content": "# https://github.com/openai/guided-diffusion/tree/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924\nimport math\nfrom abc import abstractmethod\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb):\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            else:\n                x = layer(x)\n        return x\n\n\nclass PositionalEmbedding(nn.Module):\n    # PositionalEmbedding\n    \"\"\"\n    Computes Positional Embedding of the timestep\n    \"\"\"\n\n    def __init__(self, dim, scale=1):\n        super().__init__()\n        assert dim % 2 == 0\n        self.dim = dim\n        self.scale = scale\n\n    def forward(self, x):\n        device = x.device\n        half_dim = self.dim // 2\n        emb = np.log(10000) / half_dim\n        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)\n        emb = torch.outer(x * self.scale, emb)\n        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)\n        return emb\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, use_conv, out_channels=None):\n        super().__init__()\n        self.channels = in_channels\n        out_channels = out_channels or in_channels\n        if use_conv:\n            # downsamples by 1/2\n            self.downsample = nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)\n        else:\n            assert in_channels == out_channels\n            self.downsample = nn.AvgPool2d(kernel_size=2, stride=2)\n\n    def forward(self, x, time_embed=None):\n        assert x.shape[1] == self.channels\n        return self.downsample(x)\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, use_conv, out_channels=None):\n        super().__init__()\n        self.channels = in_channels\n        self.use_conv = use_conv\n        # uses upsample then conv to avoid checkerboard artifacts\n        # self.upsample = nn.Upsample(scale_factor=2, mode=\"nearest\")\n        if use_conv:\n            self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=1)\n\n    def forward(self, x, time_embed=None):\n        assert x.shape[1] == self.channels\n        x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(self, in_channels, n_heads=1, n_head_channels=-1):\n        super().__init__()\n        self.in_channels = in_channels\n        self.norm = GroupNorm32(32, self.in_channels)\n        if n_head_channels == -1:\n            self.num_heads = n_heads\n        else:\n            assert (\n                    in_channels % n_head_channels == 0\n            ), f\"q,k,v channels {in_channels} is not divisible by num_head_channels {n_head_channels}\"\n            self.num_heads = in_channels // n_head_channels\n\n        # query, key, value for attention\n        self.to_qkv = nn.Conv1d(in_channels, in_channels * 3, 1)\n        self.attention = QKVAttention(self.num_heads)\n        self.proj_out = zero_module(nn.Conv1d(in_channels, in_channels, 1))\n\n    def forward(self, x, time=None):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1)\n        qkv = self.to_qkv(self.norm(x))\n        h = self.attention(qkv)\n        h = self.proj_out(h)\n        return (x + h).reshape(b, c, *spatial)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv, time=None):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = torch.einsum(\n                \"bct,bcs->bts\", q * scale, k * scale\n                )  # More stable with f16 than dividing afterwards\n        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = torch.einsum(\"bts,bcs->bct\", weight, v)\n        return a.reshape(bs, -1, length)\n\n\nclass ResBlock(TimestepBlock):\n    def __init__(\n            self,\n            in_channels,\n            time_embed_dim,\n            dropout,\n            out_channels=None,\n            use_conv=False,\n            up=False,\n            down=False\n            ):\n        super().__init__()\n        out_channels = out_channels or in_channels\n        self.in_layers = nn.Sequential(\n                GroupNorm32(32, in_channels),\n                nn.SiLU(),\n                nn.Conv2d(in_channels, out_channels, 3, padding=1)\n                )\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(in_channels, False)\n            self.x_upd = Upsample(in_channels, False)\n        elif down:\n            self.h_upd = Downsample(in_channels, False)\n            self.x_upd = Downsample(in_channels, False)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.embed_layers = nn.Sequential(\n                nn.SiLU(),\n                nn.Linear(time_embed_dim, out_channels)\n                )\n        self.out_layers = nn.Sequential(\n                GroupNorm32(32, out_channels),\n                nn.SiLU(),\n                nn.Dropout(p=dropout),\n                zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))\n                )\n        if out_channels == in_channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = nn.Conv2d(in_channels, out_channels, 3, padding=1)\n        else:\n            self.skip_connection = nn.Conv2d(in_channels, out_channels, 1)\n\n    def forward(self, x, time_embed):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.embed_layers(time_embed).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n\n        h = h + emb_out\n        h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass UNetModel(nn.Module):\n    # UNet model\n    def __init__(\n            self,\n            img_size,\n            base_channels,\n            conv_resample=True,\n            n_heads=1,\n            n_head_channels=-1,\n            channel_mults=\"\",\n            num_res_blocks=2,\n            dropout=0,\n            attention_resolutions=\"32,16,8\",\n            biggan_updown=True,\n            in_channels=1\n            ):\n        self.dtype = torch.float32\n        super().__init__()\n\n        if channel_mults == \"\":\n            if img_size == 512:\n                channel_mults = (0.5, 1, 1, 2, 2, 4, 4)\n            elif img_size == 256:\n                channel_mults = (1, 1, 2, 2, 4, 4)# \n            elif img_size == 128:\n                channel_mults = (1, 1, 2, 3, 4)\n            elif img_size == 64:\n                channel_mults = (1, 2, 3, 4)\n            elif img_size == 32:\n                channel_mults = (1, 2, 3, 4)\n            else:\n                raise ValueError(f\"unsupported image size: {img_size}\")\n        attention_ds = []\n        for res in attention_resolutions.split(\",\"):\n            attention_ds.append(img_size // int(res))\n\n        self.image_size = img_size\n        self.in_channels = in_channels\n        self.model_channels = base_channels\n        self.out_channels = in_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mults\n        self.conv_resample = conv_resample\n\n        self.dtype = torch.float32\n        self.num_heads = n_heads\n        self.num_head_channels = n_head_channels\n\n        time_embed_dim = base_channels * 4\n        self.time_embedding = nn.Sequential(\n                PositionalEmbedding(base_channels, 1),\n                nn.Linear(base_channels, time_embed_dim),\n                nn.SiLU(),\n                nn.Linear(time_embed_dim, time_embed_dim),\n                )\n\n        ch = int(channel_mults[0] * base_channels)\n        self.down = nn.ModuleList(\n                [TimestepEmbedSequential(nn.Conv2d(self.in_channels, base_channels, 3, padding=1))]\n                )\n        channels = [ch]\n        ds = 1\n        for i, mult in enumerate(channel_mults):\n            # out_channels = base_channels * mult\n\n            for _ in range(num_res_blocks):\n                layers = [ResBlock(\n                        ch,\n                        time_embed_dim=time_embed_dim,\n                        out_channels=base_channels * mult,\n                        dropout=dropout,\n                        )]\n                ch = base_channels * mult\n                # channels.append(ch)\n\n                if ds in attention_ds:\n                    layers.append(\n                            AttentionBlock(\n                                    ch,\n                                    n_heads=n_heads,\n                                    n_head_channels=n_head_channels,\n                                    )\n                            )\n                self.down.append(TimestepEmbedSequential(*layers))\n                channels.append(ch)\n            if i != len(channel_mults) - 1:\n                out_channels = ch\n                self.down.append(\n                        TimestepEmbedSequential(\n                                ResBlock(\n                                        ch,\n                                        time_embed_dim=time_embed_dim,\n                                        out_channels=out_channels,\n                                        dropout=dropout,\n                                        down=True\n                                        )\n                                if biggan_updown\n                                else\n                                Downsample(ch, conv_resample, out_channels=out_channels)\n                                )\n                        )\n                ds *= 2\n                ch = out_channels\n                channels.append(ch)\n\n        self.middle = TimestepEmbedSequential(\n                ResBlock(\n                        ch,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout\n                        ),\n                AttentionBlock(\n                        ch,\n                        n_heads=n_heads,\n                        n_head_channels=n_head_channels\n                        ),\n                ResBlock(\n                        ch,\n                        time_embed_dim=time_embed_dim,\n                        dropout=dropout\n                        )\n                )\n        self.up = nn.ModuleList([])\n\n        for i, mult in reversed(list(enumerate(channel_mults))):\n            for j in range(num_res_blocks + 1):\n                inp_chs = channels.pop()\n                layers = [\n                    ResBlock(\n                            ch + inp_chs,\n                            time_embed_dim=time_embed_dim,\n                            out_channels=base_channels * mult,\n                            dropout=dropout\n                            )\n                    ]\n                ch = base_channels * mult\n                if ds in attention_ds:\n                    layers.append(\n                            AttentionBlock(\n                                    ch,\n                                    n_heads=n_heads,\n                                    n_head_channels=n_head_channels\n                                    ),\n                            )\n\n                if i and j == num_res_blocks:\n                    out_channels = ch\n                    layers.append(\n                            ResBlock(\n                                    ch,\n                                    time_embed_dim=time_embed_dim,\n                                    out_channels=out_channels,\n                                    dropout=dropout,\n                                    up=True\n                                    )\n                            if biggan_updown\n                            else\n                            Upsample(ch, conv_resample, out_channels=out_channels)\n                            )\n                    ds //= 2\n                self.up.append(TimestepEmbedSequential(*layers))\n\n        self.out = nn.Sequential(\n                GroupNorm32(32, ch),\n                nn.SiLU(),\n                zero_module(nn.Conv2d(base_channels * channel_mults[0], self.out_channels, 3, padding=1))\n                )\n\n    def forward(self, x, time):\n\n        time_embed = self.time_embedding(time)\n\n        skips = []\n\n        h = x.type(self.dtype)\n        for i, module in enumerate(self.down):\n            h = module(h, time_embed)\n            skips.append(h)\n        h = self.middle(h, time_embed)\n        for i, module in enumerate(self.up):\n            h = torch.cat([h, skips.pop()], dim=1)\n            h = module(h, time_embed)\n        h = h.type(x.dtype)\n        h = self.out(h)\n        return h\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef update_ema_params(target, source, decay_rate=0.9999):\n    targParams = dict(target.named_parameters())\n    srcParams = dict(source.named_parameters())\n    for k in targParams:\n        targParams[k].data.mul_(decay_rate).add_(srcParams[k].data, alpha=1 - decay_rate)\n\n"
  },
  {
    "path": "visualize.py",
    "content": "import matplotlib.pyplot as plt\nfrom torchvision.transforms import transforms\nimport numpy as np\nimport torch\nimport os\nfrom dataset import *\n\n\ndef visualalize_reconstruction(input, recon, target):\n    plt.figure(figsize=(11,11))\n    plt.subplot(1, 3, 1).axis('off')\n    plt.subplot(1, 3, 2).axis('off')\n    plt.subplot(1, 3, 3).axis('off')\n\n    plt.subplot(1, 3, 1)\n    plt.imshow(show_tensor_image(input))\n    plt.title('input image')\n    \n\n    plt.subplot(1, 3, 2)\n    plt.imshow(show_tensor_mask(recon))\n    plt.title('recon image')\n\n    plt.subplot(1, 3, 3)\n    plt.imshow(show_tensor_mask(target))\n    plt.title('target image')\n\n\n    k = 0\n    while os.path.exists('results/heatmap{}.png'.format(k)):\n        k += 1\n    plt.savefig('results/heatmap{}.png'.format(k))\n    plt.close()\n\n\n# def visualize_reconstructed(input, data,s):\n#     fig, axs = plt.subplots(int(len(data)/5),6)\n#     row = 0\n#     col = 1\n#     axs[0,0].imshow(show_tensor_image(input))\n#     axs[0, 0].get_xaxis().set_visible(False)\n#     axs[0, 0].get_yaxis().set_visible(False)\n#     axs[0,0].set_title('input')\n#     for i, img in enumerate(data):\n#         axs[row, col].imshow(show_tensor_image(img))\n#         axs[row, col].get_xaxis().set_visible(False)\n#         axs[row, col].get_yaxis().set_visible(False)\n#         axs[row, col].set_title(str(i))\n#         col += 1\n#         if col == 6:\n#             row += 1\n#             col = 0\n#     col = 6\n#     row = int(len(data)/5)\n#     remain = col * row - len(data) -1\n#     for j in range(remain):\n#         col -= 1\n#         axs[row-1, col].remove()\n#         axs[row-1, col].get_xaxis().set_visible(False)\n#         axs[row-1, col].get_yaxis().set_visible(False)\n        \n    \n        \n#     plt.subplots_adjust(left=0.1,\n#                     bottom=0.1,\n#                     right=0.9,\n#                     top=0.9,\n#                     wspace=0.4,\n#                     hspace=0.4)\n#     k = 0\n\n#     while os.path.exists(f'results/reconstructed{k}{s}.png'):\n#         k += 1\n#     plt.savefig(f'results/reconstructed{k}{s}.png')\n#     plt.close()\n\n\n\ndef visualize(image, noisy_image, GT, pred_mask, anomaly_map, category) :\n    for idx, img in enumerate(image):\n        plt.figure(figsize=(11,11))\n        plt.subplot(1, 2, 1).axis('off')\n        plt.subplot(1, 2, 2).axis('off')\n        plt.subplot(1, 2, 1)\n        plt.imshow(show_tensor_image(image[idx]))\n        plt.title('clear image')\n\n        plt.subplot(1, 2, 2)\n\n        plt.imshow(show_tensor_image(noisy_image[idx]))\n        plt.title('reconstructed image')\n        plt.savefig('results/{}sample{}.png'.format(category,idx))\n        plt.close()\n\n        plt.figure(figsize=(11,11))\n        plt.subplot(1, 3, 1).axis('off')\n        plt.subplot(1, 3, 2).axis('off')\n        plt.subplot(1, 3, 3).axis('off')\n\n        plt.subplot(1, 3, 1)\n        plt.imshow(show_tensor_mask(GT[idx]))\n        plt.title('ground truth')\n\n        plt.subplot(1, 3, 2)\n        plt.imshow(show_tensor_mask(pred_mask[idx]))\n        plt.title('normal' if torch.max(pred_mask[idx]) == 0 else 'abnormal', color=\"g\" if torch.max(pred_mask[idx]) == 0 else \"r\")\n\n        plt.subplot(1, 3, 3)\n        plt.imshow(show_tensor_image(anomaly_map[idx]))\n        plt.title('heat map')\n        plt.savefig('results/{}sample{}heatmap.png'.format(category,idx))\n        plt.close()\n\n\n\ndef show_tensor_image(image):\n    reverse_transforms = transforms.Compose([\n        transforms.Lambda(lambda t: (t + 1) / (2)),\n        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC\n        transforms.Lambda(lambda t: t * 255.),\n        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),\n    ])\n\n    # Takes the first image of batch\n    if len(image.shape) == 4:\n        image = image[0, :, :, :] \n    return reverse_transforms(image)\n\ndef show_tensor_mask(image):\n    reverse_transforms = transforms.Compose([\n        # transforms.Lambda(lambda t: (t + 1) / (2)),\n        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC\n        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.int8)),\n    ])\n\n    # Takes the first image of batch\n    if len(image.shape) == 4:\n        image = image[0, :, :, :] \n    return reverse_transforms(image)\n        \n\n"
  }
]