Repository: DonaldRR/SimpleNet Branch: main Commit: 351a2b8d4e8c Files: 18 Total size: 108.6 KB Directory structure: gitextract_ap_hyzgr/ ├── .gitignore ├── LICENSE ├── README.md ├── VERSION ├── backbones.py ├── common.py ├── datasets/ │ ├── __init__.py │ ├── btad.py │ ├── cifar10.py │ ├── mvtec.py │ ├── sdd.py │ └── sdd2.py ├── main.py ├── metrics.py ├── resnet.py ├── run.sh ├── simplenet.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/* ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 DonaldRR Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # SimpleNet ![](imgs/cover.png) **SimpleNet: A Simple Network for Image Anomaly Detection and Localization** *Zhikang Liu, Yiming Zhou, Yuansheng Xu, Zilei Wang** [Paper link](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf) ## Introduction This repo contains source code for **SimpleNet** implemented with pytorch. SimpleNet is a simple defect detection and localization network that built with a feature encoder, feature generator and defect discriminator. It is designed conceptionally simple without complex network deisng, training schemes or external data source. ## Get Started ### Environment **Python3.8** **Packages**: - torch==1.12.1 - torchvision==0.13.1 - numpy==1.22.4 - opencv-python==4.5.1 (Above environment setups are not the minimum requiremetns, other versions might work too.) ### Data Edit `run.sh` to edit dataset class and dataset path. #### MvTecAD Download the dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad/). The dataset folders/files follow its original structure. ### Run #### Demo train Please specicy dataset path (line1) and log folder (line10) in `run.sh` before running. `run.sh` gives the configuration to train models on MVTecAD dataset. ``` bash run.sh ``` ## Citation ``` @inproceedings{liu2023simplenet, title={SimpleNet: A Simple Network for Image Anomaly Detection and Localization}, author={Liu, Zhikang and Zhou, Yiming and Xu, Yuansheng and Wang, Zilei}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={20402--20411}, year={2023} } ``` ## Acknowledgement Thanks for great inspiration from [PatchCore](https://github.com/amazon-science/patchcore-inspection) ## License All code within the repo is under [MIT license](https://mit-license.org/) ================================================ FILE: VERSION ================================================ 0.1.0 ================================================ FILE: backbones.py ================================================ import timm # noqa import torch import torchvision.models as models # noqa def load_ref_wrn50(): import resnet return resnet.wide_resnet50_2(True) _BACKBONES = { "cait_s24_224" : "cait.cait_S24_224(True)", "cait_xs24": "cait.cait_XS24(True)", "alexnet": "models.alexnet(pretrained=True)", "bninception": 'pretrainedmodels.__dict__["bninception"]' '(pretrained="imagenet", num_classes=1000)', "resnet18": "models.resnet18(pretrained=True)", "resnet50": "models.resnet50(pretrained=True)", "mc3_resnet50": "load_mc3_rn50()", "resnet101": "models.resnet101(pretrained=True)", "resnext101": "models.resnext101_32x8d(pretrained=True)", "resnet200": 'timm.create_model("resnet200", pretrained=True)', "resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)', "resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)', "resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)', "resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)', "resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)', "resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)', "resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)', "resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)', "resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)', "vgg11": "models.vgg11(pretrained=True)", "vgg19": "models.vgg19(pretrained=True)", "vgg19_bn": "models.vgg19_bn(pretrained=True)", "wideresnet50": "models.wide_resnet50_2(pretrained=True)", "ref_wideresnet50": "load_ref_wrn50()", "wideresnet101": "models.wide_resnet101_2(pretrained=True)", "mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)', "mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)', "mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)', "densenet121": 'timm.create_model("densenet121", pretrained=True)', "densenet201": 'timm.create_model("densenet201", pretrained=True)', "inception_v4": 'timm.create_model("inception_v4", pretrained=True)', "vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)', "vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)', "vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)', "vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)', "vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)', "vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)', "vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)', "vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)', "efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)', "efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)', "efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)', "efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)', "efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)', "efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)', "efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)', } def load(name): return eval(_BACKBONES[name]) ================================================ FILE: common.py ================================================ import copy from typing import List import numpy as np import scipy.ndimage as ndimage import torch import torch.nn.functional as F class _BaseMerger: def __init__(self): """Merges feature embedding by name.""" def merge(self, features: list): features = [self._reduce(feature) for feature in features] return np.concatenate(features, axis=1) class AverageMerger(_BaseMerger): @staticmethod def _reduce(features): # NxCxWxH -> NxC return features.reshape([features.shape[0], features.shape[1], -1]).mean( axis=-1 ) class ConcatMerger(_BaseMerger): @staticmethod def _reduce(features): # NxCxWxH -> NxCWH return features.reshape(len(features), -1) class Preprocessing(torch.nn.Module): def __init__(self, input_dims, output_dim): super(Preprocessing, self).__init__() self.input_dims = input_dims self.output_dim = output_dim self.preprocessing_modules = torch.nn.ModuleList() for input_dim in input_dims: module = MeanMapper(output_dim) self.preprocessing_modules.append(module) def forward(self, features): _features = [] for module, feature in zip(self.preprocessing_modules, features): _features.append(module(feature)) return torch.stack(_features, dim=1) class MeanMapper(torch.nn.Module): def __init__(self, preprocessing_dim): super(MeanMapper, self).__init__() self.preprocessing_dim = preprocessing_dim def forward(self, features): features = features.reshape(len(features), 1, -1) return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1) class Aggregator(torch.nn.Module): def __init__(self, target_dim): super(Aggregator, self).__init__() self.target_dim = target_dim def forward(self, features): """Returns reshaped and average pooled features.""" # batchsize x number_of_layers x input_dim -> batchsize x target_dim features = features.reshape(len(features), 1, -1) features = F.adaptive_avg_pool1d(features, self.target_dim) return features.reshape(len(features), -1) class RescaleSegmentor: def __init__(self, device, target_size=224): self.device = device self.target_size = target_size self.smoothing = 4 def convert_to_segmentation(self, patch_scores, features): with torch.no_grad(): if isinstance(patch_scores, np.ndarray): patch_scores = torch.from_numpy(patch_scores) _scores = patch_scores.to(self.device) _scores = _scores.unsqueeze(1) _scores = F.interpolate( _scores, size=self.target_size, mode="bilinear", align_corners=False ) _scores = _scores.squeeze(1) patch_scores = _scores.cpu().numpy() if isinstance(features, np.ndarray): features = torch.from_numpy(features) features = features.to(self.device).permute(0, 3, 1, 2) if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2**31: subbatch_size = int((2**31-1) / (self.target_size[0] * self.target_size[1] * features.shape[1])) interpolated_features = [] for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)): subfeatures = features[i_subbatch*subbatch_size:(i_subbatch+1)*subbatch_size] subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures subfeatures = F.interpolate( subfeatures, size=self.target_size, mode="bilinear", align_corners=False ) interpolated_features.append(subfeatures) features = torch.cat(interpolated_features, 0) else: features = F.interpolate( features, size=self.target_size, mode="bilinear", align_corners=False ) features = features.cpu().numpy() return [ ndimage.gaussian_filter(patch_score, sigma=self.smoothing) for patch_score in patch_scores ], [ feature for feature in features ] class NetworkFeatureAggregator(torch.nn.Module): """Efficient extraction of network features.""" def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False): super(NetworkFeatureAggregator, self).__init__() """Extraction of network features. Runs a network only to the last layer of the list of layers where network features should be extracted from. Args: backbone: torchvision.model layers_to_extract_from: [list of str] """ self.layers_to_extract_from = layers_to_extract_from self.backbone = backbone self.device = device self.train_backbone = train_backbone if not hasattr(backbone, "hook_handles"): self.backbone.hook_handles = [] for handle in self.backbone.hook_handles: handle.remove() self.outputs = {} for extract_layer in layers_to_extract_from: forward_hook = ForwardHook( self.outputs, extract_layer, layers_to_extract_from[-1] ) if "." in extract_layer: extract_block, extract_idx = extract_layer.split(".") network_layer = backbone.__dict__["_modules"][extract_block] if extract_idx.isnumeric(): extract_idx = int(extract_idx) network_layer = network_layer[extract_idx] else: network_layer = network_layer.__dict__["_modules"][extract_idx] else: network_layer = backbone.__dict__["_modules"][extract_layer] if isinstance(network_layer, torch.nn.Sequential): self.backbone.hook_handles.append( network_layer[-1].register_forward_hook(forward_hook) ) else: self.backbone.hook_handles.append( network_layer.register_forward_hook(forward_hook) ) self.to(self.device) def forward(self, images, eval=True): self.outputs.clear() if self.train_backbone and not eval: self.backbone(images) else: with torch.no_grad(): # The backbone will throw an Exception once it reached the last # layer to compute features from. Computation will stop there. try: _ = self.backbone(images) except LastLayerToExtractReachedException: pass return self.outputs def feature_dimensions(self, input_shape): """Computes the feature dimensions for all layers given input_shape.""" _input = torch.ones([1] + list(input_shape)).to(self.device) _output = self(_input) return [_output[layer].shape[1] for layer in self.layers_to_extract_from] class ForwardHook: def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str): self.hook_dict = hook_dict self.layer_name = layer_name self.raise_exception_to_break = copy.deepcopy( layer_name == last_layer_to_extract ) def __call__(self, module, input, output): self.hook_dict[self.layer_name] = output # if self.raise_exception_to_break: # raise LastLayerToExtractReachedException() return None class LastLayerToExtractReachedException(Exception): pass ================================================ FILE: datasets/__init__.py ================================================ ================================================ FILE: datasets/btad.py ================================================ import os from enum import Enum import PIL import torch from torchvision import transforms _CLASSNAMES = [ "01", "02", "03" ] IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class BTADDataset(torch.utils.data.Dataset): """ PyTorch Dataset for MVTec. """ def __init__( self, source, classname, resize=256, imagesize=224, split=DatasetSplit.TRAIN, train_val_split=1.0, rotate_degrees=0, translate=0, brightness_factor=0, contrast_factor=0, saturation_factor=0, gray_p=0, h_flip_p=0, v_flip_p=0, scale=0, **kwargs, ): """ Args: source: [str]. Path to the MVTec data folder. classname: [str or None]. Name of MVTec class that should be provided in this dataset. If None, the datasets iterates over all available images. resize: [int]. (Square) Size the loaded image initially gets resized to. imagesize: [int]. (Square) Size the resized loaded image gets (center-)cropped to. split: [enum-option]. Indicates if training or test split of the data should be used. Has to be an option taken from DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that mvtec.DatasetSplit.TEST will also load mask data. """ super().__init__() self.source = source self.split = split self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES self.train_val_split = train_val_split self.transform_std = IMAGENET_STD self.transform_mean = IMAGENET_MEAN self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() self.transform_img = [ transforms.Resize(resize), # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR), transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), transforms.RandomHorizontalFlip(h_flip_p), transforms.RandomVerticalFlip(v_flip_p), transforms.RandomGrayscale(gray_p), transforms.RandomAffine(rotate_degrees, translate=(translate, translate), scale=(1.0-scale, 1.0+scale), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(imagesize), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) self.transform_mask = [ transforms.Resize(resize), transforms.CenterCrop(imagesize), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.imagesize = (3, imagesize, imagesize) def __getitem__(self, idx): classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] image = PIL.Image.open(image_path).convert("RGB") image = self.transform_img(image) if self.split == DatasetSplit.TEST and mask_path is not None: mask = PIL.Image.open(mask_path) mask = self.transform_mask(mask) else: mask = torch.zeros([1, *image.size()[1:]]) return { "image": image, "mask": mask, "classname": classname, "anomaly": anomaly, "is_anomaly": int(anomaly != "good"), "image_name": "/".join(image_path.split("/")[-4:]), "image_path": image_path, } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): imgpaths_per_class = {} maskpaths_per_class = {} for classname in self.classnames_to_use: classpath = os.path.join(self.source, classname, self.split.value) maskpath = os.path.join(self.source, classname, "ground_truth") anomaly_types = os.listdir(classpath) imgpaths_per_class[classname] = {} maskpaths_per_class[classname] = {} for anomaly in anomaly_types: anomaly_path = os.path.join(classpath, anomaly) anomaly_files = sorted(os.listdir(anomaly_path)) imgpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_path, x) for x in anomaly_files ] if self.train_val_split < 1.0: n_images = len(imgpaths_per_class[classname][anomaly]) train_val_split_idx = int(n_images * self.train_val_split) if self.split == DatasetSplit.TRAIN: imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ classname ][anomaly][:train_val_split_idx] elif self.split == DatasetSplit.VAL: imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ classname ][anomaly][train_val_split_idx:] if self.split == DatasetSplit.TEST and anomaly != "good": anomaly_mask_path = os.path.join(maskpath, anomaly) anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) maskpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files ] else: maskpaths_per_class[classname]["good"] = None # Unrolls the data dictionary to an easy-to-iterate list. data_to_iterate = [] for classname in sorted(imgpaths_per_class.keys()): for anomaly in sorted(imgpaths_per_class[classname].keys()): for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): data_tuple = [classname, anomaly, image_path] if self.split == DatasetSplit.TEST and anomaly != "good": data_tuple.append(maskpaths_per_class[classname][anomaly][i]) else: data_tuple.append(None) data_to_iterate.append(data_tuple) return imgpaths_per_class, data_to_iterate ================================================ FILE: datasets/cifar10.py ================================================ import os from enum import Enum import PIL import torch from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class Cifar10Dataset(torch.utils.data.Dataset): """ PyTorch Dataset for MVTec. """ _CLASSES = list(range(10)) def __init__( self, source, classname, resize=256, imagesize=224, split=DatasetSplit.TRAIN, train_val_split=1.0, rotate_degrees=0, translate=0, brightness_factor=0, contrast_factor=0, saturation_factor=0, gray_p=0, h_flip_p=0, v_flip_p=0, scale=0, **kwargs, ): """ Args: source: [str]. Path to the MVTec data folder. classname: [str or None]. Name of MVTec class that should be provided in this dataset. If None, the datasets iterates over all available images. resize: [int]. (Square) Size the loaded image initially gets resized to. imagesize: [int]. (Square) Size the resized loaded image gets (center-)cropped to. split: [enum-option]. Indicates if training or test split of the data should be used. Has to be an option taken from DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that mvtec.DatasetSplit.TEST will also load mask data. """ super().__init__() self.source = source self.split = split self.classname = int(classname) self.train_val_split = train_val_split self.data_to_iterate = self.get_image_data() self.transform_std = IMAGENET_STD self.transform_mean = IMAGENET_MEAN self.transform_img = [ transforms.Resize(resize), # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR), transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), transforms.RandomHorizontalFlip(h_flip_p), transforms.RandomVerticalFlip(v_flip_p), transforms.RandomGrayscale(gray_p), transforms.RandomAffine(rotate_degrees, translate=(translate, translate), scale=(1.0-scale, 1.0+scale), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(imagesize), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) self.transform_mask = [ transforms.Resize(resize), transforms.CenterCrop(imagesize), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.imagesize = (3, imagesize, imagesize) def __getitem__(self, idx): img_path, classname = self.data_to_iterate[idx] image = PIL.Image.open(img_path).convert("RGB") image = self.transform_img(image) return { "image": image, "classname": classname, "anomaly": int(classname != self.classname), "is_anomaly": int(classname != self.classname), "image_name": os.path.split(img_path)[-1], "image_path": img_path, } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): data_to_iterate = [] for classname in Cifar10Dataset._CLASSES: if self.split == DatasetSplit.TRAIN: if classname != self.classname: continue class_dir = os.path.join(self.source, self.split.value, str(classname)) for fn in os.listdir(class_dir): img_path = os.path.join(class_dir, fn) data_to_iterate.append([img_path, classname]) return data_to_iterate ================================================ FILE: datasets/mvtec.py ================================================ import os from enum import Enum import PIL import torch from torchvision import transforms _CLASSNAMES = [ "bottle", "cable", "capsule", "carpet", "grid", "hazelnut", "leather", "metal_nut", "pill", "screw", "tile", "toothbrush", "transistor", "wood", "zipper", ] IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class MVTecDataset(torch.utils.data.Dataset): """ PyTorch Dataset for MVTec. """ def __init__( self, source, classname, resize=256, imagesize=224, split=DatasetSplit.TRAIN, train_val_split=1.0, rotate_degrees=0, translate=0, brightness_factor=0, contrast_factor=0, saturation_factor=0, gray_p=0, h_flip_p=0, v_flip_p=0, scale=0, **kwargs, ): """ Args: source: [str]. Path to the MVTec data folder. classname: [str or None]. Name of MVTec class that should be provided in this dataset. If None, the datasets iterates over all available images. resize: [int]. (Square) Size the loaded image initially gets resized to. imagesize: [int]. (Square) Size the resized loaded image gets (center-)cropped to. split: [enum-option]. Indicates if training or test split of the data should be used. Has to be an option taken from DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that mvtec.DatasetSplit.TEST will also load mask data. """ super().__init__() self.source = source self.split = split self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES self.train_val_split = train_val_split self.transform_std = IMAGENET_STD self.transform_mean = IMAGENET_MEAN self.imgpaths_per_class, self.data_to_iterate = self.get_image_data() self.transform_img = [ transforms.Resize(resize), # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR), transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), transforms.RandomHorizontalFlip(h_flip_p), transforms.RandomVerticalFlip(v_flip_p), transforms.RandomGrayscale(gray_p), transforms.RandomAffine(rotate_degrees, translate=(translate, translate), scale=(1.0-scale, 1.0+scale), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(imagesize), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) self.transform_mask = [ transforms.Resize(resize), transforms.CenterCrop(imagesize), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.imagesize = (3, imagesize, imagesize) def __getitem__(self, idx): classname, anomaly, image_path, mask_path = self.data_to_iterate[idx] image = PIL.Image.open(image_path).convert("RGB") image = self.transform_img(image) if self.split == DatasetSplit.TEST and mask_path is not None: mask = PIL.Image.open(mask_path) mask = self.transform_mask(mask) else: mask = torch.zeros([1, *image.size()[1:]]) return { "image": image, "mask": mask, "classname": classname, "anomaly": anomaly, "is_anomaly": int(anomaly != "good"), "image_name": "/".join(image_path.split("/")[-4:]), "image_path": image_path, } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): imgpaths_per_class = {} maskpaths_per_class = {} for classname in self.classnames_to_use: classpath = os.path.join(self.source, classname, self.split.value) maskpath = os.path.join(self.source, classname, "ground_truth") anomaly_types = os.listdir(classpath) imgpaths_per_class[classname] = {} maskpaths_per_class[classname] = {} for anomaly in anomaly_types: anomaly_path = os.path.join(classpath, anomaly) anomaly_files = sorted(os.listdir(anomaly_path)) imgpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_path, x) for x in anomaly_files ] if self.train_val_split < 1.0: n_images = len(imgpaths_per_class[classname][anomaly]) train_val_split_idx = int(n_images * self.train_val_split) if self.split == DatasetSplit.TRAIN: imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ classname ][anomaly][:train_val_split_idx] elif self.split == DatasetSplit.VAL: imgpaths_per_class[classname][anomaly] = imgpaths_per_class[ classname ][anomaly][train_val_split_idx:] if self.split == DatasetSplit.TEST and anomaly != "good": anomaly_mask_path = os.path.join(maskpath, anomaly) anomaly_mask_files = sorted(os.listdir(anomaly_mask_path)) maskpaths_per_class[classname][anomaly] = [ os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files ] else: maskpaths_per_class[classname]["good"] = None # Unrolls the data dictionary to an easy-to-iterate list. data_to_iterate = [] for classname in sorted(imgpaths_per_class.keys()): for anomaly in sorted(imgpaths_per_class[classname].keys()): for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]): data_tuple = [classname, anomaly, image_path] if self.split == DatasetSplit.TEST and anomaly != "good": data_tuple.append(maskpaths_per_class[classname][anomaly][i]) else: data_tuple.append(None) data_to_iterate.append(data_tuple) return imgpaths_per_class, data_to_iterate ================================================ FILE: datasets/sdd.py ================================================ import os from enum import Enum import pickle import cv2 import PIL import torch from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class SDDDataset(torch.utils.data.Dataset): """ PyTorch Dataset for MVTec. """ def __init__( self, source, classname, resize=256, imagesize=224, split=DatasetSplit.TRAIN, train_val_split=1.0, rotate_degrees=0, translate=0, brightness_factor=0, contrast_factor=0, saturation_factor=0, gray_p=0, h_flip_p=0, v_flip_p=0, scale=0, **kwargs, ): """ Args: source: [str]. Path to the MVTec data folder. classname: [str or None]. Name of MVTec class that should be provided in this dataset. If None, the datasets iterates over all available images. resize: [int]. (Square) Size the loaded image initially gets resized to. imagesize: [int]. (Square) Size the resized loaded image gets (center-)cropped to. split: [enum-option]. Indicates if training or test split of the data should be used. Has to be an option taken from DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that mvtec.DatasetSplit.TEST will also load mask data. """ super().__init__() self.source = source self.split = split self.split_id = int(classname) self.train_val_split = train_val_split self.transform_std = IMAGENET_STD self.transform_mean = IMAGENET_MEAN self.data_to_iterate = self.get_image_data() self.transform_img = [ transforms.Resize((int(resize*2.5+.5), resize)), # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR), transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), transforms.RandomHorizontalFlip(h_flip_p), transforms.RandomVerticalFlip(v_flip_p), transforms.RandomGrayscale(gray_p), transforms.RandomAffine(rotate_degrees, translate=(translate, translate), scale=(1.0-scale, 1.0+scale), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) self.transform_mask = [ transforms.Resize((int(resize*2.5+.5), resize)), transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize) # if self.split == DatasetSplit.TEST: # for i in range(len(self.data_to_iterate)): # self.__getitem__(i) def __getitem__(self, idx): data = self.data_to_iterate[idx] image = PIL.Image.open(data["img"]).convert("RGB") image = self.transform_img(image) if self.split == DatasetSplit.TEST and data["anomaly"] == 1: mask = PIL.Image.open(data["label"]) mask = self.transform_mask(mask) else: mask = torch.zeros([1, *image.size()[1:]]) return { "image": image, "mask": mask, "classname": str(self.split_id), "anomaly": data["anomaly"], "is_anomaly": data["anomaly"], "image_path": data["img"], } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): data_ids = [] with open(os.path.join(self.source, "KolektorSDD-training-splits", "split.pyb"), "rb") as f: train_ids, test_ids, _ = pickle.load(f) if self.split == DatasetSplit.TRAIN: data_ids = train_ids[self.split_id] else: data_ids = test_ids[self.split_id] data = {} for data_id in data_ids: item_dir = os.path.join(self.source, data_id) fns = os.listdir(item_dir) part_ids = [os.path.splitext(fn)[0] for fn in fns if fn.endswith("jpg")] parts = {part_id:{"img":"", "label":"", "anomaly":0} for part_id in part_ids} for part_id in parts: for fn in fns: if part_id in fn: if "label" in fn: label = cv2.imread(os.path.join(item_dir, fn)) if label.sum() > 0: parts[part_id]["anomaly"] = 1 parts[part_id]["label"] = os.path.join(item_dir, fn) else: parts[part_id]["img"] = os.path.join(item_dir, fn) for k, v in parts.items(): if self.split == DatasetSplit.TRAIN and v["anomaly"] == 1: continue data[data_id + '_' + k] = v return list(data.values()) ================================================ FILE: datasets/sdd2.py ================================================ import os from enum import Enum import pickle import cv2 import PIL import torch from torchvision import transforms IMAGENET_MEAN = [0.485, 0.456, 0.406] IMAGENET_STD = [0.229, 0.224, 0.225] class DatasetSplit(Enum): TRAIN = "train" VAL = "val" TEST = "test" class SDD2Dataset(torch.utils.data.Dataset): """ PyTorch Dataset for MVTec. """ def __init__( self, source, classname, resize=256, imagesize=224, split=DatasetSplit.TRAIN, train_val_split=1.0, rotate_degrees=0, translate=0, brightness_factor=0, contrast_factor=0, saturation_factor=0, gray_p=0, h_flip_p=0, v_flip_p=0, scale=0, **kwargs, ): """ Args: source: [str]. Path to the MVTec data folder. classname: [str or None]. Name of MVTec class that should be provided in this dataset. If None, the datasets iterates over all available images. resize: [int]. (Square) Size the loaded image initially gets resized to. imagesize: [int]. (Square) Size the resized loaded image gets (center-)cropped to. split: [enum-option]. Indicates if training or test split of the data should be used. Has to be an option taken from DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that mvtec.DatasetSplit.TEST will also load mask data. """ super().__init__() self.source = source self.split = split self.train_val_split = train_val_split self.transform_std = IMAGENET_STD self.transform_mean = IMAGENET_MEAN self.data_to_iterate = self.get_image_data() self.transform_img = [ transforms.Resize((int(resize*2.5+.5), resize)), # transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR), transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor), transforms.RandomHorizontalFlip(h_flip_p), transforms.RandomVerticalFlip(v_flip_p), transforms.RandomGrayscale(gray_p), transforms.RandomAffine(rotate_degrees, translate=(translate, translate), scale=(1.0-scale, 1.0+scale), interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)), transforms.ToTensor(), transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ] self.transform_img = transforms.Compose(self.transform_img) self.transform_mask = [ transforms.Resize((int(resize*2.5+.5), resize)), transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)), transforms.ToTensor(), ] self.transform_mask = transforms.Compose(self.transform_mask) self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize) # if self.split == DatasetSplit.TEST: # for i in range(len(self.data_to_iterate)): # self.__getitem__(i) def __getitem__(self, idx): img_path, gt_path, is_anomaly = self.data_to_iterate[idx] image = PIL.Image.open(img_path).convert("RGB") image = self.transform_img(image) if self.split == DatasetSplit.TEST and is_anomaly: mask = PIL.Image.open(gt_path) mask = self.transform_mask(mask) else: mask = torch.zeros([1, *image.size()[1:]]) return { "image": image, "mask": mask, "classname": "", "anomaly": is_anomaly, "is_anomaly": is_anomaly, "image_path": img_path, } def __len__(self): return len(self.data_to_iterate) def get_image_data(self): data_ids = [] data_dir = os.path.join(self.source, "train" if self.split == DatasetSplit.TRAIN else "test") data = [] test = [0, 0] for fn in os.listdir(data_dir): if "GT" not in fn: data_id = os.path.splitext(fn)[0] img_path = os.path.join(data_dir, fn) gt_path = os.path.join(data_dir, f"{data_id}_GT.png") assert os.path.exists(img_path) assert os.path.exists(gt_path), gt_path gt = cv2.imread(gt_path) is_anomaly = gt.sum() > 0 if is_anomaly: test[1] = test[1] + 1 else: test[0] = test[0] + 1 if self.split == DatasetSplit.TRAIN and is_anomaly: continue data.append([img_path, gt_path, gt.sum() > 0]) return data ================================================ FILE: main.py ================================================ # ------------------------------------------------------------------ # SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf) # Github source: https://github.com/DonaldRR/SimpleNet # Licensed under the MIT License [see LICENSE for details] # The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection) # ------------------------------------------------------------------ import logging import os import sys import click import numpy as np import torch sys.path.append("src") import backbones import common import metrics import simplenet import utils LOGGER = logging.getLogger(__name__) _DATASETS = { "mvtec": ["datasets.mvtec", "MVTecDataset"], } @click.group(chain=True) @click.option("--results_path", type=str) @click.option("--gpu", type=int, default=[0], multiple=True, show_default=True) @click.option("--seed", type=int, default=0, show_default=True) @click.option("--log_group", type=str, default="group") @click.option("--log_project", type=str, default="project") @click.option("--run_name", type=str, default="test") @click.option("--test", is_flag=True) @click.option("--save_segmentation_images", is_flag=True, default=False, show_default=True) def main(**kwargs): pass @main.result_callback() def run( methods, results_path, gpu, seed, log_group, log_project, run_name, test, save_segmentation_images ): methods = {key: item for (key, item) in methods} run_save_path = utils.create_storage_folder( results_path, log_project, log_group, run_name, mode="overwrite" ) pid = os.getpid() list_of_dataloaders = methods["get_dataloaders"](seed) device = utils.set_torch_device(gpu) result_collect = [] for dataloader_count, dataloaders in enumerate(list_of_dataloaders): LOGGER.info( "Evaluating dataset [{}] ({}/{})...".format( dataloaders["training"].name, dataloader_count + 1, len(list_of_dataloaders), ) ) utils.fix_seeds(seed, device) dataset_name = dataloaders["training"].name imagesize = dataloaders["training"].dataset.imagesize simplenet_list = methods["get_simplenet"](imagesize, device) models_dir = os.path.join(run_save_path, "models") os.makedirs(models_dir, exist_ok=True) for i, SimpleNet in enumerate(simplenet_list): # torch.cuda.empty_cache() if SimpleNet.backbone.seed is not None: utils.fix_seeds(SimpleNet.backbone.seed, device) LOGGER.info( "Training models ({}/{})".format(i + 1, len(simplenet_list)) ) # torch.cuda.empty_cache() SimpleNet.set_model_dir(os.path.join(models_dir, f"{i}"), dataset_name) if not test: i_auroc, p_auroc, pro_auroc = SimpleNet.train(dataloaders["training"], dataloaders["testing"]) else: # BUG: the following line is not using. Set test with True by default. # i_auroc, p_auroc, pro_auroc = SimpleNet.test(dataloaders["training"], dataloaders["testing"], save_segmentation_images) print("Warning: Pls set test with true by default") result_collect.append( { "dataset_name": dataset_name, "instance_auroc": i_auroc, # auroc, "full_pixel_auroc": p_auroc, # full_pixel_auroc, "anomaly_pixel_auroc": pro_auroc, } ) for key, item in result_collect[-1].items(): if key != "dataset_name": LOGGER.info("{0}: {1:3.3f}".format(key, item)) LOGGER.info("\n\n-----\n") # Store all results and mean scores to a csv-file. result_metric_names = list(result_collect[-1].keys())[1:] result_dataset_names = [results["dataset_name"] for results in result_collect] result_scores = [list(results.values())[1:] for results in result_collect] utils.compute_and_store_final_results( run_save_path, result_scores, column_names=result_metric_names, row_names=result_dataset_names, ) @main.command("net") @click.option("--backbone_names", "-b", type=str, multiple=True, default=[]) @click.option("--layers_to_extract_from", "-le", type=str, multiple=True, default=[]) @click.option("--pretrain_embed_dimension", type=int, default=1024) @click.option("--target_embed_dimension", type=int, default=1024) @click.option("--patchsize", type=int, default=3) @click.option("--embedding_size", type=int, default=1024) @click.option("--meta_epochs", type=int, default=1) @click.option("--aed_meta_epochs", type=int, default=1) @click.option("--gan_epochs", type=int, default=1) @click.option("--dsc_layers", type=int, default=2) @click.option("--dsc_hidden", type=int, default=None) @click.option("--noise_std", type=float, default=0.05) @click.option("--dsc_margin", type=float, default=0.8) @click.option("--dsc_lr", type=float, default=0.0002) @click.option("--auto_noise", type=float, default=0) @click.option("--train_backbone", is_flag=True) @click.option("--cos_lr", is_flag=True) @click.option("--pre_proj", type=int, default=0) @click.option("--proj_layer_type", type=int, default=0) @click.option("--mix_noise", type=int, default=1) def net( backbone_names, layers_to_extract_from, pretrain_embed_dimension, target_embed_dimension, patchsize, embedding_size, meta_epochs, aed_meta_epochs, gan_epochs, noise_std, dsc_layers, dsc_hidden, dsc_margin, dsc_lr, auto_noise, train_backbone, cos_lr, pre_proj, proj_layer_type, mix_noise, ): backbone_names = list(backbone_names) if len(backbone_names) > 1: layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))] for layer in layers_to_extract_from: idx = int(layer.split(".")[0]) layer = ".".join(layer.split(".")[1:]) layers_to_extract_from_coll[idx].append(layer) else: layers_to_extract_from_coll = [layers_to_extract_from] def get_simplenet(input_shape, device): simplenets = [] for backbone_name, layers_to_extract_from in zip( backbone_names, layers_to_extract_from_coll ): backbone_seed = None if ".seed-" in backbone_name: backbone_name, backbone_seed = backbone_name.split(".seed-")[0], int( backbone_name.split("-")[-1] ) backbone = backbones.load(backbone_name) backbone.name, backbone.seed = backbone_name, backbone_seed simplenet_inst = simplenet.SimpleNet(device) simplenet_inst.load( backbone=backbone, layers_to_extract_from=layers_to_extract_from, device=device, input_shape=input_shape, pretrain_embed_dimension=pretrain_embed_dimension, target_embed_dimension=target_embed_dimension, patchsize=patchsize, embedding_size=embedding_size, meta_epochs=meta_epochs, aed_meta_epochs=aed_meta_epochs, gan_epochs=gan_epochs, noise_std=noise_std, dsc_layers=dsc_layers, dsc_hidden=dsc_hidden, dsc_margin=dsc_margin, dsc_lr=dsc_lr, auto_noise=auto_noise, train_backbone=train_backbone, cos_lr=cos_lr, pre_proj=pre_proj, proj_layer_type=proj_layer_type, mix_noise=mix_noise, ) simplenets.append(simplenet_inst) return simplenets return ("get_simplenet", get_simplenet) @main.command("dataset") @click.argument("name", type=str) @click.argument("data_path", type=click.Path(exists=True, file_okay=False)) @click.option("--subdatasets", "-d", multiple=True, type=str, required=True) @click.option("--train_val_split", type=float, default=1, show_default=True) @click.option("--batch_size", default=2, type=int, show_default=True) @click.option("--num_workers", default=2, type=int, show_default=True) @click.option("--resize", default=256, type=int, show_default=True) @click.option("--imagesize", default=224, type=int, show_default=True) @click.option("--rotate_degrees", default=0, type=int) @click.option("--translate", default=0, type=float) @click.option("--scale", default=0.0, type=float) @click.option("--brightness", default=0.0, type=float) @click.option("--contrast", default=0.0, type=float) @click.option("--saturation", default=0.0, type=float) @click.option("--gray", default=0.0, type=float) @click.option("--hflip", default=0.0, type=float) @click.option("--vflip", default=0.0, type=float) @click.option("--augment", is_flag=True) def dataset( name, data_path, subdatasets, train_val_split, batch_size, resize, imagesize, num_workers, rotate_degrees, translate, scale, brightness, contrast, saturation, gray, hflip, vflip, augment, ): dataset_info = _DATASETS[name] dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]]) def get_dataloaders(seed): dataloaders = [] for subdataset in subdatasets: train_dataset = dataset_library.__dict__[dataset_info[1]]( data_path, classname=subdataset, resize=resize, train_val_split=train_val_split, imagesize=imagesize, split=dataset_library.DatasetSplit.TRAIN, seed=seed, rotate_degrees=rotate_degrees, translate=translate, brightness_factor=brightness, contrast_factor=contrast, saturation_factor=saturation, gray_p=gray, h_flip_p=hflip, v_flip_p=vflip, scale=scale, augment=augment, ) test_dataset = dataset_library.__dict__[dataset_info[1]]( data_path, classname=subdataset, resize=resize, imagesize=imagesize, split=dataset_library.DatasetSplit.TEST, seed=seed, ) LOGGER.info(f"Dataset: train={len(train_dataset)} test={len(test_dataset)}") train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, prefetch_factor=2, pin_memory=True, ) test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, prefetch_factor=2, pin_memory=True, ) train_dataloader.name = name if subdataset is not None: train_dataloader.name += "_" + subdataset if train_val_split < 1: val_dataset = dataset_library.__dict__[dataset_info[1]]( data_path, classname=subdataset, resize=resize, train_val_split=train_val_split, imagesize=imagesize, split=dataset_library.DatasetSplit.VAL, seed=seed, ) val_dataloader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, prefetch_factor=4, pin_memory=True, ) else: val_dataloader = None dataloader_dict = { "training": train_dataloader, "validation": val_dataloader, "testing": test_dataloader, } dataloaders.append(dataloader_dict) return dataloaders return ("get_dataloaders", get_dataloaders) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) LOGGER.info("Command line arguments: {}".format(" ".join(sys.argv))) main() ================================================ FILE: metrics.py ================================================ """Anomaly metrics.""" import cv2 import numpy as np from sklearn import metrics def compute_imagewise_retrieval_metrics( anomaly_prediction_weights, anomaly_ground_truth_labels ): """ Computes retrieval statistics (AUROC, FPR, TPR). Args: anomaly_prediction_weights: [np.array or list] [N] Assignment weights per image. Higher indicates higher probability of being an anomaly. anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1 if image is an anomaly, 0 if not. """ fpr, tpr, thresholds = metrics.roc_curve( anomaly_ground_truth_labels, anomaly_prediction_weights ) auroc = metrics.roc_auc_score( anomaly_ground_truth_labels, anomaly_prediction_weights ) precision, recall, _ = metrics.precision_recall_curve( anomaly_ground_truth_labels, anomaly_prediction_weights ) auc_pr = metrics.auc(recall, precision) return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds} def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks): """ Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations and ground truth segmentation masks. Args: anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains generated segmentation masks. ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains predefined ground truth segmentation masks """ if isinstance(anomaly_segmentations, list): anomaly_segmentations = np.stack(anomaly_segmentations) if isinstance(ground_truth_masks, list): ground_truth_masks = np.stack(ground_truth_masks) flat_anomaly_segmentations = anomaly_segmentations.ravel() flat_ground_truth_masks = ground_truth_masks.ravel() fpr, tpr, thresholds = metrics.roc_curve( flat_ground_truth_masks.astype(int), flat_anomaly_segmentations ) auroc = metrics.roc_auc_score( flat_ground_truth_masks.astype(int), flat_anomaly_segmentations ) precision, recall, thresholds = metrics.precision_recall_curve( flat_ground_truth_masks.astype(int), flat_anomaly_segmentations ) F1_scores = np.divide( 2 * precision * recall, precision + recall, out=np.zeros_like(precision), where=(precision + recall) != 0, ) optimal_threshold = thresholds[np.argmax(F1_scores)] predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int) fpr_optim = np.mean(predictions > flat_ground_truth_masks) fnr_optim = np.mean(predictions < flat_ground_truth_masks) return { "auroc": auroc, "fpr": fpr, "tpr": tpr, "optimal_threshold": optimal_threshold, "optimal_fpr": fpr_optim, "optimal_fnr": fnr_optim, } import pandas as pd from skimage import measure def compute_pro(masks, amaps, num_th=200): df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) binary_amaps = np.zeros_like(amaps, dtype=np.bool) min_th = amaps.min() max_th = amaps.max() delta = (max_th - min_th) / num_th k = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) for th in np.arange(min_th, max_th, delta): binary_amaps[amaps <= th] = 0 binary_amaps[amaps > th] = 1 pros = [] for binary_amap, mask in zip(binary_amaps, masks): binary_amap = cv2.dilate(binary_amap.astype(np.uint8), k) for region in measure.regionprops(measure.label(mask)): axes0_ids = region.coords[:, 0] axes1_ids = region.coords[:, 1] tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() pros.append(tp_pixels / region.area) inverse_masks = 1 - masks fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() fpr = fp_pixels / inverse_masks.sum() df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True) # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 df = df[df["fpr"] < 0.3] df["fpr"] = df["fpr"] / df["fpr"].max() pro_auc = metrics.auc(df["fpr"], df["pro"]) return pro_auc ================================================ FILE: resnet.py ================================================ import torch from torch import Tensor import torch.nn as nn from typing import Type, Any, Callable, Union, List, Optional try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] model_urls = { 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', } PADDING_MODE = 'reflect' # {'zeros', 'reflect', 'replicate', 'circular'} # PADDING_MODE = 'zeros' # {'zeros', 'reflect', 'replicate', 'circular'} def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, padding_mode = PADDING_MODE, groups=groups, bias=False, dilation=dilation) def conv1x1(in_planes: int, out_planes: int, stride: int = 1): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): expansion: int = 1 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None ): super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d if groups != 1 or base_width != 64: raise ValueError('BasicBlock only supports groups=1 and base_width=64') if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") # Both self.conv1 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) self.downsample = downsample self.stride = stride def forward(self, x: Tensor): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class Bottleneck(nn.Module): # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) # while original implementation places the stride at the first 1x1 convolution(self.conv1) # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. expansion: int = 4 def __init__( self, inplanes: int, planes: int, stride: int = 1, downsample: Optional[nn.Module] = None, groups: int = 1, base_width: int = 64, dilation: int = 1, norm_layer: Optional[Callable[..., nn.Module]] = None ): super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.)) * groups # Both self.conv2 and self.downsample layers downsample the input when stride != 1 self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x: Tensor): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out class ResNet(nn.Module): def __init__( self, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], num_classes: int = 1000, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None ): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layer self.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, padding_mode = PADDING_MODE, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) #self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, stride: int = 1, dilate: bool = False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion),) layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def _forward_impl(self, x: Tensor): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # remove extra layers #x = self.avgpool(x) #x = torch.flatten(x, 1) #x = self.fc(x) return x def forward(self, x: Tensor): return self._forward_impl(x) def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ): model = ResNet(block, layers, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) #model.load_state_dict(state_dict) model.load_state_dict(state_dict, strict=False) return model def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 4 return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 8 return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) : r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) # ============================================================================================================== # Model Class Definition # ============================================================================================================== ================================================ FILE: run.sh ================================================ datapath=/data4/MVTec_ad datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather') dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done)) python3 main.py \ --gpu 4 \ --seed 0 \ --log_group simplenet_mvtec \ --log_project MVTecAD_Results \ --results_path results \ --run_name run \ net \ -b wideresnet50 \ -le layer2 \ -le layer3 \ --pretrain_embed_dimension 1536 \ --target_embed_dimension 1536 \ --patchsize 3 \ --meta_epochs 40 \ --embedding_size 256 \ --gan_epochs 4 \ --noise_std 0.015 \ --dsc_hidden 1024 \ --dsc_layers 2 \ --dsc_margin .5 \ --pre_proj 1 \ dataset \ --batch_size 8 \ --resize 329 \ --imagesize 288 "${dataset_flags[@]}" mvtec $datapath ================================================ FILE: simplenet.py ================================================ # ------------------------------------------------------------------ # SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf) # Github source: https://github.com/DonaldRR/SimpleNet # Licensed under the MIT License [see LICENSE for details] # The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection) # ------------------------------------------------------------------ """detection methods.""" import logging import os import pickle from collections import OrderedDict import math import numpy as np import torch import torch.nn.functional as F import tqdm from torch.utils.tensorboard import SummaryWriter import common import metrics from utils import plot_segmentation_images LOGGER = logging.getLogger(__name__) def init_weight(m): if isinstance(m, torch.nn.Linear): torch.nn.init.xavier_normal_(m.weight) elif isinstance(m, torch.nn.Conv2d): torch.nn.init.xavier_normal_(m.weight) class Discriminator(torch.nn.Module): def __init__(self, in_planes, n_layers=1, hidden=None): super(Discriminator, self).__init__() _hidden = in_planes if hidden is None else hidden self.body = torch.nn.Sequential() for i in range(n_layers-1): _in = in_planes if i == 0 else _hidden _hidden = int(_hidden // 1.5) if hidden is None else hidden self.body.add_module('block%d'%(i+1), torch.nn.Sequential( torch.nn.Linear(_in, _hidden), torch.nn.BatchNorm1d(_hidden), torch.nn.LeakyReLU(0.2) )) self.tail = torch.nn.Linear(_hidden, 1, bias=False) self.apply(init_weight) def forward(self,x): x = self.body(x) x = self.tail(x) return x class Projection(torch.nn.Module): def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0): super(Projection, self).__init__() if out_planes is None: out_planes = in_planes self.layers = torch.nn.Sequential() _in = None _out = None for i in range(n_layers): _in = in_planes if i == 0 else _out _out = out_planes self.layers.add_module(f"{i}fc", torch.nn.Linear(_in, _out)) if i < n_layers - 1: # if layer_type > 0: # self.layers.add_module(f"{i}bn", # torch.nn.BatchNorm1d(_out)) if layer_type > 1: self.layers.add_module(f"{i}relu", torch.nn.LeakyReLU(.2)) self.apply(init_weight) def forward(self, x): # x = .1 * self.layers(x) + x x = self.layers(x) return x class TBWrapper: def __init__(self, log_dir): self.g_iter = 0 self.logger = SummaryWriter(log_dir=log_dir) def step(self): self.g_iter += 1 class SimpleNet(torch.nn.Module): def __init__(self, device): """anomaly detection class.""" super(SimpleNet, self).__init__() self.device = device def load( self, backbone, layers_to_extract_from, device, input_shape, pretrain_embed_dimension, # 1536 target_embed_dimension, # 1536 patchsize=3, # 3 patchstride=1, embedding_size=None, # 256 meta_epochs=1, # 40 aed_meta_epochs=1, gan_epochs=1, # 4 noise_std=0.05, mix_noise=1, noise_type="GAU", dsc_layers=2, # 2 dsc_hidden=None, # 1024 dsc_margin=.8, # .5 dsc_lr=0.0002, train_backbone=False, auto_noise=0, cos_lr=False, lr=1e-3, pre_proj=0, # 1 proj_layer_type=0, **kwargs, ): pid = os.getpid() def show_mem(): return(psutil.Process(pid).memory_info()) self.backbone = backbone.to(device) self.layers_to_extract_from = layers_to_extract_from self.input_shape = input_shape self.device = device self.patch_maker = PatchMaker(patchsize, stride=patchstride) self.forward_modules = torch.nn.ModuleDict({}) feature_aggregator = common.NetworkFeatureAggregator( self.backbone, self.layers_to_extract_from, self.device, train_backbone ) feature_dimensions = feature_aggregator.feature_dimensions(input_shape) self.forward_modules["feature_aggregator"] = feature_aggregator preprocessing = common.Preprocessing( feature_dimensions, pretrain_embed_dimension ) self.forward_modules["preprocessing"] = preprocessing self.target_embed_dimension = target_embed_dimension preadapt_aggregator = common.Aggregator( target_dim=target_embed_dimension ) _ = preadapt_aggregator.to(self.device) self.forward_modules["preadapt_aggregator"] = preadapt_aggregator self.anomaly_segmentor = common.RescaleSegmentor( device=self.device, target_size=input_shape[-2:] ) self.embedding_size = embedding_size if embedding_size is not None else self.target_embed_dimension self.meta_epochs = meta_epochs self.lr = lr self.cos_lr = cos_lr self.train_backbone = train_backbone if self.train_backbone: self.backbone_opt = torch.optim.AdamW(self.forward_modules["feature_aggregator"].backbone.parameters(), lr) # AED self.aed_meta_epochs = aed_meta_epochs self.pre_proj = pre_proj if self.pre_proj > 0: self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj, proj_layer_type) self.pre_projection.to(self.device) self.proj_opt = torch.optim.AdamW(self.pre_projection.parameters(), lr*.1) # Discriminator self.auto_noise = [auto_noise, None] self.dsc_lr = dsc_lr self.gan_epochs = gan_epochs self.mix_noise = mix_noise self.noise_type = noise_type self.noise_std = noise_std self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden) self.discriminator.to(self.device) self.dsc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.dsc_lr, weight_decay=1e-5) self.dsc_schl = torch.optim.lr_scheduler.CosineAnnealingLR(self.dsc_opt, (meta_epochs - aed_meta_epochs) * gan_epochs, self.dsc_lr*.4) self.dsc_margin= dsc_margin self.model_dir = "" self.dataset_name = "" self.tau = 1 self.logger = None def set_model_dir(self, model_dir, dataset_name): self.model_dir = model_dir os.makedirs(self.model_dir, exist_ok=True) self.ckpt_dir = os.path.join(self.model_dir, dataset_name) os.makedirs(self.ckpt_dir, exist_ok=True) self.tb_dir = os.path.join(self.ckpt_dir, "tb") os.makedirs(self.tb_dir, exist_ok=True) self.logger = TBWrapper(self.tb_dir) #SummaryWriter(log_dir=tb_dir) def embed(self, data): if isinstance(data, torch.utils.data.DataLoader): features = [] for image in data: if isinstance(image, dict): image = image["image"] input_image = image.to(torch.float).to(self.device) with torch.no_grad(): features.append(self._embed(input_image)) return features return self._embed(data) def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=False): """Returns feature embeddings for images.""" B = len(images) if not evaluation and self.train_backbone: self.forward_modules["feature_aggregator"].train() features = self.forward_modules["feature_aggregator"](images, eval=evaluation) else: _ = self.forward_modules["feature_aggregator"].eval() with torch.no_grad(): features = self.forward_modules["feature_aggregator"](images) features = [features[layer] for layer in self.layers_to_extract_from] for i, feat in enumerate(features): if len(feat.shape) == 3: B, L, C = feat.shape features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2) features = [ self.patch_maker.patchify(x, return_spatial_info=True) for x in features ] patch_shapes = [x[1] for x in features] features = [x[0] for x in features] ref_num_patches = patch_shapes[0] for i in range(1, len(features)): _features = features[i] patch_dims = patch_shapes[i] # TODO(pgehler): Add comments _features = _features.reshape( _features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:] ) _features = _features.permute(0, -3, -2, -1, 1, 2) perm_base_shape = _features.shape _features = _features.reshape(-1, *_features.shape[-2:]) _features = F.interpolate( _features.unsqueeze(1), size=(ref_num_patches[0], ref_num_patches[1]), mode="bilinear", align_corners=False, ) _features = _features.squeeze(1) _features = _features.reshape( *perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1] ) _features = _features.permute(0, -2, -1, 1, 2, 3) _features = _features.reshape(len(_features), -1, *_features.shape[-3:]) features[i] = _features features = [x.reshape(-1, *x.shape[-3:]) for x in features] # As different feature backbones & patching provide differently # sized features, these are brought into the correct form here. features = self.forward_modules["preprocessing"](features) # pooling each feature to same channel and stack together features = self.forward_modules["preadapt_aggregator"](features) # further pooling return features, patch_shapes def test(self, training_data, test_data, save_segmentation_images): ckpt_path = os.path.join(self.ckpt_dir, "models.ckpt") if os.path.exists(ckpt_path): state_dicts = torch.load(ckpt_path, map_location=self.device) if "pretrained_enc" in state_dicts: self.feature_enc.load_state_dict(state_dicts["pretrained_enc"]) if "pretrained_dec" in state_dicts: self.feature_dec.load_state_dict(state_dicts["pretrained_dec"]) aggregator = {"scores": [], "segmentations": [], "features": []} scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data) aggregator["scores"].append(scores) aggregator["segmentations"].append(segmentations) aggregator["features"].append(features) scores = np.array(aggregator["scores"]) min_scores = scores.min(axis=-1).reshape(-1, 1) max_scores = scores.max(axis=-1).reshape(-1, 1) scores = (scores - min_scores) / (max_scores - min_scores) scores = np.mean(scores, axis=0) segmentations = np.array(aggregator["segmentations"]) min_scores = ( segmentations.reshape(len(segmentations), -1) .min(axis=-1) .reshape(-1, 1, 1, 1) ) max_scores = ( segmentations.reshape(len(segmentations), -1) .max(axis=-1) .reshape(-1, 1, 1, 1) ) segmentations = (segmentations - min_scores) / (max_scores - min_scores) segmentations = np.mean(segmentations, axis=0) anomaly_labels = [ x[1] != "good" for x in test_data.dataset.data_to_iterate ] if save_segmentation_images: self.save_segmentation_images(test_data, segmentations, scores) auroc = metrics.compute_imagewise_retrieval_metrics( scores, anomaly_labels )["auroc"] # Compute PRO score & PW Auroc for all images pixel_scores = metrics.compute_pixelwise_retrieval_metrics( segmentations, masks_gt ) full_pixel_auroc = pixel_scores["auroc"] return auroc, full_pixel_auroc def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks_gt): scores = np.squeeze(np.array(scores)) img_min_scores = scores.min(axis=-1) img_max_scores = scores.max(axis=-1) scores = (scores - img_min_scores) / (img_max_scores - img_min_scores) # scores = np.mean(scores, axis=0) auroc = metrics.compute_imagewise_retrieval_metrics( scores, labels_gt )["auroc"] if len(masks_gt) > 0: segmentations = np.array(segmentations) min_scores = ( segmentations.reshape(len(segmentations), -1) .min(axis=-1) .reshape(-1, 1, 1, 1) ) max_scores = ( segmentations.reshape(len(segmentations), -1) .max(axis=-1) .reshape(-1, 1, 1, 1) ) norm_segmentations = np.zeros_like(segmentations) for min_score, max_score in zip(min_scores, max_scores): norm_segmentations += (segmentations - min_score) / max(max_score - min_score, 1e-2) norm_segmentations = norm_segmentations / len(scores) # Compute PRO score & PW Auroc for all images pixel_scores = metrics.compute_pixelwise_retrieval_metrics( norm_segmentations, masks_gt) # segmentations, masks_gt full_pixel_auroc = pixel_scores["auroc"] pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)), norm_segmentations) else: full_pixel_auroc = -1 pro = -1 return auroc, full_pixel_auroc, pro def train(self, training_data, test_data): state_dict = {} ckpt_path = os.path.join(self.ckpt_dir, "ckpt.pth") if os.path.exists(ckpt_path): state_dict = torch.load(ckpt_path, map_location=self.device) if 'discriminator' in state_dict: self.discriminator.load_state_dict(state_dict['discriminator']) if "pre_projection" in state_dict: self.pre_projection.load_state_dict(state_dict["pre_projection"]) else: self.load_state_dict(state_dict, strict=False) self.predict(training_data, "train_") scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data) auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt) return auroc, full_pixel_auroc, anomaly_pixel_auroc def update_state_dict(d): state_dict["discriminator"] = OrderedDict({ k:v.detach().cpu() for k, v in self.discriminator.state_dict().items()}) if self.pre_proj > 0: state_dict["pre_projection"] = OrderedDict({ k:v.detach().cpu() for k, v in self.pre_projection.state_dict().items()}) best_record = None for i_mepoch in range(self.meta_epochs): self._train_discriminator(training_data) # torch.cuda.empty_cache() scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data) auroc, full_pixel_auroc, pro = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt) self.logger.logger.add_scalar("i-auroc", auroc, i_mepoch) self.logger.logger.add_scalar("p-auroc", full_pixel_auroc, i_mepoch) self.logger.logger.add_scalar("pro", pro, i_mepoch) if best_record is None: best_record = [auroc, full_pixel_auroc, pro] update_state_dict(state_dict) # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()}) else: if auroc > best_record[0]: best_record = [auroc, full_pixel_auroc, pro] update_state_dict(state_dict) # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()}) elif auroc == best_record[0] and full_pixel_auroc > best_record[1]: best_record[1] = full_pixel_auroc best_record[2] = pro update_state_dict(state_dict) # state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()}) print(f"----- {i_mepoch} I-AUROC:{round(auroc, 4)}(MAX:{round(best_record[0], 4)})" f" P-AUROC{round(full_pixel_auroc, 4)}(MAX:{round(best_record[1], 4)}) -----" f" PRO-AUROC{round(pro, 4)}(MAX:{round(best_record[2], 4)}) -----") torch.save(state_dict, ckpt_path) return best_record def _train_discriminator(self, input_data): """Computes and sets the support features for SPADE.""" _ = self.forward_modules.eval() if self.pre_proj > 0: self.pre_projection.train() self.discriminator.train() # self.feature_enc.eval() # self.feature_dec.eval() i_iter = 0 LOGGER.info(f"Training discriminator...") with tqdm.tqdm(total=self.gan_epochs) as pbar: for i_epoch in range(self.gan_epochs): all_loss = [] all_p_true = [] all_p_fake = [] all_p_interp = [] embeddings_list = [] for data_item in input_data: self.dsc_opt.zero_grad() if self.pre_proj > 0: self.proj_opt.zero_grad() # self.dec_opt.zero_grad() i_iter += 1 img = data_item["image"] img = img.to(torch.float).to(self.device) if self.pre_proj > 0: true_feats = self.pre_projection(self._embed(img, evaluation=False)[0]) else: true_feats = self._embed(img, evaluation=False)[0] noise_idxs = torch.randint(0, self.mix_noise, torch.Size([true_feats.shape[0]])) noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=self.mix_noise).to(self.device) # (N, K) noise = torch.stack([ torch.normal(0, self.noise_std * 1.1**(k), true_feats.shape) for k in range(self.mix_noise)], dim=1).to(self.device) # (N, K, C) noise = (noise * noise_one_hot.unsqueeze(-1)).sum(1) fake_feats = true_feats + noise scores = self.discriminator(torch.cat([true_feats, fake_feats])) true_scores = scores[:len(true_feats)] fake_scores = scores[len(fake_feats):] th = self.dsc_margin p_true = (true_scores.detach() >= th).sum() / len(true_scores) p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores) true_loss = torch.clip(-true_scores + th, min=0) fake_loss = torch.clip(fake_scores + th, min=0) self.logger.logger.add_scalar(f"p_true", p_true, self.logger.g_iter) self.logger.logger.add_scalar(f"p_fake", p_fake, self.logger.g_iter) loss = true_loss.mean() + fake_loss.mean() self.logger.logger.add_scalar("loss", loss, self.logger.g_iter) self.logger.step() loss.backward() if self.pre_proj > 0: self.proj_opt.step() if self.train_backbone: self.backbone_opt.step() self.dsc_opt.step() loss = loss.detach().cpu() all_loss.append(loss.item()) all_p_true.append(p_true.cpu().item()) all_p_fake.append(p_fake.cpu().item()) if len(embeddings_list) > 0: self.auto_noise[1] = torch.cat(embeddings_list).std(0).mean(-1) if self.cos_lr: self.dsc_schl.step() all_loss = sum(all_loss) / len(input_data) all_p_true = sum(all_p_true) / len(input_data) all_p_fake = sum(all_p_fake) / len(input_data) cur_lr = self.dsc_opt.state_dict()['param_groups'][0]['lr'] pbar_str = f"epoch:{i_epoch} loss:{round(all_loss, 5)} " pbar_str += f"lr:{round(cur_lr, 6)}" pbar_str += f" p_true:{round(all_p_true, 3)} p_fake:{round(all_p_fake, 3)}" if len(all_p_interp) > 0: pbar_str += f" p_interp:{round(sum(all_p_interp) / len(input_data), 3)}" pbar.set_description_str(pbar_str) pbar.update(1) def predict(self, data, prefix=""): if isinstance(data, torch.utils.data.DataLoader): return self._predict_dataloader(data, prefix) return self._predict(data) def _predict_dataloader(self, dataloader, prefix): """This function provides anomaly scores/maps for full dataloaders.""" _ = self.forward_modules.eval() img_paths = [] scores = [] masks = [] features = [] labels_gt = [] masks_gt = [] from sklearn.manifold import TSNE with tqdm.tqdm(dataloader, desc="Inferring...", leave=False) as data_iterator: for data in data_iterator: if isinstance(data, dict): labels_gt.extend(data["is_anomaly"].numpy().tolist()) if data.get("mask", None) is not None: masks_gt.extend(data["mask"].numpy().tolist()) image = data["image"] img_paths.extend(data['image_path']) _scores, _masks, _feats = self._predict(image) for score, mask, feat, is_anomaly in zip(_scores, _masks, _feats, data["is_anomaly"].numpy().tolist()): scores.append(score) masks.append(mask) return scores, masks, features, labels_gt, masks_gt def _predict(self, images): """Infer score and mask for a batch of images.""" images = images.to(torch.float).to(self.device) _ = self.forward_modules.eval() batchsize = images.shape[0] if self.pre_proj > 0: self.pre_projection.eval() self.discriminator.eval() with torch.no_grad(): features, patch_shapes = self._embed(images, provide_patch_shapes=True, evaluation=True) if self.pre_proj > 0: features = self.pre_projection(features) # features = features.cpu().numpy() # features = np.ascontiguousarray(features.cpu().numpy()) patch_scores = image_scores = -self.discriminator(features) patch_scores = patch_scores.cpu().numpy() image_scores = image_scores.cpu().numpy() image_scores = self.patch_maker.unpatch_scores( image_scores, batchsize=batchsize ) image_scores = image_scores.reshape(*image_scores.shape[:2], -1) image_scores = self.patch_maker.score(image_scores) patch_scores = self.patch_maker.unpatch_scores( patch_scores, batchsize=batchsize ) scales = patch_shapes[0] patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1]) features = features.reshape(batchsize, scales[0], scales[1], -1) masks, features = self.anomaly_segmentor.convert_to_segmentation(patch_scores, features) return list(image_scores), list(masks), list(features) @staticmethod def _params_file(filepath, prepend=""): return os.path.join(filepath, prepend + "params.pkl") def save_to_path(self, save_path: str, prepend: str = ""): LOGGER.info("Saving data.") self.anomaly_scorer.save( save_path, save_features_separately=False, prepend=prepend ) params = { "backbone.name": self.backbone.name, "layers_to_extract_from": self.layers_to_extract_from, "input_shape": self.input_shape, "pretrain_embed_dimension": self.forward_modules[ "preprocessing" ].output_dim, "target_embed_dimension": self.forward_modules[ "preadapt_aggregator" ].target_dim, "patchsize": self.patch_maker.patchsize, "patchstride": self.patch_maker.stride, "anomaly_scorer_num_nn": self.anomaly_scorer.n_nearest_neighbours, } with open(self._params_file(save_path, prepend), "wb") as save_file: pickle.dump(params, save_file, pickle.HIGHEST_PROTOCOL) def save_segmentation_images(self, data, segmentations, scores): image_paths = [ x[2] for x in data.dataset.data_to_iterate ] mask_paths = [ x[3] for x in data.dataset.data_to_iterate ] def image_transform(image): in_std = np.array( data.dataset.transform_std ).reshape(-1, 1, 1) in_mean = np.array( data.dataset.transform_mean ).reshape(-1, 1, 1) image = data.dataset.transform_img(image) return np.clip( (image.numpy() * in_std + in_mean) * 255, 0, 255 ).astype(np.uint8) def mask_transform(mask): return data.dataset.transform_mask(mask).numpy() plot_segmentation_images( './output', image_paths, segmentations, scores, mask_paths, image_transform=image_transform, mask_transform=mask_transform, ) # Image handling classes. class PatchMaker: def __init__(self, patchsize, top_k=0, stride=None): self.patchsize = patchsize self.stride = stride self.top_k = top_k def patchify(self, features, return_spatial_info=False): """Convert a tensor into a tensor of respective patches. Args: x: [torch.Tensor, bs x c x w x h] Returns: x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize, patchsize] """ padding = int((self.patchsize - 1) / 2) unfolder = torch.nn.Unfold( kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1 ) unfolded_features = unfolder(features) number_of_total_patches = [] for s in features.shape[-2:]: n_patches = ( s + 2 * padding - 1 * (self.patchsize - 1) - 1 ) / self.stride + 1 number_of_total_patches.append(int(n_patches)) unfolded_features = unfolded_features.reshape( *features.shape[:2], self.patchsize, self.patchsize, -1 ) unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3) if return_spatial_info: return unfolded_features, number_of_total_patches return unfolded_features def unpatch_scores(self, x, batchsize): return x.reshape(batchsize, -1, *x.shape[1:]) def score(self, x): was_numpy = False if isinstance(x, np.ndarray): was_numpy = True x = torch.from_numpy(x) while x.ndim > 2: x = torch.max(x, dim=-1).values if x.ndim == 2: if self.top_k > 1: x = torch.topk(x, self.top_k, dim=1).values.mean(1) else: x = torch.max(x, dim=1).values if was_numpy: return x.numpy() return x ================================================ FILE: utils.py ================================================ import csv import logging import os import random import matplotlib.pyplot as plt import numpy as np import PIL import torch import tqdm LOGGER = logging.getLogger(__name__) def plot_segmentation_images( savefolder, image_paths, segmentations, anomaly_scores=None, mask_paths=None, image_transform=lambda x: x, mask_transform=lambda x: x, save_depth=4, ): """Generate anomaly segmentation images. Args: image_paths: List[str] List of paths to images. segmentations: [List[np.ndarray]] Generated anomaly segmentations. anomaly_scores: [List[float]] Anomaly scores for each image. mask_paths: [List[str]] List of paths to ground truth masks. image_transform: [function or lambda] Optional transformation of images. mask_transform: [function or lambda] Optional transformation of masks. save_depth: [int] Number of path-strings to use for image savenames. """ if mask_paths is None: mask_paths = ["-1" for _ in range(len(image_paths))] masks_provided = mask_paths[0] != "-1" if anomaly_scores is None: anomaly_scores = ["-1" for _ in range(len(image_paths))] os.makedirs(savefolder, exist_ok=True) for image_path, mask_path, anomaly_score, segmentation in tqdm.tqdm( zip(image_paths, mask_paths, anomaly_scores, segmentations), total=len(image_paths), desc="Generating Segmentation Images...", leave=False, ): image = PIL.Image.open(image_path).convert("RGB") image = image_transform(image) if not isinstance(image, np.ndarray): image = image.numpy() if masks_provided: if mask_path is not None: mask = PIL.Image.open(mask_path).convert("RGB") mask = mask_transform(mask) if not isinstance(mask, np.ndarray): mask = mask.numpy() else: mask = np.zeros_like(image) savename = image_path.split("/") savename = "_".join(savename[-save_depth:]) savename = os.path.join(savefolder, savename) f, axes = plt.subplots(1, 2 + int(masks_provided)) axes[0].imshow(image.transpose(1, 2, 0)) axes[1].imshow(mask.transpose(1, 2, 0)) axes[2].imshow(segmentation) f.set_size_inches(3 * (2 + int(masks_provided)), 3) f.tight_layout() f.savefig(savename) plt.close() def create_storage_folder( main_folder_path, project_folder, group_folder, run_name, mode="iterate" ): os.makedirs(main_folder_path, exist_ok=True) project_path = os.path.join(main_folder_path, project_folder) os.makedirs(project_path, exist_ok=True) save_path = os.path.join(project_path, group_folder, run_name) if mode == "iterate": counter = 0 while os.path.exists(save_path): save_path = os.path.join(project_path, group_folder + "_" + str(counter)) counter += 1 os.makedirs(save_path) elif mode == "overwrite": os.makedirs(save_path, exist_ok=True) return save_path def set_torch_device(gpu_ids): """Returns correct torch.device. Args: gpu_ids: [list] list of gpu ids. If empty, cpu is used. """ if len(gpu_ids): # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids[0]) return torch.device("cuda:{}".format(gpu_ids[0])) return torch.device("cpu") def fix_seeds(seed, with_torch=True, with_cuda=True): """Fixed available seeds for reproducibility. Args: seed: [int] Seed value. with_torch: Flag. If true, torch-related seeds are fixed. with_cuda: Flag. If true, torch+cuda-related seeds are fixed """ random.seed(seed) np.random.seed(seed) if with_torch: torch.manual_seed(seed) if with_cuda: torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True def compute_and_store_final_results( results_path, results, row_names=None, column_names=[ "Instance AUROC", "Full Pixel AUROC", "Full PRO", "Anomaly Pixel AUROC", "Anomaly PRO", ], ): """Store computed results as CSV file. Args: results_path: [str] Where to store result csv. results: [List[List]] List of lists containing results per dataset, with results[i][0] == 'dataset_name' and results[i][1:6] = [instance_auroc, full_pixelwisew_auroc, full_pro, anomaly-only_pw_auroc, anomaly-only_pro] """ if row_names is not None: assert len(row_names) == len(results), "#Rownames != #Result-rows." mean_metrics = {} for i, result_key in enumerate(column_names): mean_metrics[result_key] = np.mean([x[i] for x in results]) LOGGER.info("{0}: {1:3.3f}".format(result_key, mean_metrics[result_key])) savename = os.path.join(results_path, "results.csv") with open(savename, "w") as csv_file: csv_writer = csv.writer(csv_file, delimiter=",") header = column_names if row_names is not None: header = ["Row Names"] + header csv_writer.writerow(header) for i, result_list in enumerate(results): csv_row = result_list if row_names is not None: csv_row = [row_names[i]] + result_list csv_writer.writerow(csv_row) mean_scores = list(mean_metrics.values()) if row_names is not None: mean_scores = ["Mean"] + mean_scores csv_writer.writerow(mean_scores) mean_metrics = {"mean_{0}".format(key): item for key, item in mean_metrics.items()} return mean_metrics