Repository: DonaldRR/SimpleNet
Branch: main
Commit: 351a2b8d4e8c
Files: 18
Total size: 108.6 KB
Directory structure:
gitextract_ap_hyzgr/
├── .gitignore
├── LICENSE
├── README.md
├── VERSION
├── backbones.py
├── common.py
├── datasets/
│ ├── __init__.py
│ ├── btad.py
│ ├── cifar10.py
│ ├── mvtec.py
│ ├── sdd.py
│ └── sdd2.py
├── main.py
├── metrics.py
├── resnet.py
├── run.sh
├── simplenet.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/*
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 DonaldRR
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# SimpleNet

**SimpleNet: A Simple Network for Image Anomaly Detection and Localization**
*Zhikang Liu, Yiming Zhou, Yuansheng Xu, Zilei Wang**
[Paper link](https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)
## Introduction
This repo contains source code for **SimpleNet** implemented with pytorch.
SimpleNet is a simple defect detection and localization network that built with a feature encoder, feature generator and defect discriminator. It is designed conceptionally simple without complex network deisng, training schemes or external data source.
## Get Started
### Environment
**Python3.8**
**Packages**:
- torch==1.12.1
- torchvision==0.13.1
- numpy==1.22.4
- opencv-python==4.5.1
(Above environment setups are not the minimum requiremetns, other versions might work too.)
### Data
Edit `run.sh` to edit dataset class and dataset path.
#### MvTecAD
Download the dataset from [here](https://www.mvtec.com/company/research/datasets/mvtec-ad/).
The dataset folders/files follow its original structure.
### Run
#### Demo train
Please specicy dataset path (line1) and log folder (line10) in `run.sh` before running.
`run.sh` gives the configuration to train models on MVTecAD dataset.
```
bash run.sh
```
## Citation
```
@inproceedings{liu2023simplenet,
title={SimpleNet: A Simple Network for Image Anomaly Detection and Localization},
author={Liu, Zhikang and Zhou, Yiming and Xu, Yuansheng and Wang, Zilei},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={20402--20411},
year={2023}
}
```
## Acknowledgement
Thanks for great inspiration from [PatchCore](https://github.com/amazon-science/patchcore-inspection)
## License
All code within the repo is under [MIT license](https://mit-license.org/)
================================================
FILE: VERSION
================================================
0.1.0
================================================
FILE: backbones.py
================================================
import timm # noqa
import torch
import torchvision.models as models # noqa
def load_ref_wrn50():
import resnet
return resnet.wide_resnet50_2(True)
_BACKBONES = {
"cait_s24_224" : "cait.cait_S24_224(True)",
"cait_xs24": "cait.cait_XS24(True)",
"alexnet": "models.alexnet(pretrained=True)",
"bninception": 'pretrainedmodels.__dict__["bninception"]'
'(pretrained="imagenet", num_classes=1000)',
"resnet18": "models.resnet18(pretrained=True)",
"resnet50": "models.resnet50(pretrained=True)",
"mc3_resnet50": "load_mc3_rn50()",
"resnet101": "models.resnet101(pretrained=True)",
"resnext101": "models.resnext101_32x8d(pretrained=True)",
"resnet200": 'timm.create_model("resnet200", pretrained=True)',
"resnest50": 'timm.create_model("resnest50d_4s2x40d", pretrained=True)',
"resnetv2_50_bit": 'timm.create_model("resnetv2_50x3_bitm", pretrained=True)',
"resnetv2_50_21k": 'timm.create_model("resnetv2_50x3_bitm_in21k", pretrained=True)',
"resnetv2_101_bit": 'timm.create_model("resnetv2_101x3_bitm", pretrained=True)',
"resnetv2_101_21k": 'timm.create_model("resnetv2_101x3_bitm_in21k", pretrained=True)',
"resnetv2_152_bit": 'timm.create_model("resnetv2_152x4_bitm", pretrained=True)',
"resnetv2_152_21k": 'timm.create_model("resnetv2_152x4_bitm_in21k", pretrained=True)',
"resnetv2_152_384": 'timm.create_model("resnetv2_152x2_bit_teacher_384", pretrained=True)',
"resnetv2_101": 'timm.create_model("resnetv2_101", pretrained=True)',
"vgg11": "models.vgg11(pretrained=True)",
"vgg19": "models.vgg19(pretrained=True)",
"vgg19_bn": "models.vgg19_bn(pretrained=True)",
"wideresnet50": "models.wide_resnet50_2(pretrained=True)",
"ref_wideresnet50": "load_ref_wrn50()",
"wideresnet101": "models.wide_resnet101_2(pretrained=True)",
"mnasnet_100": 'timm.create_model("mnasnet_100", pretrained=True)',
"mnasnet_a1": 'timm.create_model("mnasnet_a1", pretrained=True)',
"mnasnet_b1": 'timm.create_model("mnasnet_b1", pretrained=True)',
"densenet121": 'timm.create_model("densenet121", pretrained=True)',
"densenet201": 'timm.create_model("densenet201", pretrained=True)',
"inception_v4": 'timm.create_model("inception_v4", pretrained=True)',
"vit_small": 'timm.create_model("vit_small_patch16_224", pretrained=True)',
"vit_base": 'timm.create_model("vit_base_patch16_224", pretrained=True)',
"vit_large": 'timm.create_model("vit_large_patch16_224", pretrained=True)',
"vit_r50": 'timm.create_model("vit_large_r50_s32_224", pretrained=True)',
"vit_deit_base": 'timm.create_model("deit_base_patch16_224", pretrained=True)',
"vit_deit_distilled": 'timm.create_model("deit_base_distilled_patch16_224", pretrained=True)',
"vit_swin_base": 'timm.create_model("swin_base_patch4_window7_224", pretrained=True)',
"vit_swin_large": 'timm.create_model("swin_large_patch4_window7_224", pretrained=True)',
"efficientnet_b7": 'timm.create_model("tf_efficientnet_b7", pretrained=True)',
"efficientnet_b5": 'timm.create_model("tf_efficientnet_b5", pretrained=True)',
"efficientnet_b3": 'timm.create_model("tf_efficientnet_b3", pretrained=True)',
"efficientnet_b1": 'timm.create_model("tf_efficientnet_b1", pretrained=True)',
"efficientnetv2_m": 'timm.create_model("tf_efficientnetv2_m", pretrained=True)',
"efficientnetv2_l": 'timm.create_model("tf_efficientnetv2_l", pretrained=True)',
"efficientnet_b3a": 'timm.create_model("efficientnet_b3a", pretrained=True)',
}
def load(name):
return eval(_BACKBONES[name])
================================================
FILE: common.py
================================================
import copy
from typing import List
import numpy as np
import scipy.ndimage as ndimage
import torch
import torch.nn.functional as F
class _BaseMerger:
def __init__(self):
"""Merges feature embedding by name."""
def merge(self, features: list):
features = [self._reduce(feature) for feature in features]
return np.concatenate(features, axis=1)
class AverageMerger(_BaseMerger):
@staticmethod
def _reduce(features):
# NxCxWxH -> NxC
return features.reshape([features.shape[0], features.shape[1], -1]).mean(
axis=-1
)
class ConcatMerger(_BaseMerger):
@staticmethod
def _reduce(features):
# NxCxWxH -> NxCWH
return features.reshape(len(features), -1)
class Preprocessing(torch.nn.Module):
def __init__(self, input_dims, output_dim):
super(Preprocessing, self).__init__()
self.input_dims = input_dims
self.output_dim = output_dim
self.preprocessing_modules = torch.nn.ModuleList()
for input_dim in input_dims:
module = MeanMapper(output_dim)
self.preprocessing_modules.append(module)
def forward(self, features):
_features = []
for module, feature in zip(self.preprocessing_modules, features):
_features.append(module(feature))
return torch.stack(_features, dim=1)
class MeanMapper(torch.nn.Module):
def __init__(self, preprocessing_dim):
super(MeanMapper, self).__init__()
self.preprocessing_dim = preprocessing_dim
def forward(self, features):
features = features.reshape(len(features), 1, -1)
return F.adaptive_avg_pool1d(features, self.preprocessing_dim).squeeze(1)
class Aggregator(torch.nn.Module):
def __init__(self, target_dim):
super(Aggregator, self).__init__()
self.target_dim = target_dim
def forward(self, features):
"""Returns reshaped and average pooled features."""
# batchsize x number_of_layers x input_dim -> batchsize x target_dim
features = features.reshape(len(features), 1, -1)
features = F.adaptive_avg_pool1d(features, self.target_dim)
return features.reshape(len(features), -1)
class RescaleSegmentor:
def __init__(self, device, target_size=224):
self.device = device
self.target_size = target_size
self.smoothing = 4
def convert_to_segmentation(self, patch_scores, features):
with torch.no_grad():
if isinstance(patch_scores, np.ndarray):
patch_scores = torch.from_numpy(patch_scores)
_scores = patch_scores.to(self.device)
_scores = _scores.unsqueeze(1)
_scores = F.interpolate(
_scores, size=self.target_size, mode="bilinear", align_corners=False
)
_scores = _scores.squeeze(1)
patch_scores = _scores.cpu().numpy()
if isinstance(features, np.ndarray):
features = torch.from_numpy(features)
features = features.to(self.device).permute(0, 3, 1, 2)
if self.target_size[0] * self.target_size[1] * features.shape[0] * features.shape[1] >= 2**31:
subbatch_size = int((2**31-1) / (self.target_size[0] * self.target_size[1] * features.shape[1]))
interpolated_features = []
for i_subbatch in range(int(features.shape[0] / subbatch_size + 1)):
subfeatures = features[i_subbatch*subbatch_size:(i_subbatch+1)*subbatch_size]
subfeatures = subfeatures.unsuqeeze(0) if len(subfeatures.shape) == 3 else subfeatures
subfeatures = F.interpolate(
subfeatures, size=self.target_size, mode="bilinear", align_corners=False
)
interpolated_features.append(subfeatures)
features = torch.cat(interpolated_features, 0)
else:
features = F.interpolate(
features, size=self.target_size, mode="bilinear", align_corners=False
)
features = features.cpu().numpy()
return [
ndimage.gaussian_filter(patch_score, sigma=self.smoothing)
for patch_score in patch_scores
], [
feature
for feature in features
]
class NetworkFeatureAggregator(torch.nn.Module):
"""Efficient extraction of network features."""
def __init__(self, backbone, layers_to_extract_from, device, train_backbone=False):
super(NetworkFeatureAggregator, self).__init__()
"""Extraction of network features.
Runs a network only to the last layer of the list of layers where
network features should be extracted from.
Args:
backbone: torchvision.model
layers_to_extract_from: [list of str]
"""
self.layers_to_extract_from = layers_to_extract_from
self.backbone = backbone
self.device = device
self.train_backbone = train_backbone
if not hasattr(backbone, "hook_handles"):
self.backbone.hook_handles = []
for handle in self.backbone.hook_handles:
handle.remove()
self.outputs = {}
for extract_layer in layers_to_extract_from:
forward_hook = ForwardHook(
self.outputs, extract_layer, layers_to_extract_from[-1]
)
if "." in extract_layer:
extract_block, extract_idx = extract_layer.split(".")
network_layer = backbone.__dict__["_modules"][extract_block]
if extract_idx.isnumeric():
extract_idx = int(extract_idx)
network_layer = network_layer[extract_idx]
else:
network_layer = network_layer.__dict__["_modules"][extract_idx]
else:
network_layer = backbone.__dict__["_modules"][extract_layer]
if isinstance(network_layer, torch.nn.Sequential):
self.backbone.hook_handles.append(
network_layer[-1].register_forward_hook(forward_hook)
)
else:
self.backbone.hook_handles.append(
network_layer.register_forward_hook(forward_hook)
)
self.to(self.device)
def forward(self, images, eval=True):
self.outputs.clear()
if self.train_backbone and not eval:
self.backbone(images)
else:
with torch.no_grad():
# The backbone will throw an Exception once it reached the last
# layer to compute features from. Computation will stop there.
try:
_ = self.backbone(images)
except LastLayerToExtractReachedException:
pass
return self.outputs
def feature_dimensions(self, input_shape):
"""Computes the feature dimensions for all layers given input_shape."""
_input = torch.ones([1] + list(input_shape)).to(self.device)
_output = self(_input)
return [_output[layer].shape[1] for layer in self.layers_to_extract_from]
class ForwardHook:
def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: str):
self.hook_dict = hook_dict
self.layer_name = layer_name
self.raise_exception_to_break = copy.deepcopy(
layer_name == last_layer_to_extract
)
def __call__(self, module, input, output):
self.hook_dict[self.layer_name] = output
# if self.raise_exception_to_break:
# raise LastLayerToExtractReachedException()
return None
class LastLayerToExtractReachedException(Exception):
pass
================================================
FILE: datasets/__init__.py
================================================
================================================
FILE: datasets/btad.py
================================================
import os
from enum import Enum
import PIL
import torch
from torchvision import transforms
_CLASSNAMES = [
"01",
"02",
"03"
]
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class BTADDataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for MVTec.
"""
def __init__(
self,
source,
classname,
resize=256,
imagesize=224,
split=DatasetSplit.TRAIN,
train_val_split=1.0,
rotate_degrees=0,
translate=0,
brightness_factor=0,
contrast_factor=0,
saturation_factor=0,
gray_p=0,
h_flip_p=0,
v_flip_p=0,
scale=0,
**kwargs,
):
"""
Args:
source: [str]. Path to the MVTec data folder.
classname: [str or None]. Name of MVTec class that should be
provided in this dataset. If None, the datasets
iterates over all available images.
resize: [int]. (Square) Size the loaded image initially gets
resized to.
imagesize: [int]. (Square) Size the resized loaded image gets
(center-)cropped to.
split: [enum-option]. Indicates if training or test split of the
data should be used. Has to be an option taken from
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
mvtec.DatasetSplit.TEST will also load mask data.
"""
super().__init__()
self.source = source
self.split = split
self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES
self.train_val_split = train_val_split
self.transform_std = IMAGENET_STD
self.transform_mean = IMAGENET_MEAN
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
self.transform_img = [
transforms.Resize(resize),
# transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
transforms.RandomHorizontalFlip(h_flip_p),
transforms.RandomVerticalFlip(v_flip_p),
transforms.RandomGrayscale(gray_p),
transforms.RandomAffine(rotate_degrees,
translate=(translate, translate),
scale=(1.0-scale, 1.0+scale),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
self.transform_mask = [
transforms.Resize(resize),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.imagesize = (3, imagesize, imagesize)
def __getitem__(self, idx):
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
image = PIL.Image.open(image_path).convert("RGB")
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and mask_path is not None:
mask = PIL.Image.open(mask_path)
mask = self.transform_mask(mask)
else:
mask = torch.zeros([1, *image.size()[1:]])
return {
"image": image,
"mask": mask,
"classname": classname,
"anomaly": anomaly,
"is_anomaly": int(anomaly != "good"),
"image_name": "/".join(image_path.split("/")[-4:]),
"image_path": image_path,
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
imgpaths_per_class = {}
maskpaths_per_class = {}
for classname in self.classnames_to_use:
classpath = os.path.join(self.source, classname, self.split.value)
maskpath = os.path.join(self.source, classname, "ground_truth")
anomaly_types = os.listdir(classpath)
imgpaths_per_class[classname] = {}
maskpaths_per_class[classname] = {}
for anomaly in anomaly_types:
anomaly_path = os.path.join(classpath, anomaly)
anomaly_files = sorted(os.listdir(anomaly_path))
imgpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_path, x) for x in anomaly_files
]
if self.train_val_split < 1.0:
n_images = len(imgpaths_per_class[classname][anomaly])
train_val_split_idx = int(n_images * self.train_val_split)
if self.split == DatasetSplit.TRAIN:
imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
classname
][anomaly][:train_val_split_idx]
elif self.split == DatasetSplit.VAL:
imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
classname
][anomaly][train_val_split_idx:]
if self.split == DatasetSplit.TEST and anomaly != "good":
anomaly_mask_path = os.path.join(maskpath, anomaly)
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
maskpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
]
else:
maskpaths_per_class[classname]["good"] = None
# Unrolls the data dictionary to an easy-to-iterate list.
data_to_iterate = []
for classname in sorted(imgpaths_per_class.keys()):
for anomaly in sorted(imgpaths_per_class[classname].keys()):
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
data_tuple = [classname, anomaly, image_path]
if self.split == DatasetSplit.TEST and anomaly != "good":
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
else:
data_tuple.append(None)
data_to_iterate.append(data_tuple)
return imgpaths_per_class, data_to_iterate
================================================
FILE: datasets/cifar10.py
================================================
import os
from enum import Enum
import PIL
import torch
from torchvision import transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class Cifar10Dataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for MVTec.
"""
_CLASSES = list(range(10))
def __init__(
self,
source,
classname,
resize=256,
imagesize=224,
split=DatasetSplit.TRAIN,
train_val_split=1.0,
rotate_degrees=0,
translate=0,
brightness_factor=0,
contrast_factor=0,
saturation_factor=0,
gray_p=0,
h_flip_p=0,
v_flip_p=0,
scale=0,
**kwargs,
):
"""
Args:
source: [str]. Path to the MVTec data folder.
classname: [str or None]. Name of MVTec class that should be
provided in this dataset. If None, the datasets
iterates over all available images.
resize: [int]. (Square) Size the loaded image initially gets
resized to.
imagesize: [int]. (Square) Size the resized loaded image gets
(center-)cropped to.
split: [enum-option]. Indicates if training or test split of the
data should be used. Has to be an option taken from
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
mvtec.DatasetSplit.TEST will also load mask data.
"""
super().__init__()
self.source = source
self.split = split
self.classname = int(classname)
self.train_val_split = train_val_split
self.data_to_iterate = self.get_image_data()
self.transform_std = IMAGENET_STD
self.transform_mean = IMAGENET_MEAN
self.transform_img = [
transforms.Resize(resize),
# transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
transforms.RandomHorizontalFlip(h_flip_p),
transforms.RandomVerticalFlip(v_flip_p),
transforms.RandomGrayscale(gray_p),
transforms.RandomAffine(rotate_degrees,
translate=(translate, translate),
scale=(1.0-scale, 1.0+scale),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
self.transform_mask = [
transforms.Resize(resize),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.imagesize = (3, imagesize, imagesize)
def __getitem__(self, idx):
img_path, classname = self.data_to_iterate[idx]
image = PIL.Image.open(img_path).convert("RGB")
image = self.transform_img(image)
return {
"image": image,
"classname": classname,
"anomaly": int(classname != self.classname),
"is_anomaly": int(classname != self.classname),
"image_name": os.path.split(img_path)[-1],
"image_path": img_path,
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
data_to_iterate = []
for classname in Cifar10Dataset._CLASSES:
if self.split == DatasetSplit.TRAIN:
if classname != self.classname:
continue
class_dir = os.path.join(self.source, self.split.value, str(classname))
for fn in os.listdir(class_dir):
img_path = os.path.join(class_dir, fn)
data_to_iterate.append([img_path, classname])
return data_to_iterate
================================================
FILE: datasets/mvtec.py
================================================
import os
from enum import Enum
import PIL
import torch
from torchvision import transforms
_CLASSNAMES = [
"bottle",
"cable",
"capsule",
"carpet",
"grid",
"hazelnut",
"leather",
"metal_nut",
"pill",
"screw",
"tile",
"toothbrush",
"transistor",
"wood",
"zipper",
]
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class MVTecDataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for MVTec.
"""
def __init__(
self,
source,
classname,
resize=256,
imagesize=224,
split=DatasetSplit.TRAIN,
train_val_split=1.0,
rotate_degrees=0,
translate=0,
brightness_factor=0,
contrast_factor=0,
saturation_factor=0,
gray_p=0,
h_flip_p=0,
v_flip_p=0,
scale=0,
**kwargs,
):
"""
Args:
source: [str]. Path to the MVTec data folder.
classname: [str or None]. Name of MVTec class that should be
provided in this dataset. If None, the datasets
iterates over all available images.
resize: [int]. (Square) Size the loaded image initially gets
resized to.
imagesize: [int]. (Square) Size the resized loaded image gets
(center-)cropped to.
split: [enum-option]. Indicates if training or test split of the
data should be used. Has to be an option taken from
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
mvtec.DatasetSplit.TEST will also load mask data.
"""
super().__init__()
self.source = source
self.split = split
self.classnames_to_use = [classname] if classname is not None else _CLASSNAMES
self.train_val_split = train_val_split
self.transform_std = IMAGENET_STD
self.transform_mean = IMAGENET_MEAN
self.imgpaths_per_class, self.data_to_iterate = self.get_image_data()
self.transform_img = [
transforms.Resize(resize),
# transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
transforms.RandomHorizontalFlip(h_flip_p),
transforms.RandomVerticalFlip(v_flip_p),
transforms.RandomGrayscale(gray_p),
transforms.RandomAffine(rotate_degrees,
translate=(translate, translate),
scale=(1.0-scale, 1.0+scale),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
self.transform_mask = [
transforms.Resize(resize),
transforms.CenterCrop(imagesize),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.imagesize = (3, imagesize, imagesize)
def __getitem__(self, idx):
classname, anomaly, image_path, mask_path = self.data_to_iterate[idx]
image = PIL.Image.open(image_path).convert("RGB")
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and mask_path is not None:
mask = PIL.Image.open(mask_path)
mask = self.transform_mask(mask)
else:
mask = torch.zeros([1, *image.size()[1:]])
return {
"image": image,
"mask": mask,
"classname": classname,
"anomaly": anomaly,
"is_anomaly": int(anomaly != "good"),
"image_name": "/".join(image_path.split("/")[-4:]),
"image_path": image_path,
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
imgpaths_per_class = {}
maskpaths_per_class = {}
for classname in self.classnames_to_use:
classpath = os.path.join(self.source, classname, self.split.value)
maskpath = os.path.join(self.source, classname, "ground_truth")
anomaly_types = os.listdir(classpath)
imgpaths_per_class[classname] = {}
maskpaths_per_class[classname] = {}
for anomaly in anomaly_types:
anomaly_path = os.path.join(classpath, anomaly)
anomaly_files = sorted(os.listdir(anomaly_path))
imgpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_path, x) for x in anomaly_files
]
if self.train_val_split < 1.0:
n_images = len(imgpaths_per_class[classname][anomaly])
train_val_split_idx = int(n_images * self.train_val_split)
if self.split == DatasetSplit.TRAIN:
imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
classname
][anomaly][:train_val_split_idx]
elif self.split == DatasetSplit.VAL:
imgpaths_per_class[classname][anomaly] = imgpaths_per_class[
classname
][anomaly][train_val_split_idx:]
if self.split == DatasetSplit.TEST and anomaly != "good":
anomaly_mask_path = os.path.join(maskpath, anomaly)
anomaly_mask_files = sorted(os.listdir(anomaly_mask_path))
maskpaths_per_class[classname][anomaly] = [
os.path.join(anomaly_mask_path, x) for x in anomaly_mask_files
]
else:
maskpaths_per_class[classname]["good"] = None
# Unrolls the data dictionary to an easy-to-iterate list.
data_to_iterate = []
for classname in sorted(imgpaths_per_class.keys()):
for anomaly in sorted(imgpaths_per_class[classname].keys()):
for i, image_path in enumerate(imgpaths_per_class[classname][anomaly]):
data_tuple = [classname, anomaly, image_path]
if self.split == DatasetSplit.TEST and anomaly != "good":
data_tuple.append(maskpaths_per_class[classname][anomaly][i])
else:
data_tuple.append(None)
data_to_iterate.append(data_tuple)
return imgpaths_per_class, data_to_iterate
================================================
FILE: datasets/sdd.py
================================================
import os
from enum import Enum
import pickle
import cv2
import PIL
import torch
from torchvision import transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class SDDDataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for MVTec.
"""
def __init__(
self,
source,
classname,
resize=256,
imagesize=224,
split=DatasetSplit.TRAIN,
train_val_split=1.0,
rotate_degrees=0,
translate=0,
brightness_factor=0,
contrast_factor=0,
saturation_factor=0,
gray_p=0,
h_flip_p=0,
v_flip_p=0,
scale=0,
**kwargs,
):
"""
Args:
source: [str]. Path to the MVTec data folder.
classname: [str or None]. Name of MVTec class that should be
provided in this dataset. If None, the datasets
iterates over all available images.
resize: [int]. (Square) Size the loaded image initially gets
resized to.
imagesize: [int]. (Square) Size the resized loaded image gets
(center-)cropped to.
split: [enum-option]. Indicates if training or test split of the
data should be used. Has to be an option taken from
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
mvtec.DatasetSplit.TEST will also load mask data.
"""
super().__init__()
self.source = source
self.split = split
self.split_id = int(classname)
self.train_val_split = train_val_split
self.transform_std = IMAGENET_STD
self.transform_mean = IMAGENET_MEAN
self.data_to_iterate = self.get_image_data()
self.transform_img = [
transforms.Resize((int(resize*2.5+.5), resize)),
# transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
transforms.RandomHorizontalFlip(h_flip_p),
transforms.RandomVerticalFlip(v_flip_p),
transforms.RandomGrayscale(gray_p),
transforms.RandomAffine(rotate_degrees,
translate=(translate, translate),
scale=(1.0-scale, 1.0+scale),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
self.transform_mask = [
transforms.Resize((int(resize*2.5+.5), resize)),
transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize)
# if self.split == DatasetSplit.TEST:
# for i in range(len(self.data_to_iterate)):
# self.__getitem__(i)
def __getitem__(self, idx):
data = self.data_to_iterate[idx]
image = PIL.Image.open(data["img"]).convert("RGB")
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and data["anomaly"] == 1:
mask = PIL.Image.open(data["label"])
mask = self.transform_mask(mask)
else:
mask = torch.zeros([1, *image.size()[1:]])
return {
"image": image,
"mask": mask,
"classname": str(self.split_id),
"anomaly": data["anomaly"],
"is_anomaly": data["anomaly"],
"image_path": data["img"],
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
data_ids = []
with open(os.path.join(self.source, "KolektorSDD-training-splits", "split.pyb"), "rb") as f:
train_ids, test_ids, _ = pickle.load(f)
if self.split == DatasetSplit.TRAIN:
data_ids = train_ids[self.split_id]
else:
data_ids = test_ids[self.split_id]
data = {}
for data_id in data_ids:
item_dir = os.path.join(self.source, data_id)
fns = os.listdir(item_dir)
part_ids = [os.path.splitext(fn)[0] for fn in fns if fn.endswith("jpg")]
parts = {part_id:{"img":"", "label":"", "anomaly":0}
for part_id in part_ids}
for part_id in parts:
for fn in fns:
if part_id in fn:
if "label" in fn:
label = cv2.imread(os.path.join(item_dir, fn))
if label.sum() > 0:
parts[part_id]["anomaly"] = 1
parts[part_id]["label"] = os.path.join(item_dir, fn)
else:
parts[part_id]["img"] = os.path.join(item_dir, fn)
for k, v in parts.items():
if self.split == DatasetSplit.TRAIN and v["anomaly"] == 1:
continue
data[data_id + '_' + k] = v
return list(data.values())
================================================
FILE: datasets/sdd2.py
================================================
import os
from enum import Enum
import pickle
import cv2
import PIL
import torch
from torchvision import transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]
class DatasetSplit(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test"
class SDD2Dataset(torch.utils.data.Dataset):
"""
PyTorch Dataset for MVTec.
"""
def __init__(
self,
source,
classname,
resize=256,
imagesize=224,
split=DatasetSplit.TRAIN,
train_val_split=1.0,
rotate_degrees=0,
translate=0,
brightness_factor=0,
contrast_factor=0,
saturation_factor=0,
gray_p=0,
h_flip_p=0,
v_flip_p=0,
scale=0,
**kwargs,
):
"""
Args:
source: [str]. Path to the MVTec data folder.
classname: [str or None]. Name of MVTec class that should be
provided in this dataset. If None, the datasets
iterates over all available images.
resize: [int]. (Square) Size the loaded image initially gets
resized to.
imagesize: [int]. (Square) Size the resized loaded image gets
(center-)cropped to.
split: [enum-option]. Indicates if training or test split of the
data should be used. Has to be an option taken from
DatasetSplit, e.g. mvtec.DatasetSplit.TRAIN. Note that
mvtec.DatasetSplit.TEST will also load mask data.
"""
super().__init__()
self.source = source
self.split = split
self.train_val_split = train_val_split
self.transform_std = IMAGENET_STD
self.transform_mean = IMAGENET_MEAN
self.data_to_iterate = self.get_image_data()
self.transform_img = [
transforms.Resize((int(resize*2.5+.5), resize)),
# transforms.RandomRotation(rotate_degrees, transforms.InterpolationMode.BILINEAR),
transforms.ColorJitter(brightness_factor, contrast_factor, saturation_factor),
transforms.RandomHorizontalFlip(h_flip_p),
transforms.RandomVerticalFlip(v_flip_p),
transforms.RandomGrayscale(gray_p),
transforms.RandomAffine(rotate_degrees,
translate=(translate, translate),
scale=(1.0-scale, 1.0+scale),
interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),
transforms.ToTensor(),
transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
]
self.transform_img = transforms.Compose(self.transform_img)
self.transform_mask = [
transforms.Resize((int(resize*2.5+.5), resize)),
transforms.CenterCrop((int(imagesize * 2.5 + .5), imagesize)),
transforms.ToTensor(),
]
self.transform_mask = transforms.Compose(self.transform_mask)
self.imagesize = (3, int(imagesize * 2.5 + .5), imagesize)
# if self.split == DatasetSplit.TEST:
# for i in range(len(self.data_to_iterate)):
# self.__getitem__(i)
def __getitem__(self, idx):
img_path, gt_path, is_anomaly = self.data_to_iterate[idx]
image = PIL.Image.open(img_path).convert("RGB")
image = self.transform_img(image)
if self.split == DatasetSplit.TEST and is_anomaly:
mask = PIL.Image.open(gt_path)
mask = self.transform_mask(mask)
else:
mask = torch.zeros([1, *image.size()[1:]])
return {
"image": image,
"mask": mask,
"classname": "",
"anomaly": is_anomaly,
"is_anomaly": is_anomaly,
"image_path": img_path,
}
def __len__(self):
return len(self.data_to_iterate)
def get_image_data(self):
data_ids = []
data_dir = os.path.join(self.source, "train" if self.split == DatasetSplit.TRAIN else "test")
data = []
test = [0, 0]
for fn in os.listdir(data_dir):
if "GT" not in fn:
data_id = os.path.splitext(fn)[0]
img_path = os.path.join(data_dir, fn)
gt_path = os.path.join(data_dir, f"{data_id}_GT.png")
assert os.path.exists(img_path)
assert os.path.exists(gt_path), gt_path
gt = cv2.imread(gt_path)
is_anomaly = gt.sum() > 0
if is_anomaly:
test[1] = test[1] + 1
else:
test[0] = test[0] + 1
if self.split == DatasetSplit.TRAIN and is_anomaly:
continue
data.append([img_path, gt_path, gt.sum() > 0])
return data
================================================
FILE: main.py
================================================
# ------------------------------------------------------------------
# SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)
# Github source: https://github.com/DonaldRR/SimpleNet
# Licensed under the MIT License [see LICENSE for details]
# The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection)
# ------------------------------------------------------------------
import logging
import os
import sys
import click
import numpy as np
import torch
sys.path.append("src")
import backbones
import common
import metrics
import simplenet
import utils
LOGGER = logging.getLogger(__name__)
_DATASETS = {
"mvtec": ["datasets.mvtec", "MVTecDataset"],
}
@click.group(chain=True)
@click.option("--results_path", type=str)
@click.option("--gpu", type=int, default=[0], multiple=True, show_default=True)
@click.option("--seed", type=int, default=0, show_default=True)
@click.option("--log_group", type=str, default="group")
@click.option("--log_project", type=str, default="project")
@click.option("--run_name", type=str, default="test")
@click.option("--test", is_flag=True)
@click.option("--save_segmentation_images", is_flag=True, default=False, show_default=True)
def main(**kwargs):
pass
@main.result_callback()
def run(
methods,
results_path,
gpu,
seed,
log_group,
log_project,
run_name,
test,
save_segmentation_images
):
methods = {key: item for (key, item) in methods}
run_save_path = utils.create_storage_folder(
results_path, log_project, log_group, run_name, mode="overwrite"
)
pid = os.getpid()
list_of_dataloaders = methods["get_dataloaders"](seed)
device = utils.set_torch_device(gpu)
result_collect = []
for dataloader_count, dataloaders in enumerate(list_of_dataloaders):
LOGGER.info(
"Evaluating dataset [{}] ({}/{})...".format(
dataloaders["training"].name,
dataloader_count + 1,
len(list_of_dataloaders),
)
)
utils.fix_seeds(seed, device)
dataset_name = dataloaders["training"].name
imagesize = dataloaders["training"].dataset.imagesize
simplenet_list = methods["get_simplenet"](imagesize, device)
models_dir = os.path.join(run_save_path, "models")
os.makedirs(models_dir, exist_ok=True)
for i, SimpleNet in enumerate(simplenet_list):
# torch.cuda.empty_cache()
if SimpleNet.backbone.seed is not None:
utils.fix_seeds(SimpleNet.backbone.seed, device)
LOGGER.info(
"Training models ({}/{})".format(i + 1, len(simplenet_list))
)
# torch.cuda.empty_cache()
SimpleNet.set_model_dir(os.path.join(models_dir, f"{i}"), dataset_name)
if not test:
i_auroc, p_auroc, pro_auroc = SimpleNet.train(dataloaders["training"], dataloaders["testing"])
else:
# BUG: the following line is not using. Set test with True by default.
# i_auroc, p_auroc, pro_auroc = SimpleNet.test(dataloaders["training"], dataloaders["testing"], save_segmentation_images)
print("Warning: Pls set test with true by default")
result_collect.append(
{
"dataset_name": dataset_name,
"instance_auroc": i_auroc, # auroc,
"full_pixel_auroc": p_auroc, # full_pixel_auroc,
"anomaly_pixel_auroc": pro_auroc,
}
)
for key, item in result_collect[-1].items():
if key != "dataset_name":
LOGGER.info("{0}: {1:3.3f}".format(key, item))
LOGGER.info("\n\n-----\n")
# Store all results and mean scores to a csv-file.
result_metric_names = list(result_collect[-1].keys())[1:]
result_dataset_names = [results["dataset_name"] for results in result_collect]
result_scores = [list(results.values())[1:] for results in result_collect]
utils.compute_and_store_final_results(
run_save_path,
result_scores,
column_names=result_metric_names,
row_names=result_dataset_names,
)
@main.command("net")
@click.option("--backbone_names", "-b", type=str, multiple=True, default=[])
@click.option("--layers_to_extract_from", "-le", type=str, multiple=True, default=[])
@click.option("--pretrain_embed_dimension", type=int, default=1024)
@click.option("--target_embed_dimension", type=int, default=1024)
@click.option("--patchsize", type=int, default=3)
@click.option("--embedding_size", type=int, default=1024)
@click.option("--meta_epochs", type=int, default=1)
@click.option("--aed_meta_epochs", type=int, default=1)
@click.option("--gan_epochs", type=int, default=1)
@click.option("--dsc_layers", type=int, default=2)
@click.option("--dsc_hidden", type=int, default=None)
@click.option("--noise_std", type=float, default=0.05)
@click.option("--dsc_margin", type=float, default=0.8)
@click.option("--dsc_lr", type=float, default=0.0002)
@click.option("--auto_noise", type=float, default=0)
@click.option("--train_backbone", is_flag=True)
@click.option("--cos_lr", is_flag=True)
@click.option("--pre_proj", type=int, default=0)
@click.option("--proj_layer_type", type=int, default=0)
@click.option("--mix_noise", type=int, default=1)
def net(
backbone_names,
layers_to_extract_from,
pretrain_embed_dimension,
target_embed_dimension,
patchsize,
embedding_size,
meta_epochs,
aed_meta_epochs,
gan_epochs,
noise_std,
dsc_layers,
dsc_hidden,
dsc_margin,
dsc_lr,
auto_noise,
train_backbone,
cos_lr,
pre_proj,
proj_layer_type,
mix_noise,
):
backbone_names = list(backbone_names)
if len(backbone_names) > 1:
layers_to_extract_from_coll = [[] for _ in range(len(backbone_names))]
for layer in layers_to_extract_from:
idx = int(layer.split(".")[0])
layer = ".".join(layer.split(".")[1:])
layers_to_extract_from_coll[idx].append(layer)
else:
layers_to_extract_from_coll = [layers_to_extract_from]
def get_simplenet(input_shape, device):
simplenets = []
for backbone_name, layers_to_extract_from in zip(
backbone_names, layers_to_extract_from_coll
):
backbone_seed = None
if ".seed-" in backbone_name:
backbone_name, backbone_seed = backbone_name.split(".seed-")[0], int(
backbone_name.split("-")[-1]
)
backbone = backbones.load(backbone_name)
backbone.name, backbone.seed = backbone_name, backbone_seed
simplenet_inst = simplenet.SimpleNet(device)
simplenet_inst.load(
backbone=backbone,
layers_to_extract_from=layers_to_extract_from,
device=device,
input_shape=input_shape,
pretrain_embed_dimension=pretrain_embed_dimension,
target_embed_dimension=target_embed_dimension,
patchsize=patchsize,
embedding_size=embedding_size,
meta_epochs=meta_epochs,
aed_meta_epochs=aed_meta_epochs,
gan_epochs=gan_epochs,
noise_std=noise_std,
dsc_layers=dsc_layers,
dsc_hidden=dsc_hidden,
dsc_margin=dsc_margin,
dsc_lr=dsc_lr,
auto_noise=auto_noise,
train_backbone=train_backbone,
cos_lr=cos_lr,
pre_proj=pre_proj,
proj_layer_type=proj_layer_type,
mix_noise=mix_noise,
)
simplenets.append(simplenet_inst)
return simplenets
return ("get_simplenet", get_simplenet)
@main.command("dataset")
@click.argument("name", type=str)
@click.argument("data_path", type=click.Path(exists=True, file_okay=False))
@click.option("--subdatasets", "-d", multiple=True, type=str, required=True)
@click.option("--train_val_split", type=float, default=1, show_default=True)
@click.option("--batch_size", default=2, type=int, show_default=True)
@click.option("--num_workers", default=2, type=int, show_default=True)
@click.option("--resize", default=256, type=int, show_default=True)
@click.option("--imagesize", default=224, type=int, show_default=True)
@click.option("--rotate_degrees", default=0, type=int)
@click.option("--translate", default=0, type=float)
@click.option("--scale", default=0.0, type=float)
@click.option("--brightness", default=0.0, type=float)
@click.option("--contrast", default=0.0, type=float)
@click.option("--saturation", default=0.0, type=float)
@click.option("--gray", default=0.0, type=float)
@click.option("--hflip", default=0.0, type=float)
@click.option("--vflip", default=0.0, type=float)
@click.option("--augment", is_flag=True)
def dataset(
name,
data_path,
subdatasets,
train_val_split,
batch_size,
resize,
imagesize,
num_workers,
rotate_degrees,
translate,
scale,
brightness,
contrast,
saturation,
gray,
hflip,
vflip,
augment,
):
dataset_info = _DATASETS[name]
dataset_library = __import__(dataset_info[0], fromlist=[dataset_info[1]])
def get_dataloaders(seed):
dataloaders = []
for subdataset in subdatasets:
train_dataset = dataset_library.__dict__[dataset_info[1]](
data_path,
classname=subdataset,
resize=resize,
train_val_split=train_val_split,
imagesize=imagesize,
split=dataset_library.DatasetSplit.TRAIN,
seed=seed,
rotate_degrees=rotate_degrees,
translate=translate,
brightness_factor=brightness,
contrast_factor=contrast,
saturation_factor=saturation,
gray_p=gray,
h_flip_p=hflip,
v_flip_p=vflip,
scale=scale,
augment=augment,
)
test_dataset = dataset_library.__dict__[dataset_info[1]](
data_path,
classname=subdataset,
resize=resize,
imagesize=imagesize,
split=dataset_library.DatasetSplit.TEST,
seed=seed,
)
LOGGER.info(f"Dataset: train={len(train_dataset)} test={len(test_dataset)}")
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
prefetch_factor=2,
pin_memory=True,
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
prefetch_factor=2,
pin_memory=True,
)
train_dataloader.name = name
if subdataset is not None:
train_dataloader.name += "_" + subdataset
if train_val_split < 1:
val_dataset = dataset_library.__dict__[dataset_info[1]](
data_path,
classname=subdataset,
resize=resize,
train_val_split=train_val_split,
imagesize=imagesize,
split=dataset_library.DatasetSplit.VAL,
seed=seed,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
prefetch_factor=4,
pin_memory=True,
)
else:
val_dataloader = None
dataloader_dict = {
"training": train_dataloader,
"validation": val_dataloader,
"testing": test_dataloader,
}
dataloaders.append(dataloader_dict)
return dataloaders
return ("get_dataloaders", get_dataloaders)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
LOGGER.info("Command line arguments: {}".format(" ".join(sys.argv)))
main()
================================================
FILE: metrics.py
================================================
"""Anomaly metrics."""
import cv2
import numpy as np
from sklearn import metrics
def compute_imagewise_retrieval_metrics(
anomaly_prediction_weights, anomaly_ground_truth_labels
):
"""
Computes retrieval statistics (AUROC, FPR, TPR).
Args:
anomaly_prediction_weights: [np.array or list] [N] Assignment weights
per image. Higher indicates higher
probability of being an anomaly.
anomaly_ground_truth_labels: [np.array or list] [N] Binary labels - 1
if image is an anomaly, 0 if not.
"""
fpr, tpr, thresholds = metrics.roc_curve(
anomaly_ground_truth_labels, anomaly_prediction_weights
)
auroc = metrics.roc_auc_score(
anomaly_ground_truth_labels, anomaly_prediction_weights
)
precision, recall, _ = metrics.precision_recall_curve(
anomaly_ground_truth_labels, anomaly_prediction_weights
)
auc_pr = metrics.auc(recall, precision)
return {"auroc": auroc, "fpr": fpr, "tpr": tpr, "threshold": thresholds}
def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_truth_masks):
"""
Computes pixel-wise statistics (AUROC, FPR, TPR) for anomaly segmentations
and ground truth segmentation masks.
Args:
anomaly_segmentations: [list of np.arrays or np.array] [NxHxW] Contains
generated segmentation masks.
ground_truth_masks: [list of np.arrays or np.array] [NxHxW] Contains
predefined ground truth segmentation masks
"""
if isinstance(anomaly_segmentations, list):
anomaly_segmentations = np.stack(anomaly_segmentations)
if isinstance(ground_truth_masks, list):
ground_truth_masks = np.stack(ground_truth_masks)
flat_anomaly_segmentations = anomaly_segmentations.ravel()
flat_ground_truth_masks = ground_truth_masks.ravel()
fpr, tpr, thresholds = metrics.roc_curve(
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
)
auroc = metrics.roc_auc_score(
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
)
precision, recall, thresholds = metrics.precision_recall_curve(
flat_ground_truth_masks.astype(int), flat_anomaly_segmentations
)
F1_scores = np.divide(
2 * precision * recall,
precision + recall,
out=np.zeros_like(precision),
where=(precision + recall) != 0,
)
optimal_threshold = thresholds[np.argmax(F1_scores)]
predictions = (flat_anomaly_segmentations >= optimal_threshold).astype(int)
fpr_optim = np.mean(predictions > flat_ground_truth_masks)
fnr_optim = np.mean(predictions < flat_ground_truth_masks)
return {
"auroc": auroc,
"fpr": fpr,
"tpr": tpr,
"optimal_threshold": optimal_threshold,
"optimal_fpr": fpr_optim,
"optimal_fnr": fnr_optim,
}
import pandas as pd
from skimage import measure
def compute_pro(masks, amaps, num_th=200):
df = pd.DataFrame([], columns=["pro", "fpr", "threshold"])
binary_amaps = np.zeros_like(amaps, dtype=np.bool)
min_th = amaps.min()
max_th = amaps.max()
delta = (max_th - min_th) / num_th
k = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5))
for th in np.arange(min_th, max_th, delta):
binary_amaps[amaps <= th] = 0
binary_amaps[amaps > th] = 1
pros = []
for binary_amap, mask in zip(binary_amaps, masks):
binary_amap = cv2.dilate(binary_amap.astype(np.uint8), k)
for region in measure.regionprops(measure.label(mask)):
axes0_ids = region.coords[:, 0]
axes1_ids = region.coords[:, 1]
tp_pixels = binary_amap[axes0_ids, axes1_ids].sum()
pros.append(tp_pixels / region.area)
inverse_masks = 1 - masks
fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum()
fpr = fp_pixels / inverse_masks.sum()
df = df.append({"pro": np.mean(pros), "fpr": fpr, "threshold": th}, ignore_index=True)
# Normalize FPR from 0 ~ 1 to 0 ~ 0.3
df = df[df["fpr"] < 0.3]
df["fpr"] = df["fpr"] / df["fpr"].max()
pro_auc = metrics.auc(df["fpr"], df["pro"])
return pro_auc
================================================
FILE: resnet.py
================================================
import torch
from torch import Tensor
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
PADDING_MODE = 'reflect' # {'zeros', 'reflect', 'replicate', 'circular'}
# PADDING_MODE = 'zeros' # {'zeros', 'reflect', 'replicate', 'circular'}
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, padding_mode = PADDING_MODE, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, padding_mode = PADDING_MODE,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
#self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# remove extra layers
#x = self.avgpool(x)
#x = torch.flatten(x, 1)
#x = self.fc(x)
return x
def forward(self, x: Tensor):
return self._forward_impl(x)
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
#model.load_state_dict(state_dict)
model.load_state_dict(state_dict, strict=False)
return model
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) :
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
# ==============================================================================================================
# Model Class Definition
# ==============================================================================================================
================================================
FILE: run.sh
================================================
datapath=/data4/MVTec_ad
datasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' 'transistor' 'metal_nut' 'bottle' 'hazelnut' 'leather')
dataset_flags=($(for dataset in "${datasets[@]}"; do echo '-d '"${dataset}"; done))
python3 main.py \
--gpu 4 \
--seed 0 \
--log_group simplenet_mvtec \
--log_project MVTecAD_Results \
--results_path results \
--run_name run \
net \
-b wideresnet50 \
-le layer2 \
-le layer3 \
--pretrain_embed_dimension 1536 \
--target_embed_dimension 1536 \
--patchsize 3 \
--meta_epochs 40 \
--embedding_size 256 \
--gan_epochs 4 \
--noise_std 0.015 \
--dsc_hidden 1024 \
--dsc_layers 2 \
--dsc_margin .5 \
--pre_proj 1 \
dataset \
--batch_size 8 \
--resize 329 \
--imagesize 288 "${dataset_flags[@]}" mvtec $datapath
================================================
FILE: simplenet.py
================================================
# ------------------------------------------------------------------
# SimpleNet: A Simple Network for Image Anomaly Detection and Localization (https://openaccess.thecvf.com/content/CVPR2023/papers/Liu_SimpleNet_A_Simple_Network_for_Image_Anomaly_Detection_and_Localization_CVPR_2023_paper.pdf)
# Github source: https://github.com/DonaldRR/SimpleNet
# Licensed under the MIT License [see LICENSE for details]
# The script is based on the code of PatchCore (https://github.com/amazon-science/patchcore-inspection)
# ------------------------------------------------------------------
"""detection methods."""
import logging
import os
import pickle
from collections import OrderedDict
import math
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch.utils.tensorboard import SummaryWriter
import common
import metrics
from utils import plot_segmentation_images
LOGGER = logging.getLogger(__name__)
def init_weight(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_normal_(m.weight)
elif isinstance(m, torch.nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight)
class Discriminator(torch.nn.Module):
def __init__(self, in_planes, n_layers=1, hidden=None):
super(Discriminator, self).__init__()
_hidden = in_planes if hidden is None else hidden
self.body = torch.nn.Sequential()
for i in range(n_layers-1):
_in = in_planes if i == 0 else _hidden
_hidden = int(_hidden // 1.5) if hidden is None else hidden
self.body.add_module('block%d'%(i+1),
torch.nn.Sequential(
torch.nn.Linear(_in, _hidden),
torch.nn.BatchNorm1d(_hidden),
torch.nn.LeakyReLU(0.2)
))
self.tail = torch.nn.Linear(_hidden, 1, bias=False)
self.apply(init_weight)
def forward(self,x):
x = self.body(x)
x = self.tail(x)
return x
class Projection(torch.nn.Module):
def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
super(Projection, self).__init__()
if out_planes is None:
out_planes = in_planes
self.layers = torch.nn.Sequential()
_in = None
_out = None
for i in range(n_layers):
_in = in_planes if i == 0 else _out
_out = out_planes
self.layers.add_module(f"{i}fc",
torch.nn.Linear(_in, _out))
if i < n_layers - 1:
# if layer_type > 0:
# self.layers.add_module(f"{i}bn",
# torch.nn.BatchNorm1d(_out))
if layer_type > 1:
self.layers.add_module(f"{i}relu",
torch.nn.LeakyReLU(.2))
self.apply(init_weight)
def forward(self, x):
# x = .1 * self.layers(x) + x
x = self.layers(x)
return x
class TBWrapper:
def __init__(self, log_dir):
self.g_iter = 0
self.logger = SummaryWriter(log_dir=log_dir)
def step(self):
self.g_iter += 1
class SimpleNet(torch.nn.Module):
def __init__(self, device):
"""anomaly detection class."""
super(SimpleNet, self).__init__()
self.device = device
def load(
self,
backbone,
layers_to_extract_from,
device,
input_shape,
pretrain_embed_dimension, # 1536
target_embed_dimension, # 1536
patchsize=3, # 3
patchstride=1,
embedding_size=None, # 256
meta_epochs=1, # 40
aed_meta_epochs=1,
gan_epochs=1, # 4
noise_std=0.05,
mix_noise=1,
noise_type="GAU",
dsc_layers=2, # 2
dsc_hidden=None, # 1024
dsc_margin=.8, # .5
dsc_lr=0.0002,
train_backbone=False,
auto_noise=0,
cos_lr=False,
lr=1e-3,
pre_proj=0, # 1
proj_layer_type=0,
**kwargs,
):
pid = os.getpid()
def show_mem():
return(psutil.Process(pid).memory_info())
self.backbone = backbone.to(device)
self.layers_to_extract_from = layers_to_extract_from
self.input_shape = input_shape
self.device = device
self.patch_maker = PatchMaker(patchsize, stride=patchstride)
self.forward_modules = torch.nn.ModuleDict({})
feature_aggregator = common.NetworkFeatureAggregator(
self.backbone, self.layers_to_extract_from, self.device, train_backbone
)
feature_dimensions = feature_aggregator.feature_dimensions(input_shape)
self.forward_modules["feature_aggregator"] = feature_aggregator
preprocessing = common.Preprocessing(
feature_dimensions, pretrain_embed_dimension
)
self.forward_modules["preprocessing"] = preprocessing
self.target_embed_dimension = target_embed_dimension
preadapt_aggregator = common.Aggregator(
target_dim=target_embed_dimension
)
_ = preadapt_aggregator.to(self.device)
self.forward_modules["preadapt_aggregator"] = preadapt_aggregator
self.anomaly_segmentor = common.RescaleSegmentor(
device=self.device, target_size=input_shape[-2:]
)
self.embedding_size = embedding_size if embedding_size is not None else self.target_embed_dimension
self.meta_epochs = meta_epochs
self.lr = lr
self.cos_lr = cos_lr
self.train_backbone = train_backbone
if self.train_backbone:
self.backbone_opt = torch.optim.AdamW(self.forward_modules["feature_aggregator"].backbone.parameters(), lr)
# AED
self.aed_meta_epochs = aed_meta_epochs
self.pre_proj = pre_proj
if self.pre_proj > 0:
self.pre_projection = Projection(self.target_embed_dimension, self.target_embed_dimension, pre_proj, proj_layer_type)
self.pre_projection.to(self.device)
self.proj_opt = torch.optim.AdamW(self.pre_projection.parameters(), lr*.1)
# Discriminator
self.auto_noise = [auto_noise, None]
self.dsc_lr = dsc_lr
self.gan_epochs = gan_epochs
self.mix_noise = mix_noise
self.noise_type = noise_type
self.noise_std = noise_std
self.discriminator = Discriminator(self.target_embed_dimension, n_layers=dsc_layers, hidden=dsc_hidden)
self.discriminator.to(self.device)
self.dsc_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.dsc_lr, weight_decay=1e-5)
self.dsc_schl = torch.optim.lr_scheduler.CosineAnnealingLR(self.dsc_opt, (meta_epochs - aed_meta_epochs) * gan_epochs, self.dsc_lr*.4)
self.dsc_margin= dsc_margin
self.model_dir = ""
self.dataset_name = ""
self.tau = 1
self.logger = None
def set_model_dir(self, model_dir, dataset_name):
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
self.ckpt_dir = os.path.join(self.model_dir, dataset_name)
os.makedirs(self.ckpt_dir, exist_ok=True)
self.tb_dir = os.path.join(self.ckpt_dir, "tb")
os.makedirs(self.tb_dir, exist_ok=True)
self.logger = TBWrapper(self.tb_dir) #SummaryWriter(log_dir=tb_dir)
def embed(self, data):
if isinstance(data, torch.utils.data.DataLoader):
features = []
for image in data:
if isinstance(image, dict):
image = image["image"]
input_image = image.to(torch.float).to(self.device)
with torch.no_grad():
features.append(self._embed(input_image))
return features
return self._embed(data)
def _embed(self, images, detach=True, provide_patch_shapes=False, evaluation=False):
"""Returns feature embeddings for images."""
B = len(images)
if not evaluation and self.train_backbone:
self.forward_modules["feature_aggregator"].train()
features = self.forward_modules["feature_aggregator"](images, eval=evaluation)
else:
_ = self.forward_modules["feature_aggregator"].eval()
with torch.no_grad():
features = self.forward_modules["feature_aggregator"](images)
features = [features[layer] for layer in self.layers_to_extract_from]
for i, feat in enumerate(features):
if len(feat.shape) == 3:
B, L, C = feat.shape
features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2)
features = [
self.patch_maker.patchify(x, return_spatial_info=True) for x in features
]
patch_shapes = [x[1] for x in features]
features = [x[0] for x in features]
ref_num_patches = patch_shapes[0]
for i in range(1, len(features)):
_features = features[i]
patch_dims = patch_shapes[i]
# TODO(pgehler): Add comments
_features = _features.reshape(
_features.shape[0], patch_dims[0], patch_dims[1], *_features.shape[2:]
)
_features = _features.permute(0, -3, -2, -1, 1, 2)
perm_base_shape = _features.shape
_features = _features.reshape(-1, *_features.shape[-2:])
_features = F.interpolate(
_features.unsqueeze(1),
size=(ref_num_patches[0], ref_num_patches[1]),
mode="bilinear",
align_corners=False,
)
_features = _features.squeeze(1)
_features = _features.reshape(
*perm_base_shape[:-2], ref_num_patches[0], ref_num_patches[1]
)
_features = _features.permute(0, -2, -1, 1, 2, 3)
_features = _features.reshape(len(_features), -1, *_features.shape[-3:])
features[i] = _features
features = [x.reshape(-1, *x.shape[-3:]) for x in features]
# As different feature backbones & patching provide differently
# sized features, these are brought into the correct form here.
features = self.forward_modules["preprocessing"](features) # pooling each feature to same channel and stack together
features = self.forward_modules["preadapt_aggregator"](features) # further pooling
return features, patch_shapes
def test(self, training_data, test_data, save_segmentation_images):
ckpt_path = os.path.join(self.ckpt_dir, "models.ckpt")
if os.path.exists(ckpt_path):
state_dicts = torch.load(ckpt_path, map_location=self.device)
if "pretrained_enc" in state_dicts:
self.feature_enc.load_state_dict(state_dicts["pretrained_enc"])
if "pretrained_dec" in state_dicts:
self.feature_dec.load_state_dict(state_dicts["pretrained_dec"])
aggregator = {"scores": [], "segmentations": [], "features": []}
scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)
aggregator["scores"].append(scores)
aggregator["segmentations"].append(segmentations)
aggregator["features"].append(features)
scores = np.array(aggregator["scores"])
min_scores = scores.min(axis=-1).reshape(-1, 1)
max_scores = scores.max(axis=-1).reshape(-1, 1)
scores = (scores - min_scores) / (max_scores - min_scores)
scores = np.mean(scores, axis=0)
segmentations = np.array(aggregator["segmentations"])
min_scores = (
segmentations.reshape(len(segmentations), -1)
.min(axis=-1)
.reshape(-1, 1, 1, 1)
)
max_scores = (
segmentations.reshape(len(segmentations), -1)
.max(axis=-1)
.reshape(-1, 1, 1, 1)
)
segmentations = (segmentations - min_scores) / (max_scores - min_scores)
segmentations = np.mean(segmentations, axis=0)
anomaly_labels = [
x[1] != "good" for x in test_data.dataset.data_to_iterate
]
if save_segmentation_images:
self.save_segmentation_images(test_data, segmentations, scores)
auroc = metrics.compute_imagewise_retrieval_metrics(
scores, anomaly_labels
)["auroc"]
# Compute PRO score & PW Auroc for all images
pixel_scores = metrics.compute_pixelwise_retrieval_metrics(
segmentations, masks_gt
)
full_pixel_auroc = pixel_scores["auroc"]
return auroc, full_pixel_auroc
def _evaluate(self, test_data, scores, segmentations, features, labels_gt, masks_gt):
scores = np.squeeze(np.array(scores))
img_min_scores = scores.min(axis=-1)
img_max_scores = scores.max(axis=-1)
scores = (scores - img_min_scores) / (img_max_scores - img_min_scores)
# scores = np.mean(scores, axis=0)
auroc = metrics.compute_imagewise_retrieval_metrics(
scores, labels_gt
)["auroc"]
if len(masks_gt) > 0:
segmentations = np.array(segmentations)
min_scores = (
segmentations.reshape(len(segmentations), -1)
.min(axis=-1)
.reshape(-1, 1, 1, 1)
)
max_scores = (
segmentations.reshape(len(segmentations), -1)
.max(axis=-1)
.reshape(-1, 1, 1, 1)
)
norm_segmentations = np.zeros_like(segmentations)
for min_score, max_score in zip(min_scores, max_scores):
norm_segmentations += (segmentations - min_score) / max(max_score - min_score, 1e-2)
norm_segmentations = norm_segmentations / len(scores)
# Compute PRO score & PW Auroc for all images
pixel_scores = metrics.compute_pixelwise_retrieval_metrics(
norm_segmentations, masks_gt)
# segmentations, masks_gt
full_pixel_auroc = pixel_scores["auroc"]
pro = metrics.compute_pro(np.squeeze(np.array(masks_gt)),
norm_segmentations)
else:
full_pixel_auroc = -1
pro = -1
return auroc, full_pixel_auroc, pro
def train(self, training_data, test_data):
state_dict = {}
ckpt_path = os.path.join(self.ckpt_dir, "ckpt.pth")
if os.path.exists(ckpt_path):
state_dict = torch.load(ckpt_path, map_location=self.device)
if 'discriminator' in state_dict:
self.discriminator.load_state_dict(state_dict['discriminator'])
if "pre_projection" in state_dict:
self.pre_projection.load_state_dict(state_dict["pre_projection"])
else:
self.load_state_dict(state_dict, strict=False)
self.predict(training_data, "train_")
scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)
auroc, full_pixel_auroc, anomaly_pixel_auroc = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)
return auroc, full_pixel_auroc, anomaly_pixel_auroc
def update_state_dict(d):
state_dict["discriminator"] = OrderedDict({
k:v.detach().cpu()
for k, v in self.discriminator.state_dict().items()})
if self.pre_proj > 0:
state_dict["pre_projection"] = OrderedDict({
k:v.detach().cpu()
for k, v in self.pre_projection.state_dict().items()})
best_record = None
for i_mepoch in range(self.meta_epochs):
self._train_discriminator(training_data)
# torch.cuda.empty_cache()
scores, segmentations, features, labels_gt, masks_gt = self.predict(test_data)
auroc, full_pixel_auroc, pro = self._evaluate(test_data, scores, segmentations, features, labels_gt, masks_gt)
self.logger.logger.add_scalar("i-auroc", auroc, i_mepoch)
self.logger.logger.add_scalar("p-auroc", full_pixel_auroc, i_mepoch)
self.logger.logger.add_scalar("pro", pro, i_mepoch)
if best_record is None:
best_record = [auroc, full_pixel_auroc, pro]
update_state_dict(state_dict)
# state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})
else:
if auroc > best_record[0]:
best_record = [auroc, full_pixel_auroc, pro]
update_state_dict(state_dict)
# state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})
elif auroc == best_record[0] and full_pixel_auroc > best_record[1]:
best_record[1] = full_pixel_auroc
best_record[2] = pro
update_state_dict(state_dict)
# state_dict = OrderedDict({k:v.detach().cpu() for k, v in self.state_dict().items()})
print(f"----- {i_mepoch} I-AUROC:{round(auroc, 4)}(MAX:{round(best_record[0], 4)})"
f" P-AUROC{round(full_pixel_auroc, 4)}(MAX:{round(best_record[1], 4)}) -----"
f" PRO-AUROC{round(pro, 4)}(MAX:{round(best_record[2], 4)}) -----")
torch.save(state_dict, ckpt_path)
return best_record
def _train_discriminator(self, input_data):
"""Computes and sets the support features for SPADE."""
_ = self.forward_modules.eval()
if self.pre_proj > 0:
self.pre_projection.train()
self.discriminator.train()
# self.feature_enc.eval()
# self.feature_dec.eval()
i_iter = 0
LOGGER.info(f"Training discriminator...")
with tqdm.tqdm(total=self.gan_epochs) as pbar:
for i_epoch in range(self.gan_epochs):
all_loss = []
all_p_true = []
all_p_fake = []
all_p_interp = []
embeddings_list = []
for data_item in input_data:
self.dsc_opt.zero_grad()
if self.pre_proj > 0:
self.proj_opt.zero_grad()
# self.dec_opt.zero_grad()
i_iter += 1
img = data_item["image"]
img = img.to(torch.float).to(self.device)
if self.pre_proj > 0:
true_feats = self.pre_projection(self._embed(img, evaluation=False)[0])
else:
true_feats = self._embed(img, evaluation=False)[0]
noise_idxs = torch.randint(0, self.mix_noise, torch.Size([true_feats.shape[0]]))
noise_one_hot = torch.nn.functional.one_hot(noise_idxs, num_classes=self.mix_noise).to(self.device) # (N, K)
noise = torch.stack([
torch.normal(0, self.noise_std * 1.1**(k), true_feats.shape)
for k in range(self.mix_noise)], dim=1).to(self.device) # (N, K, C)
noise = (noise * noise_one_hot.unsqueeze(-1)).sum(1)
fake_feats = true_feats + noise
scores = self.discriminator(torch.cat([true_feats, fake_feats]))
true_scores = scores[:len(true_feats)]
fake_scores = scores[len(fake_feats):]
th = self.dsc_margin
p_true = (true_scores.detach() >= th).sum() / len(true_scores)
p_fake = (fake_scores.detach() < -th).sum() / len(fake_scores)
true_loss = torch.clip(-true_scores + th, min=0)
fake_loss = torch.clip(fake_scores + th, min=0)
self.logger.logger.add_scalar(f"p_true", p_true, self.logger.g_iter)
self.logger.logger.add_scalar(f"p_fake", p_fake, self.logger.g_iter)
loss = true_loss.mean() + fake_loss.mean()
self.logger.logger.add_scalar("loss", loss, self.logger.g_iter)
self.logger.step()
loss.backward()
if self.pre_proj > 0:
self.proj_opt.step()
if self.train_backbone:
self.backbone_opt.step()
self.dsc_opt.step()
loss = loss.detach().cpu()
all_loss.append(loss.item())
all_p_true.append(p_true.cpu().item())
all_p_fake.append(p_fake.cpu().item())
if len(embeddings_list) > 0:
self.auto_noise[1] = torch.cat(embeddings_list).std(0).mean(-1)
if self.cos_lr:
self.dsc_schl.step()
all_loss = sum(all_loss) / len(input_data)
all_p_true = sum(all_p_true) / len(input_data)
all_p_fake = sum(all_p_fake) / len(input_data)
cur_lr = self.dsc_opt.state_dict()['param_groups'][0]['lr']
pbar_str = f"epoch:{i_epoch} loss:{round(all_loss, 5)} "
pbar_str += f"lr:{round(cur_lr, 6)}"
pbar_str += f" p_true:{round(all_p_true, 3)} p_fake:{round(all_p_fake, 3)}"
if len(all_p_interp) > 0:
pbar_str += f" p_interp:{round(sum(all_p_interp) / len(input_data), 3)}"
pbar.set_description_str(pbar_str)
pbar.update(1)
def predict(self, data, prefix=""):
if isinstance(data, torch.utils.data.DataLoader):
return self._predict_dataloader(data, prefix)
return self._predict(data)
def _predict_dataloader(self, dataloader, prefix):
"""This function provides anomaly scores/maps for full dataloaders."""
_ = self.forward_modules.eval()
img_paths = []
scores = []
masks = []
features = []
labels_gt = []
masks_gt = []
from sklearn.manifold import TSNE
with tqdm.tqdm(dataloader, desc="Inferring...", leave=False) as data_iterator:
for data in data_iterator:
if isinstance(data, dict):
labels_gt.extend(data["is_anomaly"].numpy().tolist())
if data.get("mask", None) is not None:
masks_gt.extend(data["mask"].numpy().tolist())
image = data["image"]
img_paths.extend(data['image_path'])
_scores, _masks, _feats = self._predict(image)
for score, mask, feat, is_anomaly in zip(_scores, _masks, _feats, data["is_anomaly"].numpy().tolist()):
scores.append(score)
masks.append(mask)
return scores, masks, features, labels_gt, masks_gt
def _predict(self, images):
"""Infer score and mask for a batch of images."""
images = images.to(torch.float).to(self.device)
_ = self.forward_modules.eval()
batchsize = images.shape[0]
if self.pre_proj > 0:
self.pre_projection.eval()
self.discriminator.eval()
with torch.no_grad():
features, patch_shapes = self._embed(images,
provide_patch_shapes=True,
evaluation=True)
if self.pre_proj > 0:
features = self.pre_projection(features)
# features = features.cpu().numpy()
# features = np.ascontiguousarray(features.cpu().numpy())
patch_scores = image_scores = -self.discriminator(features)
patch_scores = patch_scores.cpu().numpy()
image_scores = image_scores.cpu().numpy()
image_scores = self.patch_maker.unpatch_scores(
image_scores, batchsize=batchsize
)
image_scores = image_scores.reshape(*image_scores.shape[:2], -1)
image_scores = self.patch_maker.score(image_scores)
patch_scores = self.patch_maker.unpatch_scores(
patch_scores, batchsize=batchsize
)
scales = patch_shapes[0]
patch_scores = patch_scores.reshape(batchsize, scales[0], scales[1])
features = features.reshape(batchsize, scales[0], scales[1], -1)
masks, features = self.anomaly_segmentor.convert_to_segmentation(patch_scores, features)
return list(image_scores), list(masks), list(features)
@staticmethod
def _params_file(filepath, prepend=""):
return os.path.join(filepath, prepend + "params.pkl")
def save_to_path(self, save_path: str, prepend: str = ""):
LOGGER.info("Saving data.")
self.anomaly_scorer.save(
save_path, save_features_separately=False, prepend=prepend
)
params = {
"backbone.name": self.backbone.name,
"layers_to_extract_from": self.layers_to_extract_from,
"input_shape": self.input_shape,
"pretrain_embed_dimension": self.forward_modules[
"preprocessing"
].output_dim,
"target_embed_dimension": self.forward_modules[
"preadapt_aggregator"
].target_dim,
"patchsize": self.patch_maker.patchsize,
"patchstride": self.patch_maker.stride,
"anomaly_scorer_num_nn": self.anomaly_scorer.n_nearest_neighbours,
}
with open(self._params_file(save_path, prepend), "wb") as save_file:
pickle.dump(params, save_file, pickle.HIGHEST_PROTOCOL)
def save_segmentation_images(self, data, segmentations, scores):
image_paths = [
x[2] for x in data.dataset.data_to_iterate
]
mask_paths = [
x[3] for x in data.dataset.data_to_iterate
]
def image_transform(image):
in_std = np.array(
data.dataset.transform_std
).reshape(-1, 1, 1)
in_mean = np.array(
data.dataset.transform_mean
).reshape(-1, 1, 1)
image = data.dataset.transform_img(image)
return np.clip(
(image.numpy() * in_std + in_mean) * 255, 0, 255
).astype(np.uint8)
def mask_transform(mask):
return data.dataset.transform_mask(mask).numpy()
plot_segmentation_images(
'./output',
image_paths,
segmentations,
scores,
mask_paths,
image_transform=image_transform,
mask_transform=mask_transform,
)
# Image handling classes.
class PatchMaker:
def __init__(self, patchsize, top_k=0, stride=None):
self.patchsize = patchsize
self.stride = stride
self.top_k = top_k
def patchify(self, features, return_spatial_info=False):
"""Convert a tensor into a tensor of respective patches.
Args:
x: [torch.Tensor, bs x c x w x h]
Returns:
x: [torch.Tensor, bs * w//stride * h//stride, c, patchsize,
patchsize]
"""
padding = int((self.patchsize - 1) / 2)
unfolder = torch.nn.Unfold(
kernel_size=self.patchsize, stride=self.stride, padding=padding, dilation=1
)
unfolded_features = unfolder(features)
number_of_total_patches = []
for s in features.shape[-2:]:
n_patches = (
s + 2 * padding - 1 * (self.patchsize - 1) - 1
) / self.stride + 1
number_of_total_patches.append(int(n_patches))
unfolded_features = unfolded_features.reshape(
*features.shape[:2], self.patchsize, self.patchsize, -1
)
unfolded_features = unfolded_features.permute(0, 4, 1, 2, 3)
if return_spatial_info:
return unfolded_features, number_of_total_patches
return unfolded_features
def unpatch_scores(self, x, batchsize):
return x.reshape(batchsize, -1, *x.shape[1:])
def score(self, x):
was_numpy = False
if isinstance(x, np.ndarray):
was_numpy = True
x = torch.from_numpy(x)
while x.ndim > 2:
x = torch.max(x, dim=-1).values
if x.ndim == 2:
if self.top_k > 1:
x = torch.topk(x, self.top_k, dim=1).values.mean(1)
else:
x = torch.max(x, dim=1).values
if was_numpy:
return x.numpy()
return x
================================================
FILE: utils.py
================================================
import csv
import logging
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import tqdm
LOGGER = logging.getLogger(__name__)
def plot_segmentation_images(
savefolder,
image_paths,
segmentations,
anomaly_scores=None,
mask_paths=None,
image_transform=lambda x: x,
mask_transform=lambda x: x,
save_depth=4,
):
"""Generate anomaly segmentation images.
Args:
image_paths: List[str] List of paths to images.
segmentations: [List[np.ndarray]] Generated anomaly segmentations.
anomaly_scores: [List[float]] Anomaly scores for each image.
mask_paths: [List[str]] List of paths to ground truth masks.
image_transform: [function or lambda] Optional transformation of images.
mask_transform: [function or lambda] Optional transformation of masks.
save_depth: [int] Number of path-strings to use for image savenames.
"""
if mask_paths is None:
mask_paths = ["-1" for _ in range(len(image_paths))]
masks_provided = mask_paths[0] != "-1"
if anomaly_scores is None:
anomaly_scores = ["-1" for _ in range(len(image_paths))]
os.makedirs(savefolder, exist_ok=True)
for image_path, mask_path, anomaly_score, segmentation in tqdm.tqdm(
zip(image_paths, mask_paths, anomaly_scores, segmentations),
total=len(image_paths),
desc="Generating Segmentation Images...",
leave=False,
):
image = PIL.Image.open(image_path).convert("RGB")
image = image_transform(image)
if not isinstance(image, np.ndarray):
image = image.numpy()
if masks_provided:
if mask_path is not None:
mask = PIL.Image.open(mask_path).convert("RGB")
mask = mask_transform(mask)
if not isinstance(mask, np.ndarray):
mask = mask.numpy()
else:
mask = np.zeros_like(image)
savename = image_path.split("/")
savename = "_".join(savename[-save_depth:])
savename = os.path.join(savefolder, savename)
f, axes = plt.subplots(1, 2 + int(masks_provided))
axes[0].imshow(image.transpose(1, 2, 0))
axes[1].imshow(mask.transpose(1, 2, 0))
axes[2].imshow(segmentation)
f.set_size_inches(3 * (2 + int(masks_provided)), 3)
f.tight_layout()
f.savefig(savename)
plt.close()
def create_storage_folder(
main_folder_path, project_folder, group_folder, run_name, mode="iterate"
):
os.makedirs(main_folder_path, exist_ok=True)
project_path = os.path.join(main_folder_path, project_folder)
os.makedirs(project_path, exist_ok=True)
save_path = os.path.join(project_path, group_folder, run_name)
if mode == "iterate":
counter = 0
while os.path.exists(save_path):
save_path = os.path.join(project_path, group_folder + "_" + str(counter))
counter += 1
os.makedirs(save_path)
elif mode == "overwrite":
os.makedirs(save_path, exist_ok=True)
return save_path
def set_torch_device(gpu_ids):
"""Returns correct torch.device.
Args:
gpu_ids: [list] list of gpu ids. If empty, cpu is used.
"""
if len(gpu_ids):
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_ids[0])
return torch.device("cuda:{}".format(gpu_ids[0]))
return torch.device("cpu")
def fix_seeds(seed, with_torch=True, with_cuda=True):
"""Fixed available seeds for reproducibility.
Args:
seed: [int] Seed value.
with_torch: Flag. If true, torch-related seeds are fixed.
with_cuda: Flag. If true, torch+cuda-related seeds are fixed
"""
random.seed(seed)
np.random.seed(seed)
if with_torch:
torch.manual_seed(seed)
if with_cuda:
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
def compute_and_store_final_results(
results_path,
results,
row_names=None,
column_names=[
"Instance AUROC",
"Full Pixel AUROC",
"Full PRO",
"Anomaly Pixel AUROC",
"Anomaly PRO",
],
):
"""Store computed results as CSV file.
Args:
results_path: [str] Where to store result csv.
results: [List[List]] List of lists containing results per dataset,
with results[i][0] == 'dataset_name' and results[i][1:6] =
[instance_auroc, full_pixelwisew_auroc, full_pro,
anomaly-only_pw_auroc, anomaly-only_pro]
"""
if row_names is not None:
assert len(row_names) == len(results), "#Rownames != #Result-rows."
mean_metrics = {}
for i, result_key in enumerate(column_names):
mean_metrics[result_key] = np.mean([x[i] for x in results])
LOGGER.info("{0}: {1:3.3f}".format(result_key, mean_metrics[result_key]))
savename = os.path.join(results_path, "results.csv")
with open(savename, "w") as csv_file:
csv_writer = csv.writer(csv_file, delimiter=",")
header = column_names
if row_names is not None:
header = ["Row Names"] + header
csv_writer.writerow(header)
for i, result_list in enumerate(results):
csv_row = result_list
if row_names is not None:
csv_row = [row_names[i]] + result_list
csv_writer.writerow(csv_row)
mean_scores = list(mean_metrics.values())
if row_names is not None:
mean_scores = ["Mean"] + mean_scores
csv_writer.writerow(mean_scores)
mean_metrics = {"mean_{0}".format(key): item for key, item in mean_metrics.items()}
return mean_metrics
gitextract_ap_hyzgr/ ├── .gitignore ├── LICENSE ├── README.md ├── VERSION ├── backbones.py ├── common.py ├── datasets/ │ ├── __init__.py │ ├── btad.py │ ├── cifar10.py │ ├── mvtec.py │ ├── sdd.py │ └── sdd2.py ├── main.py ├── metrics.py ├── resnet.py ├── run.sh ├── simplenet.py └── utils.py
SYMBOL INDEX (125 symbols across 12 files)
FILE: backbones.py
function load_ref_wrn50 (line 5) | def load_ref_wrn50():
function load (line 61) | def load(name):
FILE: common.py
class _BaseMerger (line 10) | class _BaseMerger:
method __init__ (line 11) | def __init__(self):
method merge (line 14) | def merge(self, features: list):
class AverageMerger (line 19) | class AverageMerger(_BaseMerger):
method _reduce (line 21) | def _reduce(features):
class ConcatMerger (line 28) | class ConcatMerger(_BaseMerger):
method _reduce (line 30) | def _reduce(features):
class Preprocessing (line 35) | class Preprocessing(torch.nn.Module):
method __init__ (line 36) | def __init__(self, input_dims, output_dim):
method forward (line 46) | def forward(self, features):
class MeanMapper (line 53) | class MeanMapper(torch.nn.Module):
method __init__ (line 54) | def __init__(self, preprocessing_dim):
method forward (line 58) | def forward(self, features):
class Aggregator (line 63) | class Aggregator(torch.nn.Module):
method __init__ (line 64) | def __init__(self, target_dim):
method forward (line 68) | def forward(self, features):
class RescaleSegmentor (line 76) | class RescaleSegmentor:
method __init__ (line 77) | def __init__(self, device, target_size=224):
method convert_to_segmentation (line 82) | def convert_to_segmentation(self, patch_scores, features):
class NetworkFeatureAggregator (line 124) | class NetworkFeatureAggregator(torch.nn.Module):
method __init__ (line 127) | def __init__(self, backbone, layers_to_extract_from, device, train_bac...
method forward (line 173) | def forward(self, images, eval=True):
method feature_dimensions (line 187) | def feature_dimensions(self, input_shape):
class ForwardHook (line 194) | class ForwardHook:
method __init__ (line 195) | def __init__(self, hook_dict, layer_name: str, last_layer_to_extract: ...
method __call__ (line 202) | def __call__(self, module, input, output):
class LastLayerToExtractReachedException (line 209) | class LastLayerToExtractReachedException(Exception):
FILE: datasets/btad.py
class DatasetSplit (line 18) | class DatasetSplit(Enum):
class BTADDataset (line 24) | class BTADDataset(torch.utils.data.Dataset):
method __init__ (line 29) | def __init__(
method __getitem__ (line 98) | def __getitem__(self, idx):
method __len__ (line 119) | def __len__(self):
method get_image_data (line 122) | def get_image_data(self):
FILE: datasets/cifar10.py
class DatasetSplit (line 13) | class DatasetSplit(Enum):
class Cifar10Dataset (line 19) | class Cifar10Dataset(torch.utils.data.Dataset):
method __init__ (line 26) | def __init__(
method __getitem__ (line 95) | def __getitem__(self, idx):
method __len__ (line 110) | def __len__(self):
method get_image_data (line 113) | def get_image_data(self):
FILE: datasets/mvtec.py
class DatasetSplit (line 30) | class DatasetSplit(Enum):
class MVTecDataset (line 36) | class MVTecDataset(torch.utils.data.Dataset):
method __init__ (line 41) | def __init__(
method __getitem__ (line 110) | def __getitem__(self, idx):
method __len__ (line 131) | def __len__(self):
method get_image_data (line 134) | def get_image_data(self):
FILE: datasets/sdd.py
class DatasetSplit (line 15) | class DatasetSplit(Enum):
class SDDDataset (line 21) | class SDDDataset(torch.utils.data.Dataset):
method __init__ (line 26) | def __init__(
method __getitem__ (line 99) | def __getitem__(self, idx):
method __len__ (line 119) | def __len__(self):
method get_image_data (line 122) | def get_image_data(self):
FILE: datasets/sdd2.py
class DatasetSplit (line 15) | class DatasetSplit(Enum):
class SDD2Dataset (line 21) | class SDD2Dataset(torch.utils.data.Dataset):
method __init__ (line 26) | def __init__(
method __getitem__ (line 98) | def __getitem__(self, idx):
method __len__ (line 118) | def __len__(self):
method get_image_data (line 121) | def get_image_data(self):
FILE: main.py
function main (line 39) | def main(**kwargs):
function run (line 44) | def run(
function net (line 150) | def net(
function dataset (line 245) | def dataset(
FILE: metrics.py
function compute_imagewise_retrieval_metrics (line 7) | def compute_imagewise_retrieval_metrics(
function compute_pixelwise_retrieval_metrics (line 35) | def compute_pixelwise_retrieval_metrics(anomaly_segmentations, ground_tr...
function compute_pro (line 88) | def compute_pro(masks, amaps, num_th=200):
FILE: resnet.py
function conv3x3 (line 30) | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: in...
function conv1x1 (line 36) | def conv1x1(in_planes: int, out_planes: int, stride: int = 1):
class BasicBlock (line 41) | class BasicBlock(nn.Module):
method __init__ (line 44) | def __init__(
method forward (line 71) | def forward(self, x: Tensor):
class Bottleneck (line 90) | class Bottleneck(nn.Module):
method __init__ (line 99) | def __init__(
method forward (line 125) | def forward(self, x: Tensor):
class ResNet (line 148) | class ResNet(nn.Module):
method __init__ (line 150) | def __init__(
method _make_layer (line 207) | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], plan...
method _forward_impl (line 231) | def _forward_impl(self, x: Tensor):
method forward (line 248) | def forward(self, x: Tensor):
function _resnet (line 252) | def _resnet(
function resnet18 (line 268) | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet34 (line 280) | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet50 (line 292) | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: ...
function resnet101 (line 304) | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs:...
function resnet152 (line 316) | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs:...
function resnext50_32x4d (line 328) | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **k...
function resnext101_32x8d (line 342) | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **...
function wide_resnet50_2 (line 356) | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **k...
function wide_resnet101_2 (line 374) | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **...
FILE: simplenet.py
function init_weight (line 27) | def init_weight(m):
class Discriminator (line 35) | class Discriminator(torch.nn.Module):
method __init__ (line 36) | def __init__(self, in_planes, n_layers=1, hidden=None):
method forward (line 53) | def forward(self,x):
class Projection (line 59) | class Projection(torch.nn.Module):
method __init__ (line 61) | def __init__(self, in_planes, out_planes=None, n_layers=1, layer_type=0):
method forward (line 83) | def forward(self, x):
class TBWrapper (line 90) | class TBWrapper:
method __init__ (line 92) | def __init__(self, log_dir):
method step (line 96) | def step(self):
class SimpleNet (line 99) | class SimpleNet(torch.nn.Module):
method __init__ (line 100) | def __init__(self, device):
method load (line 105) | def load(
method set_model_dir (line 205) | def set_model_dir(self, model_dir, dataset_name):
method embed (line 216) | def embed(self, data):
method _embed (line 228) | def _embed(self, images, detach=True, provide_patch_shapes=False, eval...
method test (line 289) | def test(self, training_data, test_data, save_segmentation_images):
method _evaluate (line 344) | def _evaluate(self, test_data, scores, segmentations, features, labels...
method train (line 389) | def train(self, training_data, test_data):
method _train_discriminator (line 455) | def _train_discriminator(self, input_data):
method predict (line 543) | def predict(self, data, prefix=""):
method _predict_dataloader (line 548) | def _predict_dataloader(self, dataloader, prefix):
method _predict (line 576) | def _predict(self, images):
method _params_file (line 615) | def _params_file(filepath, prepend=""):
method save_to_path (line 618) | def save_to_path(self, save_path: str, prepend: str = ""):
method save_segmentation_images (line 640) | def save_segmentation_images(self, data, segmentations, scores):
class PatchMaker (line 674) | class PatchMaker:
method __init__ (line 675) | def __init__(self, patchsize, top_k=0, stride=None):
method patchify (line 680) | def patchify(self, features, return_spatial_info=False):
method unpatch_scores (line 708) | def unpatch_scores(self, x, batchsize):
method score (line 711) | def score(self, x):
FILE: utils.py
function plot_segmentation_images (line 15) | def plot_segmentation_images(
function create_storage_folder (line 77) | def create_storage_folder(
function set_torch_device (line 96) | def set_torch_device(gpu_ids):
function fix_seeds (line 109) | def fix_seeds(seed, with_torch=True, with_cuda=True):
function compute_and_store_final_results (line 127) | def compute_and_store_final_results(
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (116K chars).
[
{
"path": ".gitignore",
"chars": 14,
"preview": "__pycache__/*\n"
},
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2023 DonaldRR\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 1914,
"preview": "# SimpleNet\n\n\n\n\n**SimpleNet: A Simple Network for Image Anomaly Detection and Localization**\n\n*Zhikan"
},
{
"path": "VERSION",
"chars": 6,
"preview": "0.1.0\n"
},
{
"path": "backbones.py",
"chars": 3596,
"preview": "import timm # noqa\nimport torch\nimport torchvision.models as models # noqa\n\ndef load_ref_wrn50():\n \n import resn"
},
{
"path": "common.py",
"chars": 7808,
"preview": "import copy\nfrom typing import List\n\nimport numpy as np\nimport scipy.ndimage as ndimage\nimport torch\nimport torch.nn.fun"
},
{
"path": "datasets/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "datasets/btad.py",
"chars": 6669,
"preview": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n_CLASSNAMES = [\n \"01\",\n "
},
{
"path": "datasets/cifar10.py",
"chars": 4185,
"preview": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nIMAGENET_MEAN = [0.485, 0."
},
{
"path": "datasets/mvtec.py",
"chars": 6856,
"preview": "import os\nfrom enum import Enum\n\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n_CLASSNAMES = [\n \"bottle"
},
{
"path": "datasets/sdd.py",
"chars": 5568,
"preview": "import os\nfrom enum import Enum\nimport pickle\n\nimport cv2\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nI"
},
{
"path": "datasets/sdd2.py",
"chars": 4989,
"preview": "import os\nfrom enum import Enum\nimport pickle\n\nimport cv2\nimport PIL\nimport torch\nfrom torchvision import transforms\n\n\nI"
},
{
"path": "main.py",
"chars": 12740,
"preview": "# ------------------------------------------------------------------\n# SimpleNet: A Simple Network for Image Anomaly Det"
},
{
"path": "metrics.py",
"chars": 4341,
"preview": "\"\"\"Anomaly metrics.\"\"\"\nimport cv2\nimport numpy as np\nfrom sklearn import metrics\n\n\ndef compute_imagewise_retrieval_metri"
},
{
"path": "resnet.py",
"chars": 15772,
"preview": "import torch\nfrom torch import Tensor\nimport torch.nn as nn\nfrom typing import Type, Any, Callable, Union, List, Optiona"
},
{
"path": "run.sh",
"chars": 781,
"preview": "datapath=/data4/MVTec_ad\ndatasets=('screw' 'pill' 'capsule' 'carpet' 'grid' 'tile' 'wood' 'zipper' 'cable' 'toothbrush' "
},
{
"path": "simplenet.py",
"chars": 29117,
"preview": "# ------------------------------------------------------------------\n# SimpleNet: A Simple Network for Image Anomaly Det"
},
{
"path": "utils.py",
"chars": 5825,
"preview": "import csv\nimport logging\nimport os\nimport random\n\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport PIL\nimport "
}
]
About this extraction
This page contains the full source code of the DonaldRR/SimpleNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (108.6 KB), approximately 25.9k tokens, and a symbol index with 125 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.