[
  {
    "path": ".gitignore",
    "content": "__pycache__/*\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 DonaldRR\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# SimpleNet\n\n\n![](imgs/cover.png)\n\n**SimpleNet: A Simple Network for Image Anomaly Detection and Localization**\n\n*Zhikang Liu, Yiming Zhou, Yuansheng Xu, Zilei Wang**\n\n[Paper link](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)\n\n##  Introduction\n\nThis repo contains source code for **SimpleNet** implemented with pytorch.\n\nSimpleNet is a simple defect detection and localization network that built with a feature encoder, feature generator and defect discriminator. It is designed conceptionally simple without complex network deisng, training schemes or external data source.\n\n## Get Started \n\n### Environment \n\n**Python3.8**\n\n**Packages**:\n- torch==1.12.1\n- torchvision==0.13.1\n- numpy==1.22.4\n- opencv-python==4.5.1\n\n(Above environment setups are not the minimum requiremetns, other versions might work too.)\n\n\n### Data\n\nEdit `run.sh` to edit dataset class and dataset path.\n\n#### MvTecAD\n\nDownload the dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad/).\n\nThe dataset folders/files follow its original structure.\n\n### Run\n\n#### Demo train\n\nPlease specicy dataset path (line1) and log folder (line10) in `run.sh` before running.\n\n`run.sh` gives the configuration to train models on MVTecAD dataset.\n```\nbash run.sh\n```\n\n## Citation\n```\n@inproceedings{liu2023simplenet,\n  title={SimpleNet: A Simple Network for Image Anomaly Detection and Localization},\n  author={Liu, Zhikang and Zhou, Yiming and Xu, Yuansheng and Wang, Zilei},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={20402--20411},\n  year={2023}\n}\n```\n\n## Acknowledgement\n\nThanks for great inspiration from [PatchCore](https://github.com/amazon-science/patchcore-inspection)\n\n## License\n\nAll code within the repo is under [MIT license](https://mit-license.org/)\n"
  },
  {
    "path": "VERSION",
    "content": "0.1.0\n"
  },
  {
    "path": "backbones.py",
    "content": "import timm  # noqa\nimport torch\nimport torchvision.models as models  # noqa\n\ndef load_ref_wrn50():\n    \n    import resnet \n    return resnet.wide_resnet50_2(True)\n\n_BACKBONES = {\n    \"cait_s24_224\" : \"cait.cait_S24_224(True)\",\n    \"cait_xs24\": \"cait.cait_XS24(True)\",\n    \"alexnet\": \"models.alexnet(pretrained=True)\",\n    \"bninception\": 'pretrainedmodels.__dict__[\"bninception\"]'\n    '(pretrained=\"imagenet\", num_classes=1000)',\n    \"resnet18\": \"models.resnet18(pretrained=True)\",\n    \"resnet50\": \"models.resnet50(pretrained=True)\",\n    \"mc3_resnet50\": \"load_mc3_rn50()\", \n    \"resnet101\": \"models.resnet101(pretrained=True)\",\n    \"resnext101\": \"models.resnext101_32x8d(pretrained=True)\",\n    \"resnet200\": 'timm.create_model(\"resnet200\", pretrained=True)',\n    \"resnest50\": 'timm.create_model(\"resnest50d_4s2x40d\", pretrained=True)',\n    \"resnetv2_50_bit\": 'timm.create_model(\"resnetv2_50x3_bitm\", pretrained=True)',\n    \"resnetv2_50_21k\": 'timm.create_model(\"resnetv2_50x3_bitm_in21k\", pretrained=True)',\n    \"resnetv2_101_bit\": 'timm.create_model(\"resnetv2_101x3_bitm\", pretrained=True)',\n    \"resnetv2_101_21k\": 'timm.create_model(\"resnetv2_101x3_bitm_in21k\", pretrained=True)',\n    \"resnetv2_152_bit\": 'timm.create_model(\"resnetv2_152x4_bitm\", pretrained=True)',\n    \"resnetv2_152_21k\": 'timm.create_model(\"resnetv2_152x4_bitm_in21k\", pretrained=True)',\n    \"resnetv2_152_384\": 'timm.create_model(\"resnetv2_152x2_bit_teacher_384\", pretrained=True)',\n    \"resnetv2_101\": 'timm.create_model(\"resnetv2_101\", pretrained=True)',\n    \"vgg11\": \"models.vgg11(pretrained=True)\",\n    \"vgg19\": \"models.vgg19(pretrained=True)\",\n    \"vgg19_bn\": \"models.vgg19_bn(pretrained=True)\",\n    \"wideresnet50\": \"models.wide_resnet50_2(pretrained=True)\",\n    \"ref_wideresnet50\": \"load_ref_wrn50()\",\n    \"wideresnet101\": \"models.wide_resnet101_2(pretrained=True)\",\n    \"mnasnet_100\": 'timm.create_model(\"mnasnet_100\", pretrained=True)',\n    \"mnasnet_a1\": 'timm.create_model(\"mnasnet_a1\", pretrained=True)',\n    \"mnasnet_b1\": 'timm.create_model(\"mnasnet_b1\", pretrained=True)',\n    \"densenet121\": 'timm.create_model(\"densenet121\", pretrained=True)',\n    \"densenet201\": 'timm.create_model(\"densenet201\", pretrained=True)',\n    \"inception_v4\": 'timm.create_model(\"inception_v4\", pretrained=True)',\n    \"vit_small\": 'timm.create_model(\"vit_small_patch16_224\", pretrained=True)',\n    \"vit_base\": 'timm.create_model(\"vit_base_patch16_224\", pretrained=True)',\n    \"vit_large\": 'timm.create_model(\"vit_large_patch16_224\", pretrained=True)',\n    \"vit_r50\": 'timm.create_model(\"vit_large_r50_s32_224\", pretrained=True)',\n    \"vit_deit_base\": 'timm.create_model(\"deit_base_patch16_224\", pretrained=True)',\n    \"vit_deit_distilled\": 'timm.create_model(\"deit_base_distilled_patch16_224\", pretrained=True)',\n    \"vit_swin_base\": 'timm.create_model(\"swin_base_patch4_window7_224\", pretrained=True)',\n    \"vit_swin_large\": 'timm.create_model(\"swin_large_patch4_window7_224\", pretrained=True)',\n    \"efficientnet_b7\": 'timm.create_model(\"tf_efficientnet_b7\", pretrained=True)',\n    \"efficientnet_b5\": 'timm.create_model(\"tf_efficientnet_b5\", pretrained=True)',\n    \"efficientnet_b3\": 'timm.create_model(\"tf_efficientnet_b3\", pretrained=True)',\n    \"efficientnet_b1\": 'timm.create_model(\"tf_efficientnet_b1\", pretrained=True)',\n    \"efficientnetv2_m\": 'timm.create_model(\"tf_efficientnetv2_m\", pretrained=True)',\n    \"efficientnetv2_l\": 'timm.create_model(\"tf_efficientnetv2_l\", pretrained=True)',\n    \"efficientnet_b3a\": 'timm.create_model(\"efficientnet_b3a\", pretrained=True)',\n}\n\n\ndef load(name):\n    return eval(_BACKBONES[name])\n"
  },
  {
    "path": "common.py",
    "content": "import copy\nfrom typing import List\n\nimport numpy as np\nimport scipy.ndimage as ndimage\nimport torch\nimport torch.nn.functional as F\n\n\nclass _BaseMerger:\n    def __init__(self):\n        \"\"\"Merges feature embedding by name.\"\"\"\n\n    def merge(self, features: list):\n        features = [self._reduce(feature) for feature in features]\n        return np.concatenate(features, axis=1)\n\n\nclass AverageMerger(_BaseMerger):\n    @staticmethod\n    def _reduce(features):\n        # NxCxWxH -> NxC\n        return features.reshape([features.shape[0], features.shape[1], -1]).mean(\n            axis=-1\n        )\n\n\nclass ConcatMerger(_BaseMerger):\n    @staticmethod\n    def _reduce(features):\n        # NxCxWxH -> NxCWH\n        return features.reshape(len(features), -1)\n\n\nclass Preprocessing(torch.nn.Module):\n    def __init__(self, input_dims, output_dim):\n        super(Preprocessing, self).__init__()\n        self.input_dims = input_dims\n        self.output_dim = output_dim\n\n        self.preprocessing_modules = torch.nn.ModuleList()\n        for input_dim in input_dims:\n            module = MeanMapper(output_dim)\n            self.preprocessing_modules.append(module)\n\n    def forward(self, features):\n        _features = []\n        for module, feature in zip(self.preprocessing_modules, features):\n            _features.append(module(feature))\n        return torch.stack(_features, dim=1)\n\n\nclass MeanMapper(torch.nn.Module):\n    def __init__(self, preprocessing_dim):\n        super(MeanMapper, self).__init__()\n        self.preprocessing_dim = preprocessing_dim\n\n    def forward(self, features):\n        features = features.reshape(len(features), 1, -1)\n        return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)\n\n\nclass Aggregator(torch.nn.Module):\n    def __init__(self, target_dim):\n        super(Aggregator, self).__init__()\n        self.target_dim = target_dim\n\n    def forward(self, features):\n        \"\"\"Returns reshaped and average pooled features.\"\"\"\n        # batchsize x number_of_layers x input_dim -> batchsize x target_dim\n        features = features.reshape(len(features), 1, -1)\n        features = F.adaptive_avg_pool1d(features, self.target_dim)\n        return features.reshape(len(features), -1)\n\n\nclass RescaleSegmentor:\n    def __init__(self, device, target_size=224):\n        self.device = device\n        self.target_size = target_size\n        self.smoothing = 4\n\n    def convert_to_segmentation(self, patch_scores, features):\n\n        with torch.no_grad():\n            if isinstance(patch_scores, np.ndarray):\n                patch_scores = torch.from_numpy(patch_scores)\n            _scores = patch_scores.to(self.device)\n            _scores = _scores.unsqueeze(1)\n            _scores = F.interpolate(\n                _scores, size=self.target_size, mode=\"bilinear\", align_corners=False\n            )\n            _scores = _scores.squeeze(1)\n            patch_scores = _scores.cpu().numpy()\n\n            if isinstance(features, np.ndarray):\n                features = torch.from_numpy(features)\n            features = features.to(self.device).permute(0, 3, 1, 2)\n            if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2**31:\n                subbatch_size = int((2**31-1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))\n                interpolated_features = []\n                for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):\n                    subfeatures = features[i_subbatch*subbatch_size:(i_subbatch+1)*subbatch_size]\n                    subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures\n                    subfeatures = F.interpolate(\n                        subfeatures, size=self.target_size, mode=\"bilinear\", align_corners=False\n                    )\n                    interpolated_features.append(subfeatures)\n                features = torch.cat(interpolated_features, 0)\n            else:\n                features = F.interpolate(\n                    features, size=self.target_size, mode=\"bilinear\", align_corners=False\n                )\n            features = features.cpu().numpy()\n\n        return [\n            ndimage.gaussian_filter(patch_score, sigma=self.smoothing)\n            for patch_score in patch_scores\n        ], [ \n            feature\n            for feature in features\n        ]\n\n\nclass NetworkFeatureAggregator(torch.nn.Module):\n    \"\"\"Efficient extraction of network features.\"\"\"\n\n    def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):\n        super(NetworkFeatureAggregator, self).__init__()\n        \"\"\"Extraction of network features.\n\n        Runs a network only to the last layer of the list of layers where\n        network features should be extracted from.\n\n        Args:\n            backbone: torchvision.model\n            layers_to_extract_from: [list of str]\n        \"\"\"\n        self.layers_to_extract_from = layers_to_extract_from\n        self.backbone = backbone\n        self.device = device\n        self.train_backbone = train_backbone\n        if not hasattr(backbone, \"hook_handles\"):\n            self.backbone.hook_handles = []\n        for handle in self.backbone.hook_handles:\n            handle.remove()\n        self.outputs = {}\n\n        for extract_layer in layers_to_extract_from:\n            forward_hook = ForwardHook(\n                self.outputs, extract_layer, layers_to_extract_from[-1]\n            )\n            if \".\" in extract_layer:\n                extract_block, extract_idx = extract_layer.split(\".\")\n                network_layer = backbone.__dict__[\"_modules\"][extract_block]\n                if extract_idx.isnumeric():\n                    extract_idx = int(extract_idx)\n                    network_layer = network_layer[extract_idx]\n                else:\n                    network_layer = network_layer.__dict__[\"_modules\"][extract_idx]\n            else:\n                network_layer = backbone.__dict__[\"_modules\"][extract_layer]\n\n            if isinstance(network_layer, torch.nn.Sequential):\n                self.backbone.hook_handles.append(\n                    network_layer[-1].register_forward_hook(forward_hook)\n                )\n            else:\n                self.backbone.hook_handles.append(\n                    network_layer.register_forward_hook(forward_hook)\n                )\n        self.to(self.device)\n\n    def forward(self, images, eval=True):\n        self.outputs.clear()\n        if self.train_backbone and not eval:\n            self.backbone(images)\n        else:\n            with torch.no_grad():\n                # The backbone will throw an Exception once it reached the last\n                # layer to compute features from. Computation will stop there.\n                try:\n                    _ = self.backbone(images)\n                except LastLayerToExtractReachedException:\n                    pass\n        return self.outputs\n\n    def feature_dimensions(self, input_shape):\n        \"\"\"Computes the feature dimensions for all layers given input_shape.\"\"\"\n        _input = torch.ones([1] + list(input_shape)).to(self.device)\n        _output = self(_input)\n        return [_output[layer].shape[1] for layer in self.layers_to_extract_from]\n\n\nclass ForwardHook:\n    def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):\n        self.hook_dict = hook_dict\n        self.layer_name = layer_name\n        self.raise_exception_to_break = copy.deepcopy(\n            layer_name == last_layer_to_extract\n        )\n\n    def __call__(self, module, input, output):\n        self.hook_dict[self.layer_name] = output\n        # if self.raise_exception_to_break:\n        #     raise LastLayerToExtractReachedException()\n        return None\n\n\nclass LastLayerToExtractReachedException(Exception):\n    pass\n\n\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": ""
  },
  {
    "path": "datasets/btad.py",
    "content": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n_CLASSNAMES = [\n    \"01\",\n    \"02\",\n    \"03\"\n]\n\nIMAGENET_MEAN = [0.485, 0.456, 0.406]\nIMAGENET_STD = [0.229, 0.224, 0.225]\n\n\nclass DatasetSplit(Enum):\n    TRAIN = \"train\"\n    VAL = \"val\"\n    TEST = \"test\"\n\n\nclass BTADDataset(torch.utils.data.Dataset):\n    \"\"\"\n    PyTorch Dataset for MVTec.\n    \"\"\"\n\n    def __init__(\n        self,\n        source,\n        classname,\n        resize=256,\n        imagesize=224,\n        split=DatasetSplit.TRAIN,\n        train_val_split=1.0,\n        rotate_degrees=0,\n        translate=0,\n        brightness_factor=0,\n        contrast_factor=0,\n        saturation_factor=0,\n        gray_p=0,\n        h_flip_p=0,\n        v_flip_p=0,\n        scale=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            source: [str]. Path to the MVTec data folder.\n            classname: [str or None]. Name of MVTec class that should be\n                       provided in this dataset. If None, the datasets\n                       iterates over all available images.\n            resize: [int]. (Square) Size the loaded image initially gets\n                    resized to.\n            imagesize: [int]. (Square) Size the resized loaded image gets\n                       (center-)cropped to.\n            split: [enum-option]. Indicates if training or test split of the\n                   data should be used. Has to be an option taken from\n                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that\n                   mvtec.DatasetSplit.TEST will also load mask data.\n        \"\"\"\n        super().__init__()\n        self.source = source\n        self.split = split\n        self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES\n        self.train_val_split = train_val_split\n        self.transform_std = IMAGENET_STD\n        self.transform_mean = IMAGENET_MEAN\n        self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()\n\n        self.transform_img = [\n            transforms.Resize(resize),\n            # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),\n            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),\n            transforms.RandomHorizontalFlip(h_flip_p),\n            transforms.RandomVerticalFlip(v_flip_p),\n            transforms.RandomGrayscale(gray_p),\n            transforms.RandomAffine(rotate_degrees, \n                                    translate=(translate, translate),\n                                    scale=(1.0-scale, 1.0+scale),\n                                    interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n        ]\n        self.transform_img = transforms.Compose(self.transform_img)\n\n        self.transform_mask = [\n            transforms.Resize(resize),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n        ]\n        self.transform_mask = transforms.Compose(self.transform_mask)\n\n        self.imagesize = (3, imagesize, imagesize)\n\n    def __getitem__(self, idx):\n        classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]\n        image = PIL.Image.open(image_path).convert(\"RGB\")\n        image = self.transform_img(image)\n\n        if self.split == DatasetSplit.TEST and mask_path is not None:\n            mask = PIL.Image.open(mask_path)\n            mask = self.transform_mask(mask)\n        else:\n            mask = torch.zeros([1, *image.size()[1:]])\n\n        return {\n            \"image\": image,\n            \"mask\": mask,\n            \"classname\": classname,\n            \"anomaly\": anomaly,\n            \"is_anomaly\": int(anomaly != \"good\"),\n            \"image_name\": \"/\".join(image_path.split(\"/\")[-4:]),\n            \"image_path\": image_path,\n        }\n\n    def __len__(self):\n        return len(self.data_to_iterate)\n\n    def get_image_data(self):\n        imgpaths_per_class = {}\n        maskpaths_per_class = {}\n\n        for classname in self.classnames_to_use:\n            classpath = os.path.join(self.source, classname, self.split.value)\n            maskpath = os.path.join(self.source, classname, \"ground_truth\")\n            anomaly_types = os.listdir(classpath)\n\n            imgpaths_per_class[classname] = {}\n            maskpaths_per_class[classname] = {}\n\n            for anomaly in anomaly_types:\n                anomaly_path = os.path.join(classpath, anomaly)\n                anomaly_files = sorted(os.listdir(anomaly_path))\n                imgpaths_per_class[classname][anomaly] = [\n                    os.path.join(anomaly_path, x) for x in anomaly_files\n                ]\n\n                if self.train_val_split < 1.0:\n                    n_images = len(imgpaths_per_class[classname][anomaly])\n                    train_val_split_idx = int(n_images * self.train_val_split)\n                    if self.split == DatasetSplit.TRAIN:\n                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[\n                            classname\n                        ][anomaly][:train_val_split_idx]\n                    elif self.split == DatasetSplit.VAL:\n                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[\n                            classname\n                        ][anomaly][train_val_split_idx:]\n\n                if self.split == DatasetSplit.TEST and anomaly != \"good\":\n                    anomaly_mask_path = os.path.join(maskpath, anomaly)\n                    anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))\n                    maskpaths_per_class[classname][anomaly] = [\n                        os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files\n                    ]\n                else:\n                    maskpaths_per_class[classname][\"good\"] = None\n\n        # Unrolls the data dictionary to an easy-to-iterate list.\n        data_to_iterate = []\n        for classname in sorted(imgpaths_per_class.keys()):\n            for anomaly in sorted(imgpaths_per_class[classname].keys()):\n                for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):\n                    data_tuple = [classname, anomaly, image_path]\n                    if self.split == DatasetSplit.TEST and anomaly != \"good\":\n                        data_tuple.append(maskpaths_per_class[classname][anomaly][i])\n                    else:\n                        data_tuple.append(None)\n                    data_to_iterate.append(data_tuple)\n\n        return imgpaths_per_class, data_to_iterate\n"
  },
  {
    "path": "datasets/cifar10.py",
    "content": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nIMAGENET_MEAN = [0.485, 0.456, 0.406]\nIMAGENET_STD = [0.229, 0.224, 0.225]\n\n\nclass DatasetSplit(Enum):\n    TRAIN = \"train\"\n    VAL = \"val\"\n    TEST = \"test\"\n\n\nclass Cifar10Dataset(torch.utils.data.Dataset):\n    \"\"\"\n    PyTorch Dataset for MVTec.\n    \"\"\"\n\n    _CLASSES = list(range(10))\n\n    def __init__(\n        self,\n        source,\n        classname,\n        resize=256,\n        imagesize=224,\n        split=DatasetSplit.TRAIN,\n        train_val_split=1.0,\n        rotate_degrees=0,\n        translate=0,\n        brightness_factor=0,\n        contrast_factor=0,\n        saturation_factor=0,\n        gray_p=0,\n        h_flip_p=0,\n        v_flip_p=0,\n        scale=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            source: [str]. Path to the MVTec data folder.\n            classname: [str or None]. Name of MVTec class that should be\n                       provided in this dataset. If None, the datasets\n                       iterates over all available images.\n            resize: [int]. (Square) Size the loaded image initially gets\n                    resized to.\n            imagesize: [int]. (Square) Size the resized loaded image gets\n                       (center-)cropped to.\n            split: [enum-option]. Indicates if training or test split of the\n                   data should be used. Has to be an option taken from\n                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that\n                   mvtec.DatasetSplit.TEST will also load mask data.\n        \"\"\"\n        super().__init__()\n        self.source = source\n        self.split = split\n        self.classname = int(classname)\n        self.train_val_split = train_val_split\n\n        self.data_to_iterate = self.get_image_data()\n        self.transform_std = IMAGENET_STD\n        self.transform_mean = IMAGENET_MEAN\n        self.transform_img = [\n            transforms.Resize(resize),\n            # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),\n            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),\n            transforms.RandomHorizontalFlip(h_flip_p),\n            transforms.RandomVerticalFlip(v_flip_p),\n            transforms.RandomGrayscale(gray_p),\n            transforms.RandomAffine(rotate_degrees, \n                                    translate=(translate, translate),\n                                    scale=(1.0-scale, 1.0+scale),\n                                    interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n        ]\n        self.transform_img = transforms.Compose(self.transform_img)\n\n        self.transform_mask = [\n            transforms.Resize(resize),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n        ]\n        self.transform_mask = transforms.Compose(self.transform_mask)\n\n        self.imagesize = (3, imagesize, imagesize)\n\n    def __getitem__(self, idx):\n        img_path, classname = self.data_to_iterate[idx]\n\n        image = PIL.Image.open(img_path).convert(\"RGB\")\n        image = self.transform_img(image)\n\n        return {\n            \"image\": image,\n            \"classname\": classname,\n            \"anomaly\": int(classname != self.classname),\n            \"is_anomaly\": int(classname != self.classname),\n            \"image_name\": os.path.split(img_path)[-1],\n            \"image_path\": img_path,\n        }\n\n    def __len__(self):\n        return len(self.data_to_iterate)\n\n    def get_image_data(self):\n        data_to_iterate = []\n\n        for classname in Cifar10Dataset._CLASSES:\n            if self.split == DatasetSplit.TRAIN:\n                if classname != self.classname:\n                    continue\n            class_dir = os.path.join(self.source, self.split.value, str(classname))\n            for fn in os.listdir(class_dir):\n                img_path = os.path.join(class_dir, fn)\n                data_to_iterate.append([img_path, classname])\n\n        return data_to_iterate\n"
  },
  {
    "path": "datasets/mvtec.py",
    "content": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n_CLASSNAMES = [\n    \"bottle\",\n    \"cable\",\n    \"capsule\",\n    \"carpet\",\n    \"grid\",\n    \"hazelnut\",\n    \"leather\",\n    \"metal_nut\",\n    \"pill\",\n    \"screw\",\n    \"tile\",\n    \"toothbrush\",\n    \"transistor\",\n    \"wood\",\n    \"zipper\",\n]\n\nIMAGENET_MEAN = [0.485, 0.456, 0.406]\nIMAGENET_STD = [0.229, 0.224, 0.225]\n\n\nclass DatasetSplit(Enum):\n    TRAIN = \"train\"\n    VAL = \"val\"\n    TEST = \"test\"\n\n\nclass MVTecDataset(torch.utils.data.Dataset):\n    \"\"\"\n    PyTorch Dataset for MVTec.\n    \"\"\"\n\n    def __init__(\n        self,\n        source,\n        classname,\n        resize=256,\n        imagesize=224,\n        split=DatasetSplit.TRAIN,\n        train_val_split=1.0,\n        rotate_degrees=0,\n        translate=0,\n        brightness_factor=0,\n        contrast_factor=0,\n        saturation_factor=0,\n        gray_p=0,\n        h_flip_p=0,\n        v_flip_p=0,\n        scale=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            source: [str]. Path to the MVTec data folder.\n            classname: [str or None]. Name of MVTec class that should be\n                       provided in this dataset. If None, the datasets\n                       iterates over all available images.\n            resize: [int]. (Square) Size the loaded image initially gets\n                    resized to.\n            imagesize: [int]. (Square) Size the resized loaded image gets\n                       (center-)cropped to.\n            split: [enum-option]. Indicates if training or test split of the\n                   data should be used. Has to be an option taken from\n                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that\n                   mvtec.DatasetSplit.TEST will also load mask data.\n        \"\"\"\n        super().__init__()\n        self.source = source\n        self.split = split\n        self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES\n        self.train_val_split = train_val_split\n        self.transform_std = IMAGENET_STD\n        self.transform_mean = IMAGENET_MEAN\n        self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()\n\n        self.transform_img = [\n            transforms.Resize(resize),\n            # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),\n            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),\n            transforms.RandomHorizontalFlip(h_flip_p),\n            transforms.RandomVerticalFlip(v_flip_p),\n            transforms.RandomGrayscale(gray_p),\n            transforms.RandomAffine(rotate_degrees, \n                                    translate=(translate, translate),\n                                    scale=(1.0-scale, 1.0+scale),\n                                    interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n        ]\n        self.transform_img = transforms.Compose(self.transform_img)\n\n        self.transform_mask = [\n            transforms.Resize(resize),\n            transforms.CenterCrop(imagesize),\n            transforms.ToTensor(),\n        ]\n        self.transform_mask = transforms.Compose(self.transform_mask)\n\n        self.imagesize = (3, imagesize, imagesize)\n\n    def __getitem__(self, idx):\n        classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]\n        image = PIL.Image.open(image_path).convert(\"RGB\")\n        image = self.transform_img(image)\n\n        if self.split == DatasetSplit.TEST and mask_path is not None:\n            mask = PIL.Image.open(mask_path)\n            mask = self.transform_mask(mask)\n        else:\n            mask = torch.zeros([1, *image.size()[1:]])\n\n        return {\n            \"image\": image,\n            \"mask\": mask,\n            \"classname\": classname,\n            \"anomaly\": anomaly,\n            \"is_anomaly\": int(anomaly != \"good\"),\n            \"image_name\": \"/\".join(image_path.split(\"/\")[-4:]),\n            \"image_path\": image_path,\n        }\n\n    def __len__(self):\n        return len(self.data_to_iterate)\n\n    def get_image_data(self):\n        imgpaths_per_class = {}\n        maskpaths_per_class = {}\n\n        for classname in self.classnames_to_use:\n            classpath = os.path.join(self.source, classname, self.split.value)\n            maskpath = os.path.join(self.source, classname, \"ground_truth\")\n            anomaly_types = os.listdir(classpath)\n\n            imgpaths_per_class[classname] = {}\n            maskpaths_per_class[classname] = {}\n\n            for anomaly in anomaly_types:\n                anomaly_path = os.path.join(classpath, anomaly)\n                anomaly_files = sorted(os.listdir(anomaly_path))\n                imgpaths_per_class[classname][anomaly] = [\n                    os.path.join(anomaly_path, x) for x in anomaly_files\n                ]\n\n                if self.train_val_split < 1.0:\n                    n_images = len(imgpaths_per_class[classname][anomaly])\n                    train_val_split_idx = int(n_images * self.train_val_split)\n                    if self.split == DatasetSplit.TRAIN:\n                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[\n                            classname\n                        ][anomaly][:train_val_split_idx]\n                    elif self.split == DatasetSplit.VAL:\n                        imgpaths_per_class[classname][anomaly] = imgpaths_per_class[\n                            classname\n                        ][anomaly][train_val_split_idx:]\n\n                if self.split == DatasetSplit.TEST and anomaly != \"good\":\n                    anomaly_mask_path = os.path.join(maskpath, anomaly)\n                    anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))\n                    maskpaths_per_class[classname][anomaly] = [\n                        os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files\n                    ]\n                else:\n                    maskpaths_per_class[classname][\"good\"] = None\n\n        # Unrolls the data dictionary to an easy-to-iterate list.\n        data_to_iterate = []\n        for classname in sorted(imgpaths_per_class.keys()):\n            for anomaly in sorted(imgpaths_per_class[classname].keys()):\n                for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):\n                    data_tuple = [classname, anomaly, image_path]\n                    if self.split == DatasetSplit.TEST and anomaly != \"good\":\n                        data_tuple.append(maskpaths_per_class[classname][anomaly][i])\n                    else:\n                        data_tuple.append(None)\n                    data_to_iterate.append(data_tuple)\n\n        return imgpaths_per_class, data_to_iterate\n"
  },
  {
    "path": "datasets/sdd.py",
    "content": "import os\nfrom enum import Enum\nimport pickle\n\nimport cv2\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nIMAGENET_MEAN = [0.485, 0.456, 0.406]\nIMAGENET_STD = [0.229, 0.224, 0.225]\n\n\nclass DatasetSplit(Enum):\n    TRAIN = \"train\"\n    VAL = \"val\"\n    TEST = \"test\"\n\n\nclass SDDDataset(torch.utils.data.Dataset):\n    \"\"\"\n    PyTorch Dataset for MVTec.\n    \"\"\"\n\n    def __init__(\n        self,\n        source,\n        classname,\n        resize=256,\n        imagesize=224,\n        split=DatasetSplit.TRAIN,\n        train_val_split=1.0,\n        rotate_degrees=0,\n        translate=0,\n        brightness_factor=0,\n        contrast_factor=0,\n        saturation_factor=0,\n        gray_p=0,\n        h_flip_p=0,\n        v_flip_p=0,\n        scale=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            source: [str]. Path to the MVTec data folder.\n            classname: [str or None]. Name of MVTec class that should be\n                       provided in this dataset. If None, the datasets\n                       iterates over all available images.\n            resize: [int]. (Square) Size the loaded image initially gets\n                    resized to.\n            imagesize: [int]. (Square) Size the resized loaded image gets\n                       (center-)cropped to.\n            split: [enum-option]. Indicates if training or test split of the\n                   data should be used. Has to be an option taken from\n                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that\n                   mvtec.DatasetSplit.TEST will also load mask data.\n        \"\"\"\n        super().__init__()\n        self.source = source\n        self.split = split\n        self.split_id = int(classname)\n        self.train_val_split = train_val_split\n        self.transform_std = IMAGENET_STD\n        self.transform_mean = IMAGENET_MEAN\n        self.data_to_iterate = self.get_image_data()\n\n        self.transform_img = [\n            transforms.Resize((int(resize*2.5+.5), resize)),\n            # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),\n            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),\n            transforms.RandomHorizontalFlip(h_flip_p),\n            transforms.RandomVerticalFlip(v_flip_p),\n            transforms.RandomGrayscale(gray_p),\n            transforms.RandomAffine(rotate_degrees, \n                                    translate=(translate, translate),\n                                    scale=(1.0-scale, 1.0+scale),\n                                    interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n        ]\n        self.transform_img = transforms.Compose(self.transform_img)\n\n        self.transform_mask = [\n            transforms.Resize((int(resize*2.5+.5), resize)),\n            transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),\n            transforms.ToTensor(),\n        ]\n        self.transform_mask = transforms.Compose(self.transform_mask)\n\n        self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize)\n        \n        # if self.split == DatasetSplit.TEST:\n        #     for i in range(len(self.data_to_iterate)):\n        #         self.__getitem__(i)\n\n    def __getitem__(self, idx):\n        data = self.data_to_iterate[idx]\n        image = PIL.Image.open(data[\"img\"]).convert(\"RGB\")\n        image = self.transform_img(image)\n\n        if self.split == DatasetSplit.TEST and data[\"anomaly\"] == 1:\n            mask = PIL.Image.open(data[\"label\"])\n            mask = self.transform_mask(mask)\n        else:\n            mask = torch.zeros([1, *image.size()[1:]])\n\n        return {\n            \"image\": image,\n            \"mask\": mask,\n            \"classname\": str(self.split_id),\n            \"anomaly\": data[\"anomaly\"],\n            \"is_anomaly\": data[\"anomaly\"],\n            \"image_path\": data[\"img\"],\n        }\n\n    def __len__(self):\n        return len(self.data_to_iterate)\n\n    def get_image_data(self):\n\n        data_ids = []\n        with open(os.path.join(self.source, \"KolektorSDD-training-splits\", \"split.pyb\"), \"rb\") as f:\n            train_ids, test_ids, _ = pickle.load(f)\n            if self.split == DatasetSplit.TRAIN:\n                data_ids = train_ids[self.split_id]\n            else:\n                data_ids = test_ids[self.split_id]\n        \n        data = {}\n        for data_id in data_ids:\n            item_dir = os.path.join(self.source, data_id)\n            fns = os.listdir(item_dir)\n            part_ids = [os.path.splitext(fn)[0] for fn in fns if fn.endswith(\"jpg\")]\n            parts = {part_id:{\"img\":\"\", \"label\":\"\", \"anomaly\":0}\n                     for part_id in part_ids}\n            for part_id in parts:\n                for fn in fns:\n                    if part_id in fn:\n                        if \"label\" in fn:\n                            label = cv2.imread(os.path.join(item_dir, fn))\n                            if label.sum() > 0:\n                                parts[part_id][\"anomaly\"] = 1\n                            parts[part_id][\"label\"] = os.path.join(item_dir, fn)\n                        else:\n                            parts[part_id][\"img\"] = os.path.join(item_dir, fn)\n            for k, v in parts.items():\n                if self.split == DatasetSplit.TRAIN and v[\"anomaly\"] == 1:\n                    continue\n                data[data_id + '_' + k] = v\n\n        return list(data.values())\n"
  },
  {
    "path": "datasets/sdd2.py",
    "content": "import os\nfrom enum import Enum\nimport pickle\n\nimport cv2\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nIMAGENET_MEAN = [0.485, 0.456, 0.406]\nIMAGENET_STD = [0.229, 0.224, 0.225]\n\n\nclass DatasetSplit(Enum):\n    TRAIN = \"train\"\n    VAL = \"val\"\n    TEST = \"test\"\n\n\nclass SDD2Dataset(torch.utils.data.Dataset):\n    \"\"\"\n    PyTorch Dataset for MVTec.\n    \"\"\"\n\n    def __init__(\n        self,\n        source,\n        classname,\n        resize=256,\n        imagesize=224,\n        split=DatasetSplit.TRAIN,\n        train_val_split=1.0,\n        rotate_degrees=0,\n        translate=0,\n        brightness_factor=0,\n        contrast_factor=0,\n        saturation_factor=0,\n        gray_p=0,\n        h_flip_p=0,\n        v_flip_p=0,\n        scale=0,\n        **kwargs,\n    ):\n        \"\"\"\n        Args:\n            source: [str]. Path to the MVTec data folder.\n            classname: [str or None]. Name of MVTec class that should be\n                       provided in this dataset. If None, the datasets\n                       iterates over all available images.\n            resize: [int]. (Square) Size the loaded image initially gets\n                    resized to.\n            imagesize: [int]. (Square) Size the resized loaded image gets\n                       (center-)cropped to.\n            split: [enum-option]. Indicates if training or test split of the\n                   data should be used. Has to be an option taken from\n                   DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that\n                   mvtec.DatasetSplit.TEST will also load mask data.\n        \"\"\"\n        super().__init__()\n        self.source = source\n        self.split = split\n        self.train_val_split = train_val_split\n        self.transform_std = IMAGENET_STD\n        self.transform_mean = IMAGENET_MEAN\n        self.data_to_iterate = self.get_image_data()\n\n        self.transform_img = [\n            transforms.Resize((int(resize*2.5+.5), resize)),\n            # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),\n            transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),\n            transforms.RandomHorizontalFlip(h_flip_p),\n            transforms.RandomVerticalFlip(v_flip_p),\n            transforms.RandomGrayscale(gray_p),\n            transforms.RandomAffine(rotate_degrees, \n                                    translate=(translate, translate),\n                                    scale=(1.0-scale, 1.0+scale),\n                                    interpolation=transforms.InterpolationMode.BILINEAR),\n            transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),\n        ]\n        self.transform_img = transforms.Compose(self.transform_img)\n\n        self.transform_mask = [\n            transforms.Resize((int(resize*2.5+.5), resize)),\n            transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),\n            transforms.ToTensor(),\n        ]\n        self.transform_mask = transforms.Compose(self.transform_mask)\n\n        self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize)\n        \n        # if self.split == DatasetSplit.TEST:\n        #     for i in range(len(self.data_to_iterate)):\n        #         self.__getitem__(i)\n\n    def __getitem__(self, idx):\n        img_path, gt_path, is_anomaly = self.data_to_iterate[idx]\n        image = PIL.Image.open(img_path).convert(\"RGB\")\n        image = self.transform_img(image)\n\n        if self.split == DatasetSplit.TEST and is_anomaly:\n            mask = PIL.Image.open(gt_path)\n            mask = self.transform_mask(mask)\n        else:\n            mask = torch.zeros([1, *image.size()[1:]])\n\n        return {\n            \"image\": image,\n            \"mask\": mask,\n            \"classname\": \"\",\n            \"anomaly\": is_anomaly,\n            \"is_anomaly\": is_anomaly,\n            \"image_path\": img_path,\n        }\n\n    def __len__(self):\n        return len(self.data_to_iterate)\n\n    def get_image_data(self):\n\n        data_ids = []\n        \n        data_dir = os.path.join(self.source, \"train\" if self.split == DatasetSplit.TRAIN else \"test\")\n        data = []\n        test = [0, 0]\n        for fn in os.listdir(data_dir):\n            if \"GT\" not in fn:\n                data_id = os.path.splitext(fn)[0]\n                img_path = os.path.join(data_dir, fn)\n                gt_path = os.path.join(data_dir, f\"{data_id}_GT.png\")\n                assert os.path.exists(img_path)\n                assert os.path.exists(gt_path), gt_path\n                gt = cv2.imread(gt_path)\n                is_anomaly = gt.sum() > 0\n                if is_anomaly:\n                    test[1] = test[1] + 1\n                else:\n                    test[0] = test[0] + 1\n                if self.split == DatasetSplit.TRAIN and is_anomaly:\n                    continue\n                data.append([img_path, gt_path, gt.sum() > 0])\n\n        return data\n"
  },
  {
    "path": "main.py",
    "content": "# ------------------------------------------------------------------\n# SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)\n# Github source: https://github.com/DonaldRR/SimpleNet\n# Licensed under the MIT License [see LICENSE for details]\n# The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection)\n# ------------------------------------------------------------------\n\nimport logging\nimport os\nimport sys\n\nimport click\nimport numpy as np\nimport torch\n\nsys.path.append(\"src\")\nimport backbones\nimport common\nimport metrics\nimport simplenet \nimport utils\n\nLOGGER = logging.getLogger(__name__)\n\n_DATASETS = {\n    \"mvtec\": [\"datasets.mvtec\", \"MVTecDataset\"],\n}\n\n\n@click.group(chain=True)\n@click.option(\"--results_path\", type=str)\n@click.option(\"--gpu\", type=int, default=[0], multiple=True, show_default=True)\n@click.option(\"--seed\", type=int, default=0, show_default=True)\n@click.option(\"--log_group\", type=str, default=\"group\")\n@click.option(\"--log_project\", type=str, default=\"project\")\n@click.option(\"--run_name\", type=str, default=\"test\")\n@click.option(\"--test\", is_flag=True)\n@click.option(\"--save_segmentation_images\", is_flag=True, default=False, show_default=True)\ndef main(**kwargs):\n    pass\n\n\n@main.result_callback()\ndef run(\n    methods,\n    results_path,\n    gpu,\n    seed,\n    log_group,\n    log_project,\n    run_name,\n    test,\n    save_segmentation_images\n):\n    methods = {key: item for (key, item) in methods}\n\n    run_save_path = utils.create_storage_folder(\n        results_path, log_project, log_group, run_name, mode=\"overwrite\"\n    )\n\n    pid = os.getpid()\n    list_of_dataloaders = methods[\"get_dataloaders\"](seed)\n\n    device = utils.set_torch_device(gpu)\n\n    result_collect = []\n    for dataloader_count, dataloaders in enumerate(list_of_dataloaders):\n        LOGGER.info(\n            \"Evaluating dataset [{}] ({}/{})...\".format(\n                dataloaders[\"training\"].name,\n                dataloader_count + 1,\n                len(list_of_dataloaders),\n            )\n        )\n\n        utils.fix_seeds(seed, device)\n\n        dataset_name = dataloaders[\"training\"].name\n\n        imagesize = dataloaders[\"training\"].dataset.imagesize\n        simplenet_list = methods[\"get_simplenet\"](imagesize, device)\n\n        models_dir = os.path.join(run_save_path, \"models\")\n        os.makedirs(models_dir, exist_ok=True)\n        for i, SimpleNet in enumerate(simplenet_list):\n            # torch.cuda.empty_cache()\n            if SimpleNet.backbone.seed is not None:\n                utils.fix_seeds(SimpleNet.backbone.seed, device)\n            LOGGER.info(\n                \"Training models ({}/{})\".format(i + 1, len(simplenet_list))\n            )\n            # torch.cuda.empty_cache()\n\n            SimpleNet.set_model_dir(os.path.join(models_dir, f\"{i}\"), dataset_name)\n            if not test:\n                i_auroc, p_auroc, pro_auroc = SimpleNet.train(dataloaders[\"training\"], dataloaders[\"testing\"])\n            else:\n                # BUG: the following line is not using. Set test with True by default.\n                # i_auroc, p_auroc, pro_auroc =  SimpleNet.test(dataloaders[\"training\"], dataloaders[\"testing\"], save_segmentation_images)\n                print(\"Warning: Pls set test with true by default\")\n\n            result_collect.append(\n                {\n                    \"dataset_name\": dataset_name,\n                    \"instance_auroc\": i_auroc, # auroc,\n                    \"full_pixel_auroc\": p_auroc, # full_pixel_auroc,\n                    \"anomaly_pixel_auroc\": pro_auroc,\n                }\n            )\n\n            for key, item in result_collect[-1].items():\n                if key != \"dataset_name\":\n                    LOGGER.info(\"{0}: {1:3.3f}\".format(key, item))\n\n        LOGGER.info(\"\\n\\n-----\\n\")\n\n    # Store all results and mean scores to a csv-file.\n    result_metric_names = list(result_collect[-1].keys())[1:]\n    result_dataset_names = [results[\"dataset_name\"] for results in result_collect]\n    result_scores = [list(results.values())[1:] for results in result_collect]\n    utils.compute_and_store_final_results(\n        run_save_path,\n        result_scores,\n        column_names=result_metric_names,\n        row_names=result_dataset_names,\n    )\n\n\n@main.command(\"net\")\n@click.option(\"--backbone_names\", \"-b\", type=str, multiple=True, default=[])\n@click.option(\"--layers_to_extract_from\", \"-le\", type=str, multiple=True, default=[])\n@click.option(\"--pretrain_embed_dimension\", type=int, default=1024)\n@click.option(\"--target_embed_dimension\", type=int, default=1024)\n@click.option(\"--patchsize\", type=int, default=3)\n@click.option(\"--embedding_size\", type=int, default=1024)\n@click.option(\"--meta_epochs\", type=int, default=1)\n@click.option(\"--aed_meta_epochs\", type=int, default=1)\n@click.option(\"--gan_epochs\", type=int, default=1)\n@click.option(\"--dsc_layers\", type=int, default=2)\n@click.option(\"--dsc_hidden\", type=int, default=None)\n@click.option(\"--noise_std\", type=float, default=0.05)\n@click.option(\"--dsc_margin\", type=float, default=0.8)\n@click.option(\"--dsc_lr\", type=float, default=0.0002)\n@click.option(\"--auto_noise\", type=float, default=0)\n@click.option(\"--train_backbone\", is_flag=True)\n@click.option(\"--cos_lr\", is_flag=True)\n@click.option(\"--pre_proj\", type=int, default=0)\n@click.option(\"--proj_layer_type\", type=int, default=0)\n@click.option(\"--mix_noise\", type=int, default=1)\ndef net(\n    backbone_names,\n    layers_to_extract_from,\n    pretrain_embed_dimension,\n    target_embed_dimension,\n    patchsize,\n    embedding_size,\n    meta_epochs,\n    aed_meta_epochs,\n    gan_epochs,\n    noise_std,\n    dsc_layers, \n    dsc_hidden,\n    dsc_margin,\n    dsc_lr,\n    auto_noise,\n    train_backbone,\n    cos_lr,\n    pre_proj,\n    proj_layer_type,\n    mix_noise,\n):\n    backbone_names = list(backbone_names)\n    if len(backbone_names) > 1:\n        layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))]\n        for layer in layers_to_extract_from:\n            idx = int(layer.split(\".\")[0])\n            layer = \".\".join(layer.split(\".\")[1:])\n            layers_to_extract_from_coll[idx].append(layer)\n    else:\n        layers_to_extract_from_coll = [layers_to_extract_from]\n\n    def get_simplenet(input_shape, device):\n        simplenets = []\n        for backbone_name, layers_to_extract_from in zip(\n            backbone_names, layers_to_extract_from_coll\n        ):\n            backbone_seed = None\n            if \".seed-\" in backbone_name:\n                backbone_name, backbone_seed = backbone_name.split(\".seed-\")[0], int(\n                    backbone_name.split(\"-\")[-1]\n                )\n            backbone = backbones.load(backbone_name)\n            backbone.name, backbone.seed = backbone_name, backbone_seed\n\n            simplenet_inst = simplenet.SimpleNet(device)\n            simplenet_inst.load(\n                backbone=backbone,\n                layers_to_extract_from=layers_to_extract_from,\n                device=device,\n                input_shape=input_shape,\n                pretrain_embed_dimension=pretrain_embed_dimension,\n                target_embed_dimension=target_embed_dimension,\n                patchsize=patchsize,\n                embedding_size=embedding_size,\n                meta_epochs=meta_epochs,\n                aed_meta_epochs=aed_meta_epochs,\n                gan_epochs=gan_epochs,\n                noise_std=noise_std,\n                dsc_layers=dsc_layers,\n                dsc_hidden=dsc_hidden,\n                dsc_margin=dsc_margin,\n                dsc_lr=dsc_lr,\n                auto_noise=auto_noise,\n                train_backbone=train_backbone,\n                cos_lr=cos_lr,\n                pre_proj=pre_proj,\n                proj_layer_type=proj_layer_type,\n                mix_noise=mix_noise,\n            )\n            simplenets.append(simplenet_inst)\n        return simplenets\n\n    return (\"get_simplenet\", get_simplenet)\n\n\n@main.command(\"dataset\")\n@click.argument(\"name\", type=str)\n@click.argument(\"data_path\", type=click.Path(exists=True, file_okay=False))\n@click.option(\"--subdatasets\", \"-d\", multiple=True, type=str, required=True)\n@click.option(\"--train_val_split\", type=float, default=1, show_default=True)\n@click.option(\"--batch_size\", default=2, type=int, show_default=True)\n@click.option(\"--num_workers\", default=2, type=int, show_default=True)\n@click.option(\"--resize\", default=256, type=int, show_default=True)\n@click.option(\"--imagesize\", default=224, type=int, show_default=True)\n@click.option(\"--rotate_degrees\", default=0, type=int)\n@click.option(\"--translate\", default=0, type=float)\n@click.option(\"--scale\", default=0.0, type=float)\n@click.option(\"--brightness\", default=0.0, type=float)\n@click.option(\"--contrast\", default=0.0, type=float)\n@click.option(\"--saturation\", default=0.0, type=float)\n@click.option(\"--gray\", default=0.0, type=float)\n@click.option(\"--hflip\", default=0.0, type=float)\n@click.option(\"--vflip\", default=0.0, type=float)\n@click.option(\"--augment\", is_flag=True)\ndef dataset(\n    name,\n    data_path,\n    subdatasets,\n    train_val_split,\n    batch_size,\n    resize,\n    imagesize,\n    num_workers,\n    rotate_degrees,\n    translate,\n    scale,\n    brightness,\n    contrast,\n    saturation,\n    gray,\n    hflip,\n    vflip,\n    augment,\n):\n    dataset_info = _DATASETS[name]\n    dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]])\n\n    def get_dataloaders(seed):\n        dataloaders = []\n        for subdataset in subdatasets:\n            train_dataset = dataset_library.__dict__[dataset_info[1]](\n                data_path,\n                classname=subdataset,\n                resize=resize,\n                train_val_split=train_val_split,\n                imagesize=imagesize,\n                split=dataset_library.DatasetSplit.TRAIN,\n                seed=seed,\n                rotate_degrees=rotate_degrees,\n                translate=translate,\n                brightness_factor=brightness,\n                contrast_factor=contrast,\n                saturation_factor=saturation,\n                gray_p=gray,\n                h_flip_p=hflip,\n                v_flip_p=vflip,\n                scale=scale,\n                augment=augment,\n            )\n\n            test_dataset = dataset_library.__dict__[dataset_info[1]](\n                data_path,\n                classname=subdataset,\n                resize=resize,\n                imagesize=imagesize,\n                split=dataset_library.DatasetSplit.TEST,\n                seed=seed,\n            )\n            \n            LOGGER.info(f\"Dataset: train={len(train_dataset)} test={len(test_dataset)}\")\n\n            train_dataloader = torch.utils.data.DataLoader(\n                train_dataset,\n                batch_size=batch_size,\n                shuffle=True,\n                num_workers=num_workers,\n                prefetch_factor=2,\n                pin_memory=True,\n            )\n\n            test_dataloader = torch.utils.data.DataLoader(\n                test_dataset,\n                batch_size=batch_size,\n                shuffle=False,\n                num_workers=num_workers,\n                prefetch_factor=2,\n                pin_memory=True,\n            )\n\n            train_dataloader.name = name\n            if subdataset is not None:\n                train_dataloader.name += \"_\" + subdataset\n\n            if train_val_split < 1:\n                val_dataset = dataset_library.__dict__[dataset_info[1]](\n                    data_path,\n                    classname=subdataset,\n                    resize=resize,\n                    train_val_split=train_val_split,\n                    imagesize=imagesize,\n                    split=dataset_library.DatasetSplit.VAL,\n                    seed=seed,\n                )\n\n                val_dataloader = torch.utils.data.DataLoader(\n                    val_dataset,\n                    batch_size=batch_size,\n                    shuffle=False,\n                    num_workers=num_workers,\n                    prefetch_factor=4,\n                    pin_memory=True,\n                )\n            else:\n                val_dataloader = None\n            dataloader_dict = {\n                \"training\": train_dataloader,\n                \"validation\": val_dataloader,\n                \"testing\": test_dataloader,\n            }\n\n            dataloaders.append(dataloader_dict)\n        return dataloaders\n\n    return (\"get_dataloaders\", get_dataloaders)\n\n\nif __name__ == \"__main__\":\n    logging.basicConfig(level=logging.INFO)\n    LOGGER.info(\"Command line arguments: {}\".format(\" \".join(sys.argv)))\n    main()\n"
  },
  {
    "path": "metrics.py",
    "content": "\"\"\"Anomaly metrics.\"\"\"\nimport cv2\nimport numpy as np\nfrom sklearn import metrics\n\n\ndef compute_imagewise_retrieval_metrics(\n    anomaly_prediction_weights, anomaly_ground_truth_labels\n):\n    \"\"\"\n    Computes retrieval statistics (AUROC, FPR, TPR).\n\n    Args:\n        anomaly_prediction_weights: [np.array or list] [N] Assignment weights\n                                    per image. Higher indicates higher\n                                    probability of being an anomaly.\n        anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1\n                                    if image is an anomaly, 0 if not.\n    \"\"\"\n    fpr, tpr, thresholds = metrics.roc_curve(\n        anomaly_ground_truth_labels, anomaly_prediction_weights\n    )\n    auroc = metrics.roc_auc_score(\n        anomaly_ground_truth_labels, anomaly_prediction_weights\n    )\n    \n    precision, recall, _ = metrics.precision_recall_curve(\n        anomaly_ground_truth_labels, anomaly_prediction_weights\n    )\n    auc_pr = metrics.auc(recall, precision)\n    \n    return {\"auroc\": auroc, \"fpr\": fpr, \"tpr\": tpr, \"threshold\": thresholds}\n\n\ndef compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):\n    \"\"\"\n    Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations\n    and ground truth segmentation masks.\n\n    Args:\n        anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains\n                                generated segmentation masks.\n        ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains\n                            predefined ground truth segmentation masks\n    \"\"\"\n    if isinstance(anomaly_segmentations, list):\n        anomaly_segmentations = np.stack(anomaly_segmentations)\n    if isinstance(ground_truth_masks, list):\n        ground_truth_masks = np.stack(ground_truth_masks)\n\n    flat_anomaly_segmentations = anomaly_segmentations.ravel()\n    flat_ground_truth_masks = ground_truth_masks.ravel()\n\n    fpr, tpr, thresholds = metrics.roc_curve(\n        flat_ground_truth_masks.astype(int), flat_anomaly_segmentations\n    )\n    auroc = metrics.roc_auc_score(\n        flat_ground_truth_masks.astype(int), flat_anomaly_segmentations\n    )\n\n    precision, recall, thresholds = metrics.precision_recall_curve(\n        flat_ground_truth_masks.astype(int), flat_anomaly_segmentations\n    )\n    F1_scores = np.divide(\n        2 * precision * recall,\n        precision + recall,\n        out=np.zeros_like(precision),\n        where=(precision + recall) != 0,\n    )\n\n    optimal_threshold = thresholds[np.argmax(F1_scores)]\n    predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)\n    fpr_optim = np.mean(predictions > flat_ground_truth_masks)\n    fnr_optim = np.mean(predictions < flat_ground_truth_masks)\n\n    return {\n        \"auroc\": auroc,\n        \"fpr\": fpr,\n        \"tpr\": tpr,\n        \"optimal_threshold\": optimal_threshold,\n        \"optimal_fpr\": fpr_optim,\n        \"optimal_fnr\": fnr_optim,\n    }\n\n\nimport pandas as pd\nfrom skimage import measure\ndef compute_pro(masks, amaps, num_th=200):\n\n    df = pd.DataFrame([], columns=[\"pro\", \"fpr\", \"threshold\"])\n    binary_amaps = np.zeros_like(amaps, dtype=np.bool)\n\n    min_th = amaps.min()\n    max_th = amaps.max()\n    delta = (max_th - min_th) / num_th\n\n    k = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))\n    for th in np.arange(min_th, max_th, delta):\n        binary_amaps[amaps <= th] = 0\n        binary_amaps[amaps > th] = 1\n\n        pros = []\n        for binary_amap, mask in zip(binary_amaps, masks):\n            binary_amap = cv2.dilate(binary_amap.astype(np.uint8), k)\n            for region in measure.regionprops(measure.label(mask)):\n                axes0_ids = region.coords[:, 0]\n                axes1_ids = region.coords[:, 1]\n                tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()\n                pros.append(tp_pixels / region.area)\n\n        inverse_masks = 1 - masks\n        fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()\n        fpr = fp_pixels / inverse_masks.sum()\n\n        df = df.append({\"pro\": np.mean(pros), \"fpr\": fpr, \"threshold\": th}, ignore_index=True)\n\n    # Normalize FPR from 0 ~ 1 to 0 ~ 0.3\n    df = df[df[\"fpr\"] < 0.3]\n    df[\"fpr\"] = df[\"fpr\"] / df[\"fpr\"].max()\n\n    pro_auc = metrics.auc(df[\"fpr\"], df[\"pro\"])\n    return pro_auc"
  },
  {
    "path": "resnet.py",
    "content": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optional\ntry:\n    from torch.hub import load_state_dict_from_url\nexcept ImportError:\n    from torch.utils.model_zoo import load_url as load_state_dict_from_url\n\n__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',\n           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',\n           'wide_resnet50_2', 'wide_resnet101_2']\n\n\nmodel_urls = {\n    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',\n    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',\n    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',\n    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',\n    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',\n    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',\n    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',\n    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',\n    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',\n}\n\nPADDING_MODE = 'reflect' # {'zeros', 'reflect', 'replicate', 'circular'}\n# PADDING_MODE = 'zeros' # {'zeros', 'reflect', 'replicate', 'circular'}\n\ndef conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1):\n    \"\"\"3x3 convolution with padding\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,\n                     padding=dilation, padding_mode = PADDING_MODE, groups=groups, bias=False, dilation=dilation)\n\n\ndef conv1x1(in_planes: int, out_planes: int, stride: int = 1):\n    \"\"\"1x1 convolution\"\"\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)\n\n\nclass BasicBlock(nn.Module):\n    expansion: int = 1\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ):\n        super(BasicBlock, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        if groups != 1 or base_width != 64:\n            raise ValueError('BasicBlock only supports groups=1 and base_width=64')\n        if dilation > 1:\n            raise NotImplementedError(\"Dilation > 1 not supported in BasicBlock\")\n        # Both self.conv1 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv3x3(inplanes, planes, stride)\n        self.bn1 = norm_layer(planes)\n        self.relu = nn.ReLU(inplace=True)\n        self.conv2 = conv3x3(planes, planes)\n        self.bn2 = norm_layer(planes)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass Bottleneck(nn.Module):\n    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)\n    # while original implementation places the stride at the first 1x1 convolution(self.conv1)\n    # according to \"Deep residual learning for image recognition\"https://arxiv.org/abs/1512.03385.\n    # This variant is also known as ResNet V1.5 and improves accuracy according to\n    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.\n\n    expansion: int = 4\n\n    def __init__(\n        self,\n        inplanes: int,\n        planes: int,\n        stride: int = 1,\n        downsample: Optional[nn.Module] = None,\n        groups: int = 1,\n        base_width: int = 64,\n        dilation: int = 1,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ):\n        super(Bottleneck, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        width = int(planes * (base_width / 64.)) * groups\n        # Both self.conv2 and self.downsample layers downsample the input when stride != 1\n        self.conv1 = conv1x1(inplanes, width)\n        self.bn1 = norm_layer(width)\n        self.conv2 = conv3x3(width, width, stride, groups, dilation)\n        self.bn2 = norm_layer(width)\n        self.conv3 = conv1x1(width, planes * self.expansion)\n        self.bn3 = norm_layer(planes * self.expansion)\n        self.relu = nn.ReLU(inplace=True)\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x: Tensor):\n        identity = x\n\n        out = self.conv1(x)\n        out = self.bn1(out)\n        out = self.relu(out)\n\n        out = self.conv2(out)\n        out = self.bn2(out)\n        out = self.relu(out)\n\n        out = self.conv3(out)\n        out = self.bn3(out)\n\n        if self.downsample is not None:\n            identity = self.downsample(x)\n\n        out += identity\n        out = self.relu(out)\n\n        return out\n\n\nclass ResNet(nn.Module):\n\n    def __init__(\n        self,\n        block: Type[Union[BasicBlock, Bottleneck]],\n        layers: List[int],\n        num_classes: int = 1000,\n        zero_init_residual: bool = False,\n        groups: int = 1,\n        width_per_group: int = 64,\n        replace_stride_with_dilation: Optional[List[bool]] = None,\n        norm_layer: Optional[Callable[..., nn.Module]] = None\n    ):\n        super(ResNet, self).__init__()\n        if norm_layer is None:\n            norm_layer = nn.BatchNorm2d\n        self._norm_layer = norm_layer\n\n        self.inplanes = 64\n        self.dilation = 1\n        if replace_stride_with_dilation is None:\n            # each element in the tuple indicates if we should replace\n            # the 2x2 stride with a dilated convolution instead\n            replace_stride_with_dilation = [False, False, False]\n        if len(replace_stride_with_dilation) != 3:\n            raise ValueError(\"replace_stride_with_dilation should be None \"\n                             \"or a 3-element tuple, got {}\".format(replace_stride_with_dilation))\n        self.groups = groups\n        self.base_width = width_per_group\n        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, padding_mode = PADDING_MODE,\n                               bias=False)\n        self.bn1 = norm_layer(self.inplanes)\n        self.relu = nn.ReLU(inplace=True)\n        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n        \n        self.layer1 = self._make_layer(block, 64, layers[0])\n        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])\n        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])\n        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])\n        #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n        #self.fc = nn.Linear(512 * block.expansion, num_classes)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):\n                nn.init.constant_(m.weight, 1)\n                nn.init.constant_(m.bias, 0)\n\n        # Zero-initialize the last BN in each residual branch,\n        # so that the residual branch starts with zeros, and each residual block behaves like an identity.\n        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677\n        if zero_init_residual:\n            for m in self.modules():\n                if isinstance(m, Bottleneck):\n                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]\n                elif isinstance(m, BasicBlock):\n                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]\n\n    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,\n                    stride: int = 1, dilate: bool = False):\n        norm_layer = self._norm_layer\n        downsample = None\n        previous_dilation = self.dilation\n        if dilate:\n            self.dilation *= stride\n            stride = 1\n        if stride != 1 or self.inplanes != planes * block.expansion:\n            downsample = nn.Sequential(\n                conv1x1(self.inplanes, planes * block.expansion, stride),\n                norm_layer(planes * block.expansion),)\n\n        layers = []\n        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,\n                            self.base_width, previous_dilation, norm_layer))\n        self.inplanes = planes * block.expansion\n        for _ in range(1, blocks):\n            layers.append(block(self.inplanes, planes, groups=self.groups,\n                                base_width=self.base_width, dilation=self.dilation,\n                                norm_layer=norm_layer))\n\n        return nn.Sequential(*layers)\n\n    def _forward_impl(self, x: Tensor):\n        # See note [TorchScript super()]\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = self.maxpool(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.layer4(x)\n        # remove extra layers\n        #x = self.avgpool(x)\n        #x = torch.flatten(x, 1)\n        #x = self.fc(x)\n        return x\n\n    def forward(self, x: Tensor):\n        return self._forward_impl(x)\n\n\ndef _resnet(\n    arch: str,\n    block: Type[Union[BasicBlock, Bottleneck]],\n    layers: List[int],\n    pretrained: bool,\n    progress: bool,\n    **kwargs: Any\n):\n    model = ResNet(block, layers, **kwargs)\n    if pretrained:\n        state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)\n        #model.load_state_dict(state_dict)\n        model.load_state_dict(state_dict, strict=False)\n    return model\n\n\ndef resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any):\n    r\"\"\"ResNet-18 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNet-34 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNet-50 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNet-101 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNet-152 model from\n    `\"Deep Residual Learning for Image Recognition\" <https://arxiv.org/pdf/1512.03385.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,\n                   **kwargs)\n\n\ndef resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNeXt-50 32x4d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 4\n    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"ResNeXt-101 32x8d model from\n    `\"Aggregated Residual Transformation for Deep Neural Networks\" <https://arxiv.org/pdf/1611.05431.pdf>`_.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['groups'] = 32\n    kwargs['width_per_group'] = 8\n    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"Wide ResNet-50-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],\n                   pretrained, progress, **kwargs)\n\n\ndef wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) :\n    r\"\"\"Wide ResNet-101-2 model from\n    `\"Wide Residual Networks\" <https://arxiv.org/pdf/1605.07146.pdf>`_.\n\n    The model is the same as ResNet except for the bottleneck number of channels\n    which is twice larger in every block. The number of channels in outer 1x1\n    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048\n    channels, and in Wide ResNet-50-2 has 2048-1024-2048.\n\n    Args:\n        pretrained (bool): If True, returns a model pre-trained on ImageNet\n        progress (bool): If True, displays a progress bar of the download to stderr\n    \"\"\"\n    kwargs['width_per_group'] = 64 * 2\n    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],\n                   pretrained, progress, **kwargs)\n\n\n# ==============================================================================================================\n#                                           Model Class Definition\n# =============================================================================================================="
  },
  {
    "path": "run.sh",
    "content": "datapath=/data4/MVTec_ad\ndatasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')\ndataset_flags=($(for dataset in \"${datasets[@]}\"; do echo '-d '\"${dataset}\"; done))\n\npython3 main.py \\\n--gpu 4 \\\n--seed 0 \\\n--log_group simplenet_mvtec \\\n--log_project MVTecAD_Results \\\n--results_path results \\\n--run_name run \\\nnet \\\n-b wideresnet50 \\\n-le layer2 \\\n-le layer3 \\\n--pretrain_embed_dimension 1536 \\\n--target_embed_dimension 1536 \\\n--patchsize 3 \\\n--meta_epochs 40 \\\n--embedding_size 256 \\\n--gan_epochs 4 \\\n--noise_std 0.015 \\\n--dsc_hidden 1024 \\\n--dsc_layers 2 \\\n--dsc_margin .5 \\\n--pre_proj 1 \\\ndataset \\\n--batch_size 8 \\\n--resize 329 \\\n--imagesize 288 \"${dataset_flags[@]}\" mvtec $datapath\n"
  },
  {
    "path": "simplenet.py",
    "content": "# ------------------------------------------------------------------\n# SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)\n# Github source: https://github.com/DonaldRR/SimpleNet\n# Licensed under the MIT License [see LICENSE for details]\n# The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection)\n# ------------------------------------------------------------------\n\n\"\"\"detection methods.\"\"\"\nimport logging\nimport os\nimport pickle\nfrom collections import OrderedDict\n\nimport math\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport tqdm\nfrom torch.utils.tensorboard import SummaryWriter\n\nimport common\nimport metrics\nfrom utils import plot_segmentation_images\n\nLOGGER = logging.getLogger(__name__)\n\ndef init_weight(m):\n\n    if isinstance(m, torch.nn.Linear):\n        torch.nn.init.xavier_normal_(m.weight)\n    elif isinstance(m, torch.nn.Conv2d):\n        torch.nn.init.xavier_normal_(m.weight)\n\n\nclass Discriminator(torch.nn.Module):\n    def __init__(self, in_planes, n_layers=1, hidden=None):\n        super(Discriminator, self).__init__()\n\n        _hidden = in_planes if hidden is None else hidden\n        self.body = torch.nn.Sequential()\n        for i in range(n_layers-1):\n            _in = in_planes if i == 0 else _hidden\n            _hidden = int(_hidden // 1.5) if hidden is None else hidden\n            self.body.add_module('block%d'%(i+1),\n                                 torch.nn.Sequential(\n                                     torch.nn.Linear(_in, _hidden),\n                                     torch.nn.BatchNorm1d(_hidden),\n                                     torch.nn.LeakyReLU(0.2)\n                                 ))\n        self.tail = torch.nn.Linear(_hidden, 1, bias=False)\n        self.apply(init_weight)\n\n    def forward(self,x):\n        x = self.body(x)\n        x = self.tail(x)\n        return x\n\n\nclass Projection(torch.nn.Module):\n    \n    def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):\n        super(Projection, self).__init__()\n        \n        if out_planes is None:\n            out_planes = in_planes\n        self.layers = torch.nn.Sequential()\n        _in = None\n        _out = None\n        for i in range(n_layers):\n            _in = in_planes if i == 0 else _out\n            _out = out_planes \n            self.layers.add_module(f\"{i}fc\", \n                                   torch.nn.Linear(_in, _out))\n            if i < n_layers - 1:\n                # if layer_type > 0:\n                #     self.layers.add_module(f\"{i}bn\", \n                #                            torch.nn.BatchNorm1d(_out))\n                if layer_type > 1:\n                    self.layers.add_module(f\"{i}relu\",\n                                           torch.nn.LeakyReLU(.2))\n        self.apply(init_weight)\n    \n    def forward(self, x):\n        \n        # x = .1 * self.layers(x) + x\n        x = self.layers(x)\n        return x\n\n\nclass TBWrapper:\n    \n    def __init__(self, log_dir):\n        self.g_iter = 0\n        self.logger = SummaryWriter(log_dir=log_dir)\n    \n    def step(self):\n        self.g_iter += 1\n\nclass SimpleNet(torch.nn.Module):\n    def __init__(self, device):\n        \"\"\"anomaly detection class.\"\"\"\n        super(SimpleNet, self).__init__()\n        self.device = device\n\n    def load(\n        self,\n        backbone,\n        layers_to_extract_from,\n        device,\n        input_shape,\n        pretrain_embed_dimension, # 1536\n        target_embed_dimension, # 1536\n        patchsize=3, # 3\n        patchstride=1, \n        embedding_size=None, # 256\n        meta_epochs=1, # 40\n        aed_meta_epochs=1,\n        gan_epochs=1, # 4\n        noise_std=0.05,\n        mix_noise=1,\n        noise_type=\"GAU\",\n        dsc_layers=2, # 2\n        dsc_hidden=None, # 1024\n        dsc_margin=.8, # .5\n        dsc_lr=0.0002,\n        train_backbone=False,\n        auto_noise=0,\n        cos_lr=False,\n        lr=1e-3,\n        pre_proj=0, # 1\n        proj_layer_type=0,\n        **kwargs,\n    ):\n        pid = os.getpid()\n        def show_mem():\n            return(psutil.Process(pid).memory_info())\n\n        self.backbone = backbone.to(device)\n        self.layers_to_extract_from = layers_to_extract_from\n        self.input_shape = input_shape\n\n        self.device = device\n        self.patch_maker = PatchMaker(patchsize, stride=patchstride)\n\n        self.forward_modules = torch.nn.ModuleDict({})\n\n        feature_aggregator = common.NetworkFeatureAggregator(\n            self.backbone, self.layers_to_extract_from, self.device, train_backbone\n        )\n        feature_dimensions = feature_aggregator.feature_dimensions(input_shape)\n        self.forward_modules[\"feature_aggregator\"] = feature_aggregator\n\n        preprocessing = common.Preprocessing(\n            feature_dimensions, pretrain_embed_dimension\n        )\n        self.forward_modules[\"preprocessing\"] = preprocessing\n\n        self.target_embed_dimension = target_embed_dimension\n        preadapt_aggregator = common.Aggregator(\n            target_dim=target_embed_dimension\n        )\n\n        _ = preadapt_aggregator.to(self.device)\n\n        self.forward_modules[\"preadapt_aggregator\"] = preadapt_aggregator\n\n        self.anomaly_segmentor = common.RescaleSegmentor(\n            device=self.device, target_size=input_shape[-2:]\n        )\n\n        self.embedding_size = embedding_size if embedding_size is not None else self.target_embed_dimension\n        self.meta_epochs = meta_epochs\n        self.lr = lr\n        self.cos_lr = cos_lr\n        self.train_backbone = train_backbone\n        if self.train_backbone:\n            self.backbone_opt = torch.optim.AdamW(self.forward_modules[\"feature_aggregator\"].backbone.parameters(), lr)\n        # AED\n        self.aed_meta_epochs = aed_meta_epochs\n\n        self.pre_proj = pre_proj\n        if self.pre_proj > 0:\n            self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj, proj_layer_type)\n            self.pre_projection.to(self.device)\n            self.proj_opt = torch.optim.AdamW(self.pre_projection.parameters(), lr*.1)\n\n        # Discriminator\n        self.auto_noise = [auto_noise, None]\n        self.dsc_lr = dsc_lr\n        self.gan_epochs = gan_epochs\n        self.mix_noise = mix_noise\n        self.noise_type = noise_type\n        self.noise_std = noise_std\n        self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden)\n        self.discriminator.to(self.device)\n        self.dsc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.dsc_lr, weight_decay=1e-5)\n        self.dsc_schl = torch.optim.lr_scheduler.CosineAnnealingLR(self.dsc_opt, (meta_epochs - aed_meta_epochs) * gan_epochs, self.dsc_lr*.4)\n        self.dsc_margin= dsc_margin \n\n        self.model_dir = \"\"\n        self.dataset_name = \"\"\n        self.tau = 1\n        self.logger = None\n\n    def set_model_dir(self, model_dir, dataset_name):\n\n        self.model_dir = model_dir \n        os.makedirs(self.model_dir, exist_ok=True)\n        self.ckpt_dir = os.path.join(self.model_dir, dataset_name)\n        os.makedirs(self.ckpt_dir, exist_ok=True)\n        self.tb_dir = os.path.join(self.ckpt_dir, \"tb\")\n        os.makedirs(self.tb_dir, exist_ok=True)\n        self.logger = TBWrapper(self.tb_dir) #SummaryWriter(log_dir=tb_dir)\n    \n\n    def embed(self, data):\n        if isinstance(data, torch.utils.data.DataLoader):\n            features = []\n            for image in data:\n                if isinstance(image, dict):\n                    image = image[\"image\"]\n                    input_image = image.to(torch.float).to(self.device)\n                with torch.no_grad():\n                    features.append(self._embed(input_image))\n            return features\n        return self._embed(data)\n\n    def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=False):\n        \"\"\"Returns feature embeddings for images.\"\"\"\n\n        B = len(images)\n        if not evaluation and self.train_backbone:\n            self.forward_modules[\"feature_aggregator\"].train()\n            features = self.forward_modules[\"feature_aggregator\"](images, eval=evaluation)\n        else:\n            _ = self.forward_modules[\"feature_aggregator\"].eval()\n            with torch.no_grad():\n                features = self.forward_modules[\"feature_aggregator\"](images)\n\n        features = [features[layer] for layer in self.layers_to_extract_from]\n\n        for i, feat in enumerate(features):\n            if len(feat.shape) == 3:\n                B, L, C = feat.shape\n                features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2)\n\n        features = [\n            self.patch_maker.patchify(x, return_spatial_info=True) for x in features\n        ]\n        patch_shapes = [x[1] for x in features]\n        features = [x[0] for x in features]\n        ref_num_patches = patch_shapes[0]\n\n        for i in range(1, len(features)):\n            _features = features[i]\n            patch_dims = patch_shapes[i]\n\n            # TODO(pgehler): Add comments\n            _features = _features.reshape(\n                _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]\n            )\n            _features = _features.permute(0, -3, -2, -1, 1, 2)\n            perm_base_shape = _features.shape\n            _features = _features.reshape(-1, *_features.shape[-2:])\n            _features = F.interpolate(\n                _features.unsqueeze(1),\n                size=(ref_num_patches[0], ref_num_patches[1]),\n                mode=\"bilinear\",\n                align_corners=False,\n            )\n            _features = _features.squeeze(1)\n            _features = _features.reshape(\n                *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]\n            )\n            _features = _features.permute(0, -2, -1, 1, 2, 3)\n            _features = _features.reshape(len(_features), -1, *_features.shape[-3:])\n            features[i] = _features\n        features = [x.reshape(-1, *x.shape[-3:]) for x in features]\n        \n        # As different feature backbones & patching provide differently\n        # sized features, these are brought into the correct form here.\n        features = self.forward_modules[\"preprocessing\"](features) # pooling each feature to same channel and stack together\n        features = self.forward_modules[\"preadapt_aggregator\"](features) # further pooling        \n\n\n        return features, patch_shapes\n\n    \n    def test(self, training_data, test_data, save_segmentation_images):\n\n        ckpt_path = os.path.join(self.ckpt_dir, \"models.ckpt\")\n        if os.path.exists(ckpt_path):\n            state_dicts = torch.load(ckpt_path, map_location=self.device)\n            if \"pretrained_enc\" in state_dicts:\n                self.feature_enc.load_state_dict(state_dicts[\"pretrained_enc\"])\n            if \"pretrained_dec\" in state_dicts:\n                self.feature_dec.load_state_dict(state_dicts[\"pretrained_dec\"])\n\n        aggregator = {\"scores\": [], \"segmentations\": [], \"features\": []}\n        scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)\n        aggregator[\"scores\"].append(scores)\n        aggregator[\"segmentations\"].append(segmentations)\n        aggregator[\"features\"].append(features)\n\n        scores = np.array(aggregator[\"scores\"])\n        min_scores = scores.min(axis=-1).reshape(-1, 1)\n        max_scores = scores.max(axis=-1).reshape(-1, 1)\n        scores = (scores - min_scores) / (max_scores - min_scores)\n        scores = np.mean(scores, axis=0)\n\n        segmentations = np.array(aggregator[\"segmentations\"])\n        min_scores = (\n            segmentations.reshape(len(segmentations), -1)\n            .min(axis=-1)\n            .reshape(-1, 1, 1, 1)\n        )\n        max_scores = (\n            segmentations.reshape(len(segmentations), -1)\n            .max(axis=-1)\n            .reshape(-1, 1, 1, 1)\n        )\n        segmentations = (segmentations - min_scores) / (max_scores - min_scores)\n        segmentations = np.mean(segmentations, axis=0)\n\n        anomaly_labels = [\n            x[1] != \"good\" for x in test_data.dataset.data_to_iterate\n        ]\n\n        if save_segmentation_images:\n            self.save_segmentation_images(test_data, segmentations, scores)\n            \n        auroc = metrics.compute_imagewise_retrieval_metrics(\n            scores, anomaly_labels\n        )[\"auroc\"]\n\n        # Compute PRO score & PW Auroc for all images\n        pixel_scores = metrics.compute_pixelwise_retrieval_metrics(\n            segmentations, masks_gt\n        )\n        full_pixel_auroc = pixel_scores[\"auroc\"]\n\n        return auroc, full_pixel_auroc\n    \n    def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks_gt):\n        \n        scores = np.squeeze(np.array(scores))\n        img_min_scores = scores.min(axis=-1)\n        img_max_scores = scores.max(axis=-1)\n        scores = (scores - img_min_scores) / (img_max_scores - img_min_scores)\n        # scores = np.mean(scores, axis=0)\n\n        auroc = metrics.compute_imagewise_retrieval_metrics(\n            scores, labels_gt \n        )[\"auroc\"]\n\n        if len(masks_gt) > 0:\n            segmentations = np.array(segmentations)\n            min_scores = (\n                segmentations.reshape(len(segmentations), -1)\n                .min(axis=-1)\n                .reshape(-1, 1, 1, 1)\n            )\n            max_scores = (\n                segmentations.reshape(len(segmentations), -1)\n                .max(axis=-1)\n                .reshape(-1, 1, 1, 1)\n            )\n            norm_segmentations = np.zeros_like(segmentations)\n            for min_score, max_score in zip(min_scores, max_scores):\n                norm_segmentations += (segmentations - min_score) / max(max_score - min_score, 1e-2)\n            norm_segmentations = norm_segmentations / len(scores)\n\n\n            # Compute PRO score & PW Auroc for all images\n            pixel_scores = metrics.compute_pixelwise_retrieval_metrics(\n                norm_segmentations, masks_gt)\n                # segmentations, masks_gt\n            full_pixel_auroc = pixel_scores[\"auroc\"]\n\n            pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), \n                                            norm_segmentations)\n        else:\n            full_pixel_auroc = -1 \n            pro = -1\n\n        return auroc, full_pixel_auroc, pro\n        \n    \n    def train(self, training_data, test_data):\n\n        \n        state_dict = {}\n        ckpt_path = os.path.join(self.ckpt_dir, \"ckpt.pth\")\n        if os.path.exists(ckpt_path):\n            state_dict = torch.load(ckpt_path, map_location=self.device)\n            if 'discriminator' in state_dict:\n                self.discriminator.load_state_dict(state_dict['discriminator'])\n                if \"pre_projection\" in state_dict:\n                    self.pre_projection.load_state_dict(state_dict[\"pre_projection\"])\n            else:\n                self.load_state_dict(state_dict, strict=False)\n\n            self.predict(training_data, \"train_\")\n            scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)\n            auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)\n            \n            return auroc, full_pixel_auroc, anomaly_pixel_auroc\n        \n        def update_state_dict(d):\n            \n            state_dict[\"discriminator\"] = OrderedDict({\n                k:v.detach().cpu() \n                for k, v in self.discriminator.state_dict().items()})\n            if self.pre_proj > 0:\n                state_dict[\"pre_projection\"] = OrderedDict({\n                    k:v.detach().cpu() \n                    for k, v in self.pre_projection.state_dict().items()})\n\n        best_record = None\n        for i_mepoch in range(self.meta_epochs):\n\n            self._train_discriminator(training_data)\n\n            # torch.cuda.empty_cache()\n            scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)\n            auroc, full_pixel_auroc, pro = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)\n            self.logger.logger.add_scalar(\"i-auroc\", auroc, i_mepoch)\n            self.logger.logger.add_scalar(\"p-auroc\", full_pixel_auroc, i_mepoch)\n            self.logger.logger.add_scalar(\"pro\", pro, i_mepoch)\n\n            if best_record is None:\n                best_record = [auroc, full_pixel_auroc, pro]\n                update_state_dict(state_dict)\n                # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})\n            else:\n                if auroc > best_record[0]:\n                    best_record = [auroc, full_pixel_auroc, pro]\n                    update_state_dict(state_dict)\n                    # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})\n                elif auroc == best_record[0] and full_pixel_auroc > best_record[1]:\n                    best_record[1] = full_pixel_auroc\n                    best_record[2] = pro \n                    update_state_dict(state_dict)\n                    # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})\n\n            print(f\"----- {i_mepoch} I-AUROC:{round(auroc, 4)}(MAX:{round(best_record[0], 4)})\"\n                  f\"  P-AUROC{round(full_pixel_auroc, 4)}(MAX:{round(best_record[1], 4)}) -----\"\n                  f\"  PRO-AUROC{round(pro, 4)}(MAX:{round(best_record[2], 4)}) -----\")\n        \n        torch.save(state_dict, ckpt_path)\n        \n        return best_record\n            \n\n    def _train_discriminator(self, input_data):\n        \"\"\"Computes and sets the support features for SPADE.\"\"\"\n        _ = self.forward_modules.eval()\n        \n        if self.pre_proj > 0:\n            self.pre_projection.train()\n        self.discriminator.train()\n        # self.feature_enc.eval()\n        # self.feature_dec.eval()\n        i_iter = 0\n        LOGGER.info(f\"Training discriminator...\")\n        with tqdm.tqdm(total=self.gan_epochs) as pbar:\n            for i_epoch in range(self.gan_epochs):\n                all_loss = []\n                all_p_true = []\n                all_p_fake = []\n                all_p_interp = []\n                embeddings_list = []\n                for data_item in input_data:\n                    self.dsc_opt.zero_grad()\n                    if self.pre_proj > 0:\n                        self.proj_opt.zero_grad()\n                    # self.dec_opt.zero_grad()\n\n                    i_iter += 1\n                    img = data_item[\"image\"]\n                    img = img.to(torch.float).to(self.device)\n                    if self.pre_proj > 0:\n                        true_feats = self.pre_projection(self._embed(img, evaluation=False)[0])\n                    else:\n                        true_feats = self._embed(img, evaluation=False)[0]\n                    \n                    noise_idxs = torch.randint(0, self.mix_noise, torch.Size([true_feats.shape[0]]))\n                    noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=self.mix_noise).to(self.device) # (N, K)\n                    noise = torch.stack([\n                        torch.normal(0, self.noise_std * 1.1**(k), true_feats.shape)\n                        for k in range(self.mix_noise)], dim=1).to(self.device) # (N, K, C)\n                    noise = (noise * noise_one_hot.unsqueeze(-1)).sum(1)\n                    fake_feats = true_feats + noise\n\n                    scores = self.discriminator(torch.cat([true_feats, fake_feats]))\n                    true_scores = scores[:len(true_feats)]\n                    fake_scores = scores[len(fake_feats):]\n                    \n                    th = self.dsc_margin\n                    p_true = (true_scores.detach() >= th).sum() / len(true_scores)\n                    p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores)\n                    true_loss = torch.clip(-true_scores + th, min=0)\n                    fake_loss = torch.clip(fake_scores + th, min=0)\n\n                    self.logger.logger.add_scalar(f\"p_true\", p_true, self.logger.g_iter)\n                    self.logger.logger.add_scalar(f\"p_fake\", p_fake, self.logger.g_iter)\n\n                    loss = true_loss.mean() + fake_loss.mean()\n                    self.logger.logger.add_scalar(\"loss\", loss, self.logger.g_iter)\n                    self.logger.step()\n\n                    loss.backward()\n                    if self.pre_proj > 0:\n                        self.proj_opt.step()\n                    if self.train_backbone:\n                        self.backbone_opt.step()\n                    self.dsc_opt.step()\n\n                    loss = loss.detach().cpu() \n                    all_loss.append(loss.item())\n                    all_p_true.append(p_true.cpu().item())\n                    all_p_fake.append(p_fake.cpu().item())\n                \n                if len(embeddings_list) > 0:\n                    self.auto_noise[1] = torch.cat(embeddings_list).std(0).mean(-1)\n                \n                if self.cos_lr:\n                    self.dsc_schl.step()\n                \n                all_loss = sum(all_loss) / len(input_data)\n                all_p_true = sum(all_p_true) / len(input_data)\n                all_p_fake = sum(all_p_fake) / len(input_data)\n                cur_lr = self.dsc_opt.state_dict()['param_groups'][0]['lr']\n                pbar_str = f\"epoch:{i_epoch} loss:{round(all_loss, 5)} \"\n                pbar_str += f\"lr:{round(cur_lr, 6)}\"\n                pbar_str += f\" p_true:{round(all_p_true, 3)} p_fake:{round(all_p_fake, 3)}\"\n                if len(all_p_interp) > 0:\n                    pbar_str += f\" p_interp:{round(sum(all_p_interp) / len(input_data), 3)}\"\n                pbar.set_description_str(pbar_str)\n                pbar.update(1)\n\n\n    def predict(self, data, prefix=\"\"):\n        if isinstance(data, torch.utils.data.DataLoader):\n            return self._predict_dataloader(data, prefix)\n        return self._predict(data)\n\n    def _predict_dataloader(self, dataloader, prefix):\n        \"\"\"This function provides anomaly scores/maps for full dataloaders.\"\"\"\n        _ = self.forward_modules.eval()\n\n\n        img_paths = []\n        scores = []\n        masks = []\n        features = []\n        labels_gt = []\n        masks_gt = []\n        from sklearn.manifold import TSNE\n\n        with tqdm.tqdm(dataloader, desc=\"Inferring...\", leave=False) as data_iterator:\n            for data in data_iterator:\n                if isinstance(data, dict):\n                    labels_gt.extend(data[\"is_anomaly\"].numpy().tolist())\n                    if data.get(\"mask\", None) is not None:\n                        masks_gt.extend(data[\"mask\"].numpy().tolist())\n                    image = data[\"image\"]\n                    img_paths.extend(data['image_path'])\n                _scores, _masks, _feats = self._predict(image)\n                for score, mask, feat, is_anomaly in zip(_scores, _masks, _feats, data[\"is_anomaly\"].numpy().tolist()):\n                    scores.append(score)\n                    masks.append(mask)\n\n        return scores, masks, features, labels_gt, masks_gt\n\n    def _predict(self, images):\n        \"\"\"Infer score and mask for a batch of images.\"\"\"\n        images = images.to(torch.float).to(self.device)\n        _ = self.forward_modules.eval()\n\n        batchsize = images.shape[0]\n        if self.pre_proj > 0:\n            self.pre_projection.eval()\n        self.discriminator.eval()\n        with torch.no_grad():\n            features, patch_shapes = self._embed(images,\n                                                 provide_patch_shapes=True, \n                                                 evaluation=True)\n            if self.pre_proj > 0:\n                features = self.pre_projection(features)\n\n            # features = features.cpu().numpy()\n            # features = np.ascontiguousarray(features.cpu().numpy())\n            patch_scores = image_scores = -self.discriminator(features)\n            patch_scores = patch_scores.cpu().numpy()\n            image_scores = image_scores.cpu().numpy()\n\n            image_scores = self.patch_maker.unpatch_scores(\n                image_scores, batchsize=batchsize\n            )\n            image_scores = image_scores.reshape(*image_scores.shape[:2], -1)\n            image_scores = self.patch_maker.score(image_scores)\n\n            patch_scores = self.patch_maker.unpatch_scores(\n                patch_scores, batchsize=batchsize\n            )\n            scales = patch_shapes[0]\n            patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1])\n            features = features.reshape(batchsize, scales[0], scales[1], -1)\n            masks, features = self.anomaly_segmentor.convert_to_segmentation(patch_scores, features)\n\n        return list(image_scores), list(masks), list(features)\n\n    @staticmethod\n    def _params_file(filepath, prepend=\"\"):\n        return os.path.join(filepath, prepend + \"params.pkl\")\n\n    def save_to_path(self, save_path: str, prepend: str = \"\"):\n        LOGGER.info(\"Saving data.\")\n        self.anomaly_scorer.save(\n            save_path, save_features_separately=False, prepend=prepend\n        )\n        params = {\n            \"backbone.name\": self.backbone.name,\n            \"layers_to_extract_from\": self.layers_to_extract_from,\n            \"input_shape\": self.input_shape,\n            \"pretrain_embed_dimension\": self.forward_modules[\n                \"preprocessing\"\n            ].output_dim,\n            \"target_embed_dimension\": self.forward_modules[\n                \"preadapt_aggregator\"\n            ].target_dim,\n            \"patchsize\": self.patch_maker.patchsize,\n            \"patchstride\": self.patch_maker.stride,\n            \"anomaly_scorer_num_nn\": self.anomaly_scorer.n_nearest_neighbours,\n        }\n        with open(self._params_file(save_path, prepend), \"wb\") as save_file:\n            pickle.dump(params, save_file, pickle.HIGHEST_PROTOCOL)\n\n    def save_segmentation_images(self, data, segmentations, scores):\n        image_paths = [\n            x[2] for x in data.dataset.data_to_iterate\n        ]\n        mask_paths = [\n            x[3] for x in data.dataset.data_to_iterate\n        ]\n\n        def image_transform(image):\n            in_std = np.array(\n                data.dataset.transform_std\n            ).reshape(-1, 1, 1)\n            in_mean = np.array(\n                data.dataset.transform_mean\n            ).reshape(-1, 1, 1)\n            image = data.dataset.transform_img(image)\n            return np.clip(\n                (image.numpy() * in_std + in_mean) * 255, 0, 255\n            ).astype(np.uint8)\n\n        def mask_transform(mask):\n            return data.dataset.transform_mask(mask).numpy()\n\n        plot_segmentation_images(\n            './output',\n            image_paths,\n            segmentations,\n            scores,\n            mask_paths,\n            image_transform=image_transform,\n            mask_transform=mask_transform,\n        )\n\n# Image handling classes.\nclass PatchMaker:\n    def __init__(self, patchsize, top_k=0, stride=None):\n        self.patchsize = patchsize\n        self.stride = stride\n        self.top_k = top_k\n\n    def patchify(self, features, return_spatial_info=False):\n        \"\"\"Convert a tensor into a tensor of respective patches.\n        Args:\n            x: [torch.Tensor, bs x c x w x h]\n        Returns:\n            x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,\n            patchsize]\n        \"\"\"\n        padding = int((self.patchsize - 1) / 2)\n        unfolder = torch.nn.Unfold(\n            kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1\n        )\n        unfolded_features = unfolder(features)\n        number_of_total_patches = []\n        for s in features.shape[-2:]:\n            n_patches = (\n                s + 2 * padding - 1 * (self.patchsize - 1) - 1\n            ) / self.stride + 1\n            number_of_total_patches.append(int(n_patches))\n        unfolded_features = unfolded_features.reshape(\n            *features.shape[:2], self.patchsize, self.patchsize, -1\n        )\n        unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)\n\n        if return_spatial_info:\n            return unfolded_features, number_of_total_patches\n        return unfolded_features\n\n    def unpatch_scores(self, x, batchsize):\n        return x.reshape(batchsize, -1, *x.shape[1:])\n\n    def score(self, x):\n        was_numpy = False\n        if isinstance(x, np.ndarray):\n            was_numpy = True\n            x = torch.from_numpy(x)\n        while x.ndim > 2:\n            x = torch.max(x, dim=-1).values\n        if x.ndim == 2:\n            if self.top_k > 1:\n                x = torch.topk(x, self.top_k, dim=1).values.mean(1)\n            else:\n                x = torch.max(x, dim=1).values\n        if was_numpy:\n            return x.numpy()\n        return x\n"
  },
  {
    "path": "utils.py",
    "content": "import csv\nimport logging\nimport os\nimport random\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport PIL\nimport torch\nimport tqdm\n\nLOGGER = logging.getLogger(__name__)\n\n\ndef plot_segmentation_images(\n    savefolder,\n    image_paths,\n    segmentations,\n    anomaly_scores=None,\n    mask_paths=None,\n    image_transform=lambda x: x,\n    mask_transform=lambda x: x,\n    save_depth=4,\n):\n    \"\"\"Generate anomaly segmentation images.\n\n    Args:\n        image_paths: List[str] List of paths to images.\n        segmentations: [List[np.ndarray]] Generated anomaly segmentations.\n        anomaly_scores: [List[float]] Anomaly scores for each image.\n        mask_paths: [List[str]] List of paths to ground truth masks.\n        image_transform: [function or lambda] Optional transformation of images.\n        mask_transform: [function or lambda] Optional transformation of masks.\n        save_depth: [int] Number of path-strings to use for image savenames.\n    \"\"\"\n    if mask_paths is None:\n        mask_paths = [\"-1\" for _ in range(len(image_paths))]\n    masks_provided = mask_paths[0] != \"-1\"\n    if anomaly_scores is None:\n        anomaly_scores = [\"-1\" for _ in range(len(image_paths))]\n\n    os.makedirs(savefolder, exist_ok=True)\n\n    for image_path, mask_path, anomaly_score, segmentation in tqdm.tqdm(\n        zip(image_paths, mask_paths, anomaly_scores, segmentations),\n        total=len(image_paths),\n        desc=\"Generating Segmentation Images...\",\n        leave=False,\n    ):\n        image = PIL.Image.open(image_path).convert(\"RGB\")\n        image = image_transform(image)\n        if not isinstance(image, np.ndarray):\n            image = image.numpy()\n\n        if masks_provided:\n            if mask_path is not None:\n                mask = PIL.Image.open(mask_path).convert(\"RGB\")\n                mask = mask_transform(mask)\n                if not isinstance(mask, np.ndarray):\n                    mask = mask.numpy()\n            else:\n                mask = np.zeros_like(image)\n\n        savename = image_path.split(\"/\")\n        savename = \"_\".join(savename[-save_depth:])\n        savename = os.path.join(savefolder, savename)\n        f, axes = plt.subplots(1, 2 + int(masks_provided))\n        axes[0].imshow(image.transpose(1, 2, 0))\n        axes[1].imshow(mask.transpose(1, 2, 0))\n        axes[2].imshow(segmentation)\n        f.set_size_inches(3 * (2 + int(masks_provided)), 3)\n        f.tight_layout()\n        f.savefig(savename)\n        plt.close()\n\n\ndef create_storage_folder(\n    main_folder_path, project_folder, group_folder, run_name, mode=\"iterate\"\n):\n    os.makedirs(main_folder_path, exist_ok=True)\n    project_path = os.path.join(main_folder_path, project_folder)\n    os.makedirs(project_path, exist_ok=True)\n    save_path = os.path.join(project_path, group_folder, run_name)\n    if mode == \"iterate\":\n        counter = 0\n        while os.path.exists(save_path):\n            save_path = os.path.join(project_path, group_folder + \"_\" + str(counter))\n            counter += 1\n        os.makedirs(save_path)\n    elif mode == \"overwrite\":\n        os.makedirs(save_path, exist_ok=True)\n\n    return save_path\n\n\ndef set_torch_device(gpu_ids):\n    \"\"\"Returns correct torch.device.\n\n    Args:\n        gpu_ids: [list] list of gpu ids. If empty, cpu is used.\n    \"\"\"\n    if len(gpu_ids):\n        # os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n        # os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_ids[0])\n        return torch.device(\"cuda:{}\".format(gpu_ids[0]))\n    return torch.device(\"cpu\")\n\n\ndef fix_seeds(seed, with_torch=True, with_cuda=True):\n    \"\"\"Fixed available seeds for reproducibility.\n\n    Args:\n        seed: [int] Seed value.\n        with_torch: Flag. If true, torch-related seeds are fixed.\n        with_cuda: Flag. If true, torch+cuda-related seeds are fixed\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    if with_torch:\n        torch.manual_seed(seed)\n    if with_cuda:\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n        torch.backends.cudnn.deterministic = True\n\n\ndef compute_and_store_final_results(\n    results_path,\n    results,\n    row_names=None,\n    column_names=[\n        \"Instance AUROC\",\n        \"Full Pixel AUROC\",\n        \"Full PRO\",\n        \"Anomaly Pixel AUROC\",\n        \"Anomaly PRO\",\n    ],\n):\n    \"\"\"Store computed results as CSV file.\n\n    Args:\n        results_path: [str] Where to store result csv.\n        results: [List[List]] List of lists containing results per dataset,\n                 with results[i][0] == 'dataset_name' and results[i][1:6] =\n                 [instance_auroc, full_pixelwisew_auroc, full_pro,\n                 anomaly-only_pw_auroc, anomaly-only_pro]\n    \"\"\"\n    if row_names is not None:\n        assert len(row_names) == len(results), \"#Rownames != #Result-rows.\"\n\n    mean_metrics = {}\n    for i, result_key in enumerate(column_names):\n        mean_metrics[result_key] = np.mean([x[i] for x in results])\n        LOGGER.info(\"{0}: {1:3.3f}\".format(result_key, mean_metrics[result_key]))\n\n    savename = os.path.join(results_path, \"results.csv\")\n    with open(savename, \"w\") as csv_file:\n        csv_writer = csv.writer(csv_file, delimiter=\",\")\n        header = column_names\n        if row_names is not None:\n            header = [\"Row Names\"] + header\n\n        csv_writer.writerow(header)\n        for i, result_list in enumerate(results):\n            csv_row = result_list\n            if row_names is not None:\n                csv_row = [row_names[i]] + result_list\n            csv_writer.writerow(csv_row)\n        mean_scores = list(mean_metrics.values())\n        if row_names is not None:\n            mean_scores = [\"Mean\"] + mean_scores\n        csv_writer.writerow(mean_scores)\n\n    mean_metrics = {\"mean_{0}\".format(key): item for key, item in mean_metrics.items()}\n    return mean_metrics\n"
  }
]