Repository: nomewang/M3DM Branch: main Commit: 62b7f468aa00 Files: 20 Total size: 156.1 KB Directory structure: gitextract_ziog4jk2/ ├── LICENSE ├── README.md ├── dataset.py ├── engine_fusion_pretrain.py ├── feature_extractors/ │ ├── features.py │ └── multiple_features.py ├── fusion_pretrain.py ├── m3dm_runner.py ├── main.py ├── models/ │ ├── feature_fusion.py │ ├── models.py │ └── pointnet2_utils.py ├── requirements.txt └── utils/ ├── au_pro_util.py ├── lr_sched.py ├── misc.py ├── mvtec3d_util.py ├── preprocess_eyecandies.py ├── preprocessing.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023 nomewang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Multimodal Industrial Anomaly Detection via Hybrid Fusion (CVPR 2023) ## Abstract > 2D-based Industrial Anomaly Detection has been widely discussed, however, multimodal industrial anomaly detection based on 3D point clouds and RGB images still has many untouched fields. Existing multimodal industrial anomaly detection methods directly concatenate the multimodal features, which leads to a strong disturbance between features and harms the detection performance. In this paper, we propose **Multi-3D-Memory** (**M3DM**), a novel multimodal anomaly detection method with hybrid fusion scheme: firstly, we design an unsupervised feature fusion with patch-wise contrastive learning to encourage the interaction of different modal features; secondly, we use a decision layer fusion with multiple memory banks to avoid loss of information and additional novelty classifiers to make the final decision. We further propose a point feature alignment operation to better align the point cloud and RGB features. Extensive experiments show that our multimodal industrial anomaly detection model outperforms the state-of-the-art (SOTA) methods on both detection and segmentation precision on MVTec-3D AD dataset. ![piplien](figures/pipeline.png) - `The pipeline of Multi-3D-Memory (M3DM).` Our M3DM contains three important parts: (1) **Point Feature Alignment** (PFA) converts Point Group features to plane features with interpolation and project operation, $\text{FPS}$ is the farthest point sampling and $\mathcal{F_{pt}}$ is a pretrained Point Transformer; (2) **Unsupervised Feature Fusion** (UFF) fuses point feature and image feature together with a patch-wise contrastive loss $\mathcal{L_{con}}$, where $\mathcal{F_{rgb}}$ is a Vision Transformer, $\chi_{rgb},\chi_{pt}$ are MLP layers and $\sigma_r, \sigma_p$ are single fully connected layers; (3) **Decision Layer Fusion** (DLF) combines multimodal information with multiple memory banks and makes the final decision with 2 learnable modules $\mathcal D_a, \mathcal{D_s}$ for anomaly detection and segmentation, where $\mathcal{M_{rgb}}$, $\mathcal{M_{fs}}$, $\mathcal{M_{pt}}$ are memory banks, $\phi, \psi$ are score function for single memory bank detection and segmentation, and $\mathcal{P}$ is the memory bank building algorithm. ### [Paper](https://arxiv.org/pdf/2303.00601.pdf) ## Setup We implement this repo with the following environment: - Ubuntu 18.04 - Python 3.8 - Pytorch 1.9.0 - CUDA 11.3 Install the other package via: ``` bash pip install -r requirement.txt # install knn_cuda pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl # install pointnet2_ops_lib pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib" ``` ## Data Download and Preprocess ### Dataset - The `MVTec-3D AD` dataset can be download from the [Official Website of MVTec-3D AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad). - The `Eyecandies` dataset can be download from the [Official Website of Eyecandies](https://eyecan-ai.github.io/eyecandies/). After download, put the dataset in `dataset` folder. ### Datapreprocess To run the preprocessing ```bash python utils/preprocessing.py datasets/mvtec3d/ ``` It may take a few hours to run the preprocessing. ### Checkpoints The following table lists the pretrain model used in M3DM: | Backbone | Pretrain Method | | ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | Point Transformer | [Point-MAE](https://drive.google.com/file/d/1-wlRIz0GM8o6BuPTJz4kTt6c_z1Gh6LX/view?usp=sharing) | | Point Transformer | [Point-Bert](https://cloud.tsinghua.edu.cn/f/202b29805eea45d7be92/?dl=1) | | ViT-b/8 | [DINO](https://drive.google.com/file/d/17s6lwfxwG_nf1td6LXunL-LjRaX67iyK/view?usp=sharing) | | ViT-b/8 | [Supervised ImageNet 1K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) | | ViT-b/8 | [Supervised ImageNet 21K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz) | | ViT-s/8 | [DINO](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth) | | UFF | [UFF Module](https://drive.google.com/file/d/1Z2AkfPqenJEv-IdWhVdRcvVQAsJC4DxW/view?usp=sharing) | Put the checkpoint files in `checkpoints` folder. ## Train and Test Train and test the double lib version and save the feature for UFF training: ```bash mkdir -p datasets/patch_lib python3 main.py \ --method_name DINO+Point_MAE \ --memory_bank multiple \ --rgb_backbone_name vit_base_patch8_224_dino \ --xyz_backbone_name Point_MAE \ --save_feature \ ``` Train the UFF: ```bash OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=1 fusion_pretrain.py \ --accum_iter 16 \ --lr 0.003 \ --batch_size 16 \ --data_path datasets/patch_lib \ --output_dir checkpoints \ ``` Train and test the full setting with the following command: ```bash python3 main.py \ --method_name DINO+Point_MAE+Fusion \ --use_uff \ --memory_bank multiple \ --rgb_backbone_name vit_base_patch8_224_dino \ --xyz_backbone_name Point_MAE \ --fusion_module_path checkpoints/{FUSION_CHECKPOINT}.pth \ ``` Note: if you set `--method_name DINO` or `--method_name Point_MAE`, set `--memory_bank single` at the same time. If you find this repository useful for your research, please use the following. ```bibtex @inproceedings{wang2023multimodal, title={Multimodal Industrial Anomaly Detection via Hybrid Fusion}, author={Wang, Yue and Peng, Jinlong and Zhang, Jiangning and Yi, Ran and Wang, Yabiao and Wang, Chengjie}, booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, pages={8032--8041}, year={2023} } ``` ## Thanks Our repo is built on [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [MoCo-v3](https://github.com/facebookresearch/moco-v3), thanks their extraordinary works! ================================================ FILE: dataset.py ================================================ import os from PIL import Image from torchvision import transforms import glob from torch.utils.data import Dataset from utils.mvtec3d_util import * from torch.utils.data import DataLoader import numpy as np def eyecandies_classes(): return [ 'CandyCane', 'ChocolateCookie', 'ChocolatePraline', 'Confetto', 'GummyBear', 'HazelnutTruffle', 'LicoriceSandwich', 'Lollipop', 'Marshmallow', 'PeppermintCandy', ] def mvtec3d_classes(): return [ "bagel", "cable_gland", "carrot", "cookie", "dowel", "foam", "peach", "potato", "rope", "tire", ] RGB_SIZE = 224 class BaseAnomalyDetectionDataset(Dataset): def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): self.IMAGENET_MEAN = [0.485, 0.456, 0.406] self.IMAGENET_STD = [0.229, 0.224, 0.225] self.cls = class_name self.size = img_size self.img_path = os.path.join(dataset_path, self.cls, split) self.rgb_transform = transforms.Compose( [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)]) class PreTrainTensorDataset(Dataset): def __init__(self, root_path): super().__init__() self.root_path = root_path self.tensor_paths = os.listdir(self.root_path) def __len__(self): return len(self.tensor_paths) def __getitem__(self, idx): tensor_path = self.tensor_paths[idx] tensor = torch.load(os.path.join(self.root_path, tensor_path)) label = 0 return tensor, label class TrainDataset(BaseAnomalyDetectionDataset): def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path) self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 def load_dataset(self): img_tot_paths = [] tot_labels = [] rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png") tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff") rgb_paths.sort() tiff_paths.sort() sample_paths = list(zip(rgb_paths, tiff_paths)) img_tot_paths.extend(sample_paths) tot_labels.extend([0] * len(sample_paths)) return img_tot_paths, tot_labels def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img_path, label = self.img_paths[idx], self.labels[idx] rgb_path = img_path[0] tiff_path = img_path[1] img = Image.open(rgb_path).convert('RGB') img = self.rgb_transform(img) organized_pc = read_tiff_organized_pc(tiff_path) depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) resized_organized_pc = resized_organized_pc.clone().detach().float() return (img, resized_organized_pc, resized_depth_map_3channel), label class TestDataset(BaseAnomalyDetectionDataset): def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'): super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path) self.gt_transform = transforms.Compose([ transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST), transforms.ToTensor()]) self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1 def load_dataset(self): img_tot_paths = [] gt_tot_paths = [] tot_labels = [] defect_types = os.listdir(self.img_path) for defect_type in defect_types: if defect_type == 'good': rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") rgb_paths.sort() tiff_paths.sort() sample_paths = list(zip(rgb_paths, tiff_paths)) img_tot_paths.extend(sample_paths) gt_tot_paths.extend([0] * len(sample_paths)) tot_labels.extend([0] * len(sample_paths)) else: rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png") tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff") gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png") rgb_paths.sort() tiff_paths.sort() gt_paths.sort() sample_paths = list(zip(rgb_paths, tiff_paths)) img_tot_paths.extend(sample_paths) gt_tot_paths.extend(gt_paths) tot_labels.extend([1] * len(sample_paths)) assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!" return img_tot_paths, gt_tot_paths, tot_labels def __len__(self): return len(self.img_paths) def __getitem__(self, idx): img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx] rgb_path = img_path[0] tiff_path = img_path[1] img_original = Image.open(rgb_path).convert('RGB') img = self.rgb_transform(img_original) organized_pc = read_tiff_organized_pc(tiff_path) depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2) resized_depth_map_3channel = resize_organized_pc(depth_map_3channel) resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size) resized_organized_pc = resized_organized_pc.clone().detach().float() if gt == 0: gt = torch.zeros( [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]]) else: gt = Image.open(gt).convert('L') gt = self.gt_transform(gt) gt = torch.where(gt > 0.5, 1., .0) return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path def get_data_loader(split, class_name, img_size, args): if split in ['train']: dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path) elif split in ['test']: dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path) data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False, pin_memory=True) return data_loader ================================================ FILE: engine_fusion_pretrain.py ================================================ import math import sys from typing import Iterable import torch import utils.misc as misc import utils.lr_sched as lr_sched def train_one_epoch(model: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, loss_scaler, log_writer=None, args=None): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 20 accum_iter = args.accum_iter optimizer.zero_grad() if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) xyz_samples = samples[:,:,:1152].to(device, non_blocking=True) rgb_samples = samples[:,:,1152:].to(device, non_blocking=True) with torch.cuda.amp.autocast(): loss = model(xyz_samples, rgb_samples) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss /= accum_iter loss_scaler(loss, optimizer, parameters=model.parameters(), update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) lr = optimizer.param_groups[0]["lr"] metric_logger.update(lr=lr) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', lr, epoch_1000x) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} ================================================ FILE: feature_extractors/features.py ================================================ """ PatchCore logic based on https://github.com/rvorias/ind_knn_ad """ import torch import numpy as np import os from tqdm import tqdm from matplotlib import pyplot as plt from sklearn import random_projection from sklearn import linear_model from sklearn.svm import OneClassSVM from sklearn.ensemble import IsolationForest from sklearn.metrics import roc_auc_score from timm.models.layers import DropPath, trunc_normal_ from pointnet2_ops import pointnet2_utils from knn_cuda import KNN from utils.utils import KNNGaussianBlur from utils.utils import set_seeds from utils.au_pro_util import calculate_au_pro from models.pointnet2_utils import interpolating_points from models.feature_fusion import FeatureFusionBlock from models.models import Model class Features(torch.nn.Module): def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9): super().__init__() self.device = "cuda" if torch.cuda.is_available() else "cpu" self.deep_feature_extractor = Model( device=self.device, rgb_backbone_name=args.rgb_backbone_name, xyz_backbone_name=args.xyz_backbone_name, group_size = args.group_size, num_group=args.num_group ) self.deep_feature_extractor.to(self.device) self.args = args self.image_size = args.img_size self.f_coreset = args.f_coreset self.coreset_eps = args.coreset_eps self.blur = KNNGaussianBlur(4) self.n_reweight = 3 set_seeds(0) self.patch_xyz_lib = [] self.patch_rgb_lib = [] self.patch_fusion_lib = [] self.patch_lib = [] self.random_state = args.random_state self.xyz_dim = 0 self.rgb_dim = 0 self.xyz_mean=0 self.xyz_std=0 self.rgb_mean=0 self.rgb_std=0 self.fusion_mean=0 self.fusion_std=0 self.average = torch.nn.AvgPool2d(3, stride=1) # torch.nn.AvgPool2d(1, stride=1) # self.resize = torch.nn.AdaptiveAvgPool2d((56, 56)) self.resize2 = torch.nn.AdaptiveAvgPool2d((56, 56)) self.image_preds = list() self.image_labels = list() self.pixel_preds = list() self.pixel_labels = list() self.gts = [] self.predictions = [] self.image_rocauc = 0 self.pixel_rocauc = 0 self.au_pro = 0 self.ins_id = 0 self.rgb_layernorm = torch.nn.LayerNorm(768, elementwise_affine=False) if self.args.use_uff: self.fusion = FeatureFusionBlock(1152, 768, mlp_ratio=4.) ckpt = torch.load(args.fusion_module_path)['model'] incompatible = self.fusion.load_state_dict(ckpt, strict=False) print('[Fusion Block]', incompatible) self.detect_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter) self.seg_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter) self.s_lib = [] self.s_map_lib = [] def __call__(self, rgb, xyz): # Extract the desired feature maps using the backbone model. rgb = rgb.to(self.device) xyz = xyz.to(self.device) with torch.no_grad(): rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz) interpolate = True if interpolate: interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps).to("cpu") xyz_feature_maps = [fmap.to("cpu") for fmap in [xyz_feature_maps]] rgb_feature_maps = [fmap.to("cpu") for fmap in [rgb_feature_maps]] if interpolate: return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps else: return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx def add_sample_to_mem_bank(self, sample): raise NotImplementedError def predict(self, sample, mask, label): raise NotImplementedError def add_sample_to_late_fusion_mem_bank(self, sample): raise NotImplementedError def interpolate_points(self, rgb, xyz): with torch.no_grad(): rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz) return xyz_feature_maps, center, xyz def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): raise NotImplementedError def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'): raise NotImplementedError def run_coreset(self): raise NotImplementedError def calculate_metrics(self): self.image_preds = np.stack(self.image_preds) self.image_labels = np.stack(self.image_labels) self.pixel_preds = np.array(self.pixel_preds) self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds) self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds) self.au_pro, _ = calculate_au_pro(self.gts, self.predictions) def save_prediction_maps(self, output_path, rgb_path, save_num=5): for i in range(max(save_num, len(self.predictions))): # fig = plt.figure(dpi=300) fig = plt.figure() ax3 = fig.add_subplot(1,3,1) gt = plt.imread(rgb_path[i][0]) ax3.imshow(gt) ax2 = fig.add_subplot(1,3,2) im2 = ax2.imshow(self.gts[i], cmap=plt.cm.gray) ax = fig.add_subplot(1,3,3) im = ax.imshow(self.predictions[i], cmap=plt.cm.jet) class_dir = os.path.join(output_path, rgb_path[i][0].split('/')[-5]) if not os.path.exists(class_dir): os.mkdir(class_dir) ad_dir = os.path.join(class_dir, rgb_path[i][0].split('/')[-3]) if not os.path.exists(ad_dir): os.mkdir(ad_dir) plt.savefig(os.path.join(ad_dir, str(self.image_preds[i]) + '_pred_' + rgb_path[i][0].split('/')[-1] + '.jpg')) def run_late_fusion(self): self.s_lib = torch.cat(self.s_lib, 0) self.s_map_lib = torch.cat(self.s_map_lib, 0) self.detect_fuser.fit(self.s_lib) self.seg_fuser.fit(self.s_map_lib) def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=True, force_cpu=False): print(f" Fitting random projections. Start dim = {z_lib.shape}.") try: transformer = random_projection.SparseRandomProjection(eps=eps, random_state=self.random_state) z_lib = torch.tensor(transformer.fit_transform(z_lib)) print(f" DONE. Transformed dim = {z_lib.shape}.") except ValueError: print(" Error: could not project vectors. Please increase `eps`.") select_idx = 0 last_item = z_lib[select_idx:select_idx + 1] coreset_idx = [torch.tensor(select_idx)] min_distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) if float16: last_item = last_item.half() z_lib = z_lib.half() min_distances = min_distances.half() if torch.cuda.is_available() and not force_cpu: last_item = last_item.to("cuda") z_lib = z_lib.to("cuda") min_distances = min_distances.to("cuda") for _ in tqdm(range(n - 1)): distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) # broadcasting step min_distances = torch.minimum(distances, min_distances) # iterative step select_idx = torch.argmax(min_distances) # selection step # bookkeeping last_item = z_lib[select_idx:select_idx + 1] min_distances[select_idx] = 0 coreset_idx.append(select_idx.to("cpu")) return torch.stack(coreset_idx) ================================================ FILE: feature_extractors/multiple_features.py ================================================ import torch from feature_extractors.features import Features from utils.mvtec3d_util import * import numpy as np import math import os class RGBFeatures(Features): def add_sample_to_mem_bank(self, sample): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, _, _, center_idx, _ = self(sample[0],unorganized_pc_no_zeros.contiguous()) rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T self.patch_lib.append(rgb_patch) def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, _ = self(sample[0], unorganized_pc_no_zeros.contiguous()) rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T self.compute_s_s_map(rgb_patch, rgb_feature_maps[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def run_coreset(self): self.patch_lib = torch.cat(self.patch_lib, 0) self.mean = torch.mean(self.patch_lib) self.std = torch.std(self.patch_lib) self.patch_lib = (self.patch_lib - self.mean)/self.std # self.patch_lib = self.rgb_layernorm(self.patch_lib) if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib, n=int(self.f_coreset * self.patch_lib.shape[0]), eps=self.coreset_eps, ) self.patch_lib = self.patch_lib[self.coreset_idx] def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' patch = (patch - self.mean)/self.std # self.patch_lib = self.rgb_layernorm(self.patch_lib) dist = torch.cdist(patch, self.patch_lib) min_val, min_idx = torch.min(dist, dim=1) # print(min_val.shape) s_idx = torch.argmax(min_val) s_star = torch.max(min_val) # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1) D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5)) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear') s_map = self.blur(s_map) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) class PointFeatures(Features): def add_sample_to_mem_bank(self, sample): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T self.patch_lib.append(xyz_patch) def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T self.compute_s_s_map(xyz_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def run_coreset(self): self.patch_lib = torch.cat(self.patch_lib, 0) if self.args.rm_zero_for_project: self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]] if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib, n=int(self.f_coreset * self.patch_lib.shape[0]), eps=self.coreset_eps, ) self.patch_lib = self.patch_lib[self.coreset_idx] if self.args.rm_zero_for_project: self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]] self.patch_lib = torch.cat((self.patch_lib, torch.zeros(1, self.patch_lib.shape[1])), 0) def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' dist = torch.cdist(patch, self.patch_lib) min_val, min_idx = torch.min(dist, dim=1) # print(min_val.shape) s_idx = torch.argmax(min_val) s_star = torch.max(min_val) # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1) D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5)) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear') s_map = self.blur(s_map) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) FUSION_BLOCK= True class FusionFeatures(Features): def add_sample_to_mem_bank(self, sample, class_name=None): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch_size = int(math.sqrt(rgb_patch.shape[0])) rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size)) rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T if FUSION_BLOCK: with torch.no_grad(): fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0)) fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach() else: fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1) if class_name is not None: torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt')) self.ins_id += 1 self.patch_lib.append(fusion_patch) def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_size = int(math.sqrt(rgb_patch.shape[0])) rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size)) rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T if FUSION_BLOCK: with torch.no_grad(): fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0)) fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach() else: fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1) self.compute_s_s_map(fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' dist = torch.cdist(patch, self.patch_lib) min_val, min_idx = torch.min(dist, dim=1) s_idx = torch.argmax(min_val) s_star = torch.max(min_val) # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1) D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)))) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear') s_map = self.blur(s_map) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) def run_coreset(self): self.patch_lib = torch.cat(self.patch_lib, 0) if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib, n=int(self.f_coreset * self.patch_lib.shape[0]), eps=self.coreset_eps) self.patch_lib = self.patch_lib[self.coreset_idx] class DoubleRGBPointFeatures(Features): def add_sample_to_mem_bank(self, sample, class_name=None): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1) patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1) if class_name is not None: torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt')) self.ins_id += 1 self.patch_xyz_lib.append(xyz_patch) self.patch_rgb_lib.append(rgb_patch) def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def add_sample_to_late_fusion_mem_bank(self, sample): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T # 2D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]]) s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0) self.s_lib.append(s) self.s_map_lib.append(s_map) def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' # 2D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]]) s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0) s = torch.tensor(self.detect_fuser.score_samples(s)) s_map = torch.tensor(self.seg_fuser.score_samples(s_map)) s_map = s_map.view(1, 224, 224) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'): min_val, min_idx = torch.min(dist, dim=1) s_idx = torch.argmax(min_val) s_star = torch.max(min_val)/1000 # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch if modal=='xyz': m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1 else: m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 if modal=='xyz': m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1)/1000 else: m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)/1000 D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)))) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear') s_map = self.blur(s_map) return s, s_map def run_coreset(self): self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0) self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0) self.xyz_mean = torch.mean(self.patch_xyz_lib) self.xyz_std = torch.std(self.patch_rgb_lib) self.rgb_mean = torch.mean(self.patch_xyz_lib) self.rgb_std = torch.std(self.patch_rgb_lib) self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx] self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx] class DoubleRGBPointFeatures_add(Features): def add_sample_to_mem_bank(self, sample, class_name=None): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1) patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1) if class_name is not None: torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt')) self.ins_id += 1 self.patch_xyz_lib.append(xyz_patch) self.patch_rgb_lib.append(rgb_patch) def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def add_sample_to_late_fusion_mem_bank(self, sample): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T # 2D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s = torch.tensor([[s_xyz, s_rgb]]) s_map = torch.cat([s_map_xyz, s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0) self.s_lib.append(s) self.s_map_lib.append(s_map) def run_coreset(self): self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0) self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0) self.xyz_mean = torch.mean(self.patch_xyz_lib) self.xyz_std = torch.std(self.patch_rgb_lib) self.rgb_mean = torch.mean(self.patch_xyz_lib) self.rgb_std = torch.std(self.patch_rgb_lib) self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx] self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx] def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' # 2D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s = s_xyz + s_rgb s_map = s_map_xyz + s_map_rgb s_map = s_map.view(1, self.image_size, self.image_size) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'): min_val, min_idx = torch.min(dist, dim=1) s_idx = torch.argmax(min_val) s_star = torch.max(min_val) # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch if modal=='xyz': m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1 else: m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 if modal=='xyz': m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1) else: m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1) D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)))) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False) s_map = self.blur(s_map) return s, s_map class TripleFeatures(Features): def add_sample_to_mem_bank(self, sample, class_name=None): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_size = int(math.sqrt(rgb_patch.shape[0])) rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size)) rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T self.patch_rgb_lib.append(rgb_patch) if self.args.asy_memory_bank is None or len(self.patch_xyz_lib) < self.args.asy_memory_bank: xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d)) xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T if FUSION_BLOCK: with torch.no_grad(): fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0)) fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach() else: fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1) self.patch_xyz_lib.append(xyz_patch) self.patch_fusion_lib.append(fusion_patch) if class_name is not None: torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt')) self.ins_id += 1 def predict(self, sample, mask, label): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d)) xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_size = int(math.sqrt(rgb_patch.shape[0])) rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size)) rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T if FUSION_BLOCK: with torch.no_grad(): fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0)) fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach() else: fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1) self.compute_s_s_map(xyz_patch, rgb_patch, fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx) def add_sample_to_late_fusion_mem_bank(self, sample): organized_pc = sample[1] organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy() unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1) rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous()) xyz_patch = torch.cat(xyz_feature_maps, 1) xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype) xyz_patch_full[:,:,nonzero_indices] = interpolated_pc xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size) xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d)) xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d)) xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T rgb_patch = torch.cat(rgb_feature_maps, 1) rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T rgb_patch_size = int(math.sqrt(rgb_patch.shape[0])) rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size)) rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T if FUSION_BLOCK: with torch.no_grad(): fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0)) fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach() else: fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1) # 3D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) fusion_feat_size = (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0]))) # 3 memory bank results s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion') s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]]) s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0) self.s_lib.append(s) self.s_map_lib.append(s_map) def run_coreset(self): self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0) self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0) self.patch_fusion_lib = torch.cat(self.patch_fusion_lib, 0) self.xyz_mean = torch.mean(self.patch_xyz_lib) self.xyz_std = torch.std(self.patch_rgb_lib) self.rgb_mean = torch.mean(self.patch_xyz_lib) self.rgb_std = torch.std(self.patch_rgb_lib) self.fusion_mean = torch.mean(self.patch_xyz_lib) self.fusion_std = torch.std(self.patch_rgb_lib) self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std self.patch_fusion_lib = (self.patch_fusion_lib - self.fusion_mean)/self.fusion_std if self.f_coreset < 1: self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx] self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx] self.coreset_idx = self.get_coreset_idx_randomp(self.patch_fusion_lib, n=int(self.f_coreset * self.patch_xyz_lib.shape[0]), eps=self.coreset_eps, ) self.patch_fusion_lib = self.patch_fusion_lib[self.coreset_idx] self.patch_xyz_lib = self.patch_xyz_lib[torch.nonzero(torch.all(self.patch_xyz_lib!=0, dim=1))[:,0]] self.patch_xyz_lib = torch.cat((self.patch_xyz_lib, torch.zeros(1, self.patch_xyz_lib.shape[1])), 0) def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx): ''' center: point group center position neighbour_idx: each group point index nonzero_indices: point indices of original point clouds xyz: nonzero point clouds ''' # 3D dist xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib) dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib) dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib) rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0]))) xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0]))) fusion_feat_size = (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0]))) s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz') s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb') s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion') s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]]) s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0) s = torch.tensor(self.detect_fuser.score_samples(s)) s_map = torch.tensor(self.seg_fuser.score_samples(s_map)) s_map = s_map.view(1, self.image_size, self.image_size) self.image_preds.append(s.numpy()) self.image_labels.append(label) self.pixel_preds.extend(s_map.flatten().numpy()) self.pixel_labels.extend(mask.flatten().numpy()) self.predictions.append(s_map.detach().cpu().squeeze().numpy()) self.gts.append(mask.detach().cpu().squeeze().numpy()) def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'): min_val, min_idx = torch.min(dist, dim=1) s_idx = torch.argmax(min_val) s_star = torch.max(min_val) # reweighting m_test = patch[s_idx].unsqueeze(0) # anomalous patch if modal=='xyz': m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1 elif modal=='rgb': m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1 else: m_star = self.patch_fusion_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour w_dist = torch.cdist(m_star, self.patch_fusion_lib) # find knn to m_star pt.1 _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 # equation 7 from the paper if modal=='xyz': m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1) elif modal=='rgb': m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1) else: m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1:]], dim=1) # sparse reweight # if modal=='rgb': # _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2 # else: # _, nn_idx = torch.topk(w_dist, k=4*self.n_reweight, largest=False) # pt.2 # if modal=='xyz': # m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1::4]], dim=1) # elif modal=='rgb': # m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1) # else: # m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1::4]], dim=1) # Softmax normalization trick as in transformers. # As the patch vectors grow larger, their norm might differ a lot. # exp(norm) can give infinities. D = torch.sqrt(torch.tensor(patch.shape[1])) w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)))) s = w * s_star # segmentation map s_map = min_val.view(1, 1, *feature_map_dims) s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False) s_map = self.blur(s_map) return s, s_map ================================================ FILE: fusion_pretrain.py ================================================ import argparse import datetime import json import numpy as np import os import time from pathlib import Path import torch import torch.backends.cudnn as cudnn from torch.utils.tensorboard import SummaryWriter import torchvision.transforms as transforms import timm import timm.optim.optim_factory as optim_factory import utils.misc as misc from utils.misc import NativeScalerWithGradNormCount as NativeScaler from engine_fusion_pretrain import train_one_epoch import dataset import torch from models.feature_fusion import FeatureFusionBlock def get_args_parser(): parser = argparse.ArgumentParser('MAE pre-training', add_help=False) parser.add_argument('--batch_size', default=64, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=3, type=int) parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') # Model parameters parser.add_argument('--input_size', default=224, type=int, help='images input size') # Optimizer parameters parser.add_argument('--weight_decay', type=float, default=1.5e-6, help='weight decay (default: 0.05)') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=0.002, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR') # Dataset parameters parser.add_argument('--data_path', default='', type=str, help='dataset path') parser.add_argument('--output_dir', default='./output_dir', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='./output_dir', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--num_workers', default=10, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') return parser def main(args): misc.init_distributed_mode(args) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True dataset_train = dataset.PreTrainTensorDataset(args.data_path) print(dataset_train) if True: # args.distributed: num_tasks = misc.get_world_size() global_rank = misc.get_rank() sampler_train = torch.utils.data.DistributedSampler( dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True ) print("Sampler_train = %s" % str(sampler_train)) else: sampler_train = torch.utils.data.RandomSampler(dataset_train) if global_rank == 0 and args.log_dir is not None: os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=args.log_dir) else: log_writer = None data_loader_train = torch.utils.data.DataLoader( dataset_train, sampler=sampler_train, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=args.pin_mem, drop_last=True, ) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 256 print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) model = FeatureFusionBlock(1152, 768) model.to(device) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) model_without_ddp = model.module # following timm: set wd as 0 for bias and norm layers optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, betas=(0.9, 0.95)) print(optimizer) loss_scaler = NativeScaler() misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) print(f"Start training for {args.epochs} epochs") start_time = time.time() for epoch in range(args.start_epoch, args.epochs): if args.distributed: data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args ) if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs): misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch,} if args.output_dir and misc.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args) ================================================ FILE: m3dm_runner.py ================================================ import torch from tqdm import tqdm import os from feature_extractors import multiple_features from dataset import get_data_loader class M3DM(): def __init__(self, args): self.args = args self.image_size = args.img_size self.count = args.max_sample if args.method_name == 'DINO': self.methods = { "DINO": multiple_features.RGBFeatures(args), } elif args.method_name == 'Point_MAE': self.methods = { "Point_MAE": multiple_features.PointFeatures(args), } elif args.method_name == 'Fusion': self.methods = { "Fusion": multiple_features.FusionFeatures(args), } elif args.method_name == 'DINO+Point_MAE': self.methods = { "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures(args), } elif args.method_name == 'DINO+Point_MAE+add': self.methods = { "DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures_add(args), } elif args.method_name == 'DINO+Point_MAE+Fusion': self.methods = { "DINO+Point_MAE+Fusion": multiple_features.TripleFeatures(args), } def fit(self, class_name): train_loader = get_data_loader("train", class_name=class_name, img_size=self.image_size, args=self.args) flag = 0 for sample, _ in tqdm(train_loader, desc=f'Extracting train features for class {class_name}'): for method in self.methods.values(): if self.args.save_feature: method.add_sample_to_mem_bank(sample, class_name=class_name) else: method.add_sample_to_mem_bank(sample) flag += 1 if flag > self.count: flag = 0 break for method_name, method in self.methods.items(): print(f'\n\nRunning coreset for {method_name} on class {class_name}...') method.run_coreset() if self.args.memory_bank == 'multiple': flag = 0 for sample, _ in tqdm(train_loader, desc=f'Running late fusion for {method_name} on class {class_name}..'): for method_name, method in self.methods.items(): method.add_sample_to_late_fusion_mem_bank(sample) flag += 1 if flag > self.count: flag = 0 break for method_name, method in self.methods.items(): print(f'\n\nTraining Dicision Layer Fusion for {method_name} on class {class_name}...') method.run_late_fusion() def evaluate(self, class_name): image_rocaucs = dict() pixel_rocaucs = dict() au_pros = dict() test_loader = get_data_loader("test", class_name=class_name, img_size=self.image_size, args=self.args) path_list = [] with torch.no_grad(): for sample, mask, label, rgb_path in tqdm(test_loader, desc=f'Extracting test features for class {class_name}'): for method in self.methods.values(): method.predict(sample, mask, label) path_list.append(rgb_path) for method_name, method in self.methods.items(): method.calculate_metrics() image_rocaucs[method_name] = round(method.image_rocauc, 3) pixel_rocaucs[method_name] = round(method.pixel_rocauc, 3) au_pros[method_name] = round(method.au_pro, 3) print( f'Class: {class_name}, {method_name} Image ROCAUC: {method.image_rocauc:.3f}, {method_name} Pixel ROCAUC: {method.pixel_rocauc:.3f}, {method_name} AU-PRO: {method.au_pro:.3f}') if self.args.save_preds: method.save_prediction_maps('./pred_maps', path_list) return image_rocaucs, pixel_rocaucs, au_pros ================================================ FILE: main.py ================================================ import argparse from m3dm_runner import M3DM from dataset import eyecandies_classes, mvtec3d_classes import pandas as pd def run_3d_ads(args): if args.dataset_type=='eyecandies': classes = eyecandies_classes() elif args.dataset_type=='mvtec3d': classes = mvtec3d_classes() METHOD_NAMES = [args.method_name] image_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) pixel_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) au_pros_df = pd.DataFrame(METHOD_NAMES, columns=['Method']) for cls in classes: model = M3DM(args) model.fit(cls) image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls) image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs) pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs) au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros) print(f"\nFinished running on class {cls}") print("################################################################################\n\n") image_rocaucs_df['Mean'] = round(image_rocaucs_df.iloc[:, 1:].mean(axis=1),3) pixel_rocaucs_df['Mean'] = round(pixel_rocaucs_df.iloc[:, 1:].mean(axis=1),3) au_pros_df['Mean'] = round(au_pros_df.iloc[:, 1:].mean(axis=1),3) print("\n\n################################################################################") print("############################# Image ROCAUC Results #############################") print("################################################################################\n") print(image_rocaucs_df.to_markdown(index=False)) print("\n\n################################################################################") print("############################# Pixel ROCAUC Results #############################") print("################################################################################\n") print(pixel_rocaucs_df.to_markdown(index=False)) print("\n\n##########################################################################") print("############################# AU PRO Results #############################") print("##########################################################################\n") print(au_pros_df.to_markdown(index=False)) with open("results/image_rocauc_results.md", "a") as tf: tf.write(image_rocaucs_df.to_markdown(index=False)) with open("results/pixel_rocauc_results.md", "a") as tf: tf.write(pixel_rocaucs_df.to_markdown(index=False)) with open("results/aupro_results.md", "a") as tf: tf.write(au_pros_df.to_markdown(index=False)) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str, choices=['DINO', 'Point_MAE', 'Fusion', 'DINO+Point_MAE', 'DINO+Point_MAE+Fusion', 'DINO+Point_MAE+add'], help='Anomaly detection modal name.') parser.add_argument('--max_sample', default=400, type=int, help='Max sample number.') parser.add_argument('--memory_bank', default='multiple', type=str, choices=["multiple", "single"], help='memory bank mode: "multiple", "single".') parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str, choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'], help='Timm checkpoints name of RGB backbone.') parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert'], help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert].') parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str, help='Checkpoints for fusion module.') parser.add_argument('--save_feature', default=False, action='store_true', help='Save feature for training fusion block.') parser.add_argument('--use_uff', default=False, action='store_true', help='Use UFF module.') parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str, help='Save feature for training fusion block.') parser.add_argument('--save_preds', default=False, action='store_true', help='Save predicts results.') parser.add_argument('--group_size', default=128, type=int, help='Point group size of Point Transformer.') parser.add_argument('--num_group', default=1024, type=int, help='Point groups number of Point Transformer.') parser.add_argument('--random_state', default=None, type=int, help='random_state for random project') parser.add_argument('--dataset_type', default='mvtec3d', type=str, choices=['mvtec3d', 'eyecandies'], help='Dataset type for training or testing') parser.add_argument('--dataset_path', default='datasets/mvtec3d', type=str, help='Dataset store path') parser.add_argument('--img_size', default=224, type=int, help='Images size for model') parser.add_argument('--xyz_s_lambda', default=1.0, type=float, help='xyz_s_lambda') parser.add_argument('--xyz_smap_lambda', default=1.0, type=float, help='xyz_smap_lambda') parser.add_argument('--rgb_s_lambda', default=0.1, type=float, help='rgb_s_lambda') parser.add_argument('--rgb_smap_lambda', default=0.1, type=float, help='rgb_smap_lambda') parser.add_argument('--fusion_s_lambda', default=1.0, type=float, help='fusion_s_lambda') parser.add_argument('--fusion_smap_lambda', default=1.0, type=float, help='fusion_smap_lambda') parser.add_argument('--coreset_eps', default=0.9, type=float, help='eps for sparse project') parser.add_argument('--f_coreset', default=0.1, type=float, help='eps for sparse project') parser.add_argument('--asy_memory_bank', default=None, type=int, help='build an asymmetric memory bank for point clouds') parser.add_argument('--ocsvm_nu', default=0.5, type=float, help='ocsvm nu') parser.add_argument('--ocsvm_maxiter', default=1000, type=int, help='ocsvm maxiter') parser.add_argument('--rm_zero_for_project', default=False, action='store_true', help='Save predicts results.') args = parser.parse_args() run_3d_ads(args) ================================================ FILE: models/feature_fusion.py ================================================ import torch import torch.nn as nn import math class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class FeatureFusionBlock(nn.Module): def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.): super().__init__() self.xyz_dim = xyz_dim self.rgb_dim = rgb_dim self.xyz_norm = nn.LayerNorm(xyz_dim) self.xyz_mlp = Mlp(in_features=xyz_dim, hidden_features=int(xyz_dim * mlp_ratio), act_layer=nn.GELU, drop=0.) self.rgb_norm = nn.LayerNorm(rgb_dim) self.rgb_mlp = Mlp(in_features=rgb_dim, hidden_features=int(rgb_dim * mlp_ratio), act_layer=nn.GELU, drop=0.) self.rgb_head = nn.Linear(rgb_dim, 256) self.xyz_head = nn.Linear(xyz_dim, 256) self.T = 1 def feature_fusion(self, xyz_feature, rgb_feature): xyz_feature = self.xyz_mlp(self.xyz_norm(xyz_feature)) rgb_feature = self.rgb_mlp(self.rgb_norm(rgb_feature)) feature = torch.cat([xyz_feature, rgb_feature], dim=2) return feature def contrastive_loss(self, q, k): # normalize q = nn.functional.normalize(q, dim=1) k = nn.functional.normalize(k, dim=1) # gather all targets # Einstein sum is more intuitive logits = torch.einsum('nc,mc->nm', [q, k]) / self.T N = logits.shape[0] # batch size per GPU labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda() return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T) def reparameterize(self, mu, logvar): """ Will a single z be enough ti compute the expectation for the loss?? :param mu: (Tensor) Mean of the latent Gaussian :param logvar: (Tensor) Standard deviation of the latent Gaussian :return: """ std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return eps * std + mu def forward(self, xyz_feature, rgb_feature): feature = self.feature_fusion(xyz_feature, rgb_feature) feature_xyz = feature[:,:, :self.xyz_dim] feature_rgb = feature[:,:, self.xyz_dim:] q = self.rgb_head(feature_rgb.view(-1, feature_rgb.shape[2])) k = self.xyz_head(feature_xyz.view(-1, feature_xyz.shape[2])) xyz_feature = xyz_feature.view(-1, xyz_feature.shape[2]) rgb_feature = rgb_feature.view(-1, rgb_feature.shape[2]) patch_no_zeros_indices = torch.nonzero(torch.all(xyz_feature != 0, dim=1)) loss = self.contrastive_loss(q[patch_no_zeros_indices,:].squeeze(), k[patch_no_zeros_indices,:].squeeze()) return loss ================================================ FILE: models/models.py ================================================ import torch import torch.nn as nn import timm from timm.models.layers import DropPath, trunc_normal_ from pointnet2_ops import pointnet2_utils from knn_cuda import KNN class Model(torch.nn.Module): def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino', out_indices=None, checkpoint_path='', pool_last=False, xyz_backbone_name='Point_MAE', group_size=128, num_group=1024): super().__init__() # 'vit_base_patch8_224_dino' # Determine if to output features. self.device = device kwargs = {'features_only': True if out_indices else False} if out_indices: kwargs.update({'out_indices': out_indices}) ## RGB backbone self.rgb_backbone = timm.create_model(model_name=rgb_backbone_name, pretrained=True, checkpoint_path=checkpoint_path, **kwargs) ## XYZ backbone if xyz_backbone_name=='Point_MAE': self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group) self.xyz_backbone.load_model_from_ckpt("checkpoints/pointmae_pretrain.pth") elif xyz_backbone_name=='Point_Bert': self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group, encoder_dims=256) self.xyz_backbone.load_model_from_pb_ckpt("checkpoints/Point-BERT.pth") def forward_rgb_features(self, x): x = self.rgb_backbone.patch_embed(x) x = self.rgb_backbone._pos_embed(x) x = self.rgb_backbone.norm_pre(x) if self.rgb_backbone.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.rgb_backbone.blocks(x) x = self.rgb_backbone.norm(x) feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28) return feat def forward(self, rgb, xyz): rgb_features = self.forward_rgb_features(rgb) xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz) return rgb_features, xyz_features, center, ori_idx, center_idx def fps(data, number): ''' data B N 3 number int ''' fps_idx = pointnet2_utils.furthest_point_sample(data, number) fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous() return fps_data, fps_idx class Group(nn.Module): def __init__(self, num_group, group_size): super().__init__() self.num_group = num_group self.group_size = group_size self.knn = KNN(k=self.group_size, transpose_mode=True) def forward(self, xyz): ''' input: B N 3 --------------------------- output: B G M 3 center : B G 3 ''' batch_size, num_points, _ = xyz.shape # fps the centers out center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3 # knn to get the neighborhood _, idx = self.knn(xyz, center) # B G M assert idx.size(1) == self.num_group assert idx.size(2) == self.group_size ori_idx = idx idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :] neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous() # normalize neighborhood = neighborhood - center.unsqueeze(2) return neighborhood, center, ori_idx, center_idx class Encoder(nn.Module): def __init__(self, encoder_channel): super().__init__() self.encoder_channel = encoder_channel self.first_conv = nn.Sequential( nn.Conv1d(3, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Conv1d(128, 256, 1) ) self.second_conv = nn.Sequential( nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Conv1d(512, self.encoder_channel, 1) ) def forward(self, point_groups): ''' point_groups : B G N 3 ----------------- feature_global : B G C ''' bs, g, n, _ = point_groups.shape point_groups = point_groups.reshape(bs * g, n, 3) # encoder feature = self.first_conv(point_groups.transpose(2, 1)) feature_global = torch.max(feature, dim=2, keepdim=True)[0] feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1) feature = self.second_conv(feature) feature_global = torch.max(feature, dim=2, keepdim=False)[0] return feature_global.reshape(bs, g, self.encoder_channel) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) attn = (q * self.scale) @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) def forward(self, x): x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class TransformerEncoder(nn.Module): """ Transformer Encoder without hierarchical structure """ def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.): super().__init__() self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate ) for i in range(depth)]) def forward(self, x, pos): feature_list = [] fetch_idx = [3, 7, 11] for i, block in enumerate(self.blocks): x = block(x + pos) if i in fetch_idx: feature_list.append(x) return feature_list class PointTransformer(nn.Module): def __init__(self, group_size=128, num_group=1024, encoder_dims=384): super().__init__() self.trans_dim = 384 self.depth = 12 self.drop_path_rate = 0.1 self.num_heads = 6 self.group_size = group_size self.num_group = num_group # grouper self.group_divider = Group(num_group=self.num_group, group_size=self.group_size) # define the encoder self.encoder_dims = encoder_dims if self.encoder_dims != self.trans_dim: self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim) self.encoder = Encoder(encoder_channel=self.encoder_dims) # bridge encoder and transformer self.pos_embed = nn.Sequential( nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.trans_dim) ) dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] self.blocks = TransformerEncoder( embed_dim=self.trans_dim, depth=self.depth, drop_path_rate=dpr, num_heads=self.num_heads ) self.norm = nn.LayerNorm(self.trans_dim) def load_model_from_ckpt(self, bert_ckpt_path): if bert_ckpt_path is not None: ckpt = torch.load(bert_ckpt_path) base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} for k in list(base_ckpt.keys()): if k.startswith('MAE_encoder'): base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k] del base_ckpt[k] elif k.startswith('base_model'): base_ckpt[k[len('base_model.'):]] = base_ckpt[k] del base_ckpt[k] incompatible = self.load_state_dict(base_ckpt, strict=False) #if incompatible.missing_keys: # print('missing_keys') # print( # incompatible.missing_keys # ) #if incompatible.unexpected_keys: # print('unexpected_keys') # print( # incompatible.unexpected_keys # ) # print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}') def load_model_from_pb_ckpt(self, bert_ckpt_path): ckpt = torch.load(bert_ckpt_path) base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()} for k in list(base_ckpt.keys()): if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'): base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k] elif k.startswith('base_model'): base_ckpt[k[len('base_model.'):]] = base_ckpt[k] del base_ckpt[k] incompatible = self.load_state_dict(base_ckpt, strict=False) if incompatible.missing_keys: print('missing_keys') print( incompatible.missing_keys ) if incompatible.unexpected_keys: print('unexpected_keys') print( incompatible.unexpected_keys ) print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}') def forward(self, pts): if self.encoder_dims != self.trans_dim: B,C,N = pts.shape pts = pts.transpose(-1, -2) # B N 3 # divide the point clo ud in the same form. This is important neighborhood, center, ori_idx, center_idx = self.group_divider(pts) # # generate mask # bool_masked_pos = self._mask_center(center, no_mask = False) # B G # encoder the input cloud blocks group_input_tokens = self.encoder(neighborhood) # B G N group_input_tokens = self.reduce_dim(group_input_tokens) # prepare cls cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) # add pos embedding pos = self.pos_embed(center) # final input x = torch.cat((cls_tokens, group_input_tokens), dim=1) pos = torch.cat((cls_pos, pos), dim=1) # transformer feature_list = self.blocks(x, pos) feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list] x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152 return x, center, ori_idx, center_idx else: B, C, N = pts.shape pts = pts.transpose(-1, -2) # B N 3 # divide the point clo ud in the same form. This is important neighborhood, center, ori_idx, center_idx = self.group_divider(pts) group_input_tokens = self.encoder(neighborhood) # B G N pos = self.pos_embed(center) # final input x = group_input_tokens # transformer feature_list = self.blocks(x, pos) feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list] x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152 return x, center, ori_idx, center_idx ================================================ FILE: models/pointnet2_utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from time import time import numpy as np def timeit(tag, t): print("{}: {}s".format(tag, time() - t)) return time() def pc_normalize(pc): l = pc.shape[0] centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc**2, axis=1))) pc = pc / m return pc def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) mask = dist < distance distance[mask] = dist[mask] farthest = torch.max(distance, -1)[1] return centroids def query_ball_point(radius, nsample, xyz, new_xyz): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, 3] new_xyz: query points, [B, S, 3] Return: group_idx: grouped points index, [B, S, nsample] """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): """ Input: npoint: radius: nsample: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, npoint, nsample, 3] new_points: sampled points data, [B, npoint, nsample, 3+D] """ B, N, C = xyz.shape S = npoint fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] new_xyz = index_points(xyz, fps_idx) idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) if points is not None: grouped_points = index_points(points, idx) new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] else: new_points = grouped_xyz_norm if returnfps: return new_xyz, new_points, grouped_xyz, fps_idx else: return new_xyz, new_points def sample_and_group_all(xyz, points): """ Input: xyz: input points position data, [B, N, 3] points: input points data, [B, N, D] Return: new_xyz: sampled points position data, [B, 1, 3] new_points: sampled points data, [B, 1, N, 3+D] """ device = xyz.device B, N, C = xyz.shape new_xyz = torch.zeros(B, 1, C).to(device) grouped_xyz = xyz.view(B, 1, N, C) if points is not None: new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) else: new_points = grouped_xyz return new_xyz, new_points def interpolating_points(xyz1, xyz2, points2): """ Input: xyz1: input points position data, [B, C, N] xyz2: sampled input points position data, [B, C, S] points2: input points data, [B, D, S] Return: new_points: upsampled points data, [B, D', N] """ xyz1 = xyz1.permute(0, 2, 1) xyz2 = xyz2.permute(0, 2, 1) points2 = points2.permute(0, 2, 1) B, N, C = xyz1.shape _, S, _ = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, N, 1) else: dists = square_distance(xyz1, xyz2) dists, idx = dists.sort(dim=-1) dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) interpolated_points = interpolated_points.permute(0, 2, 1) return interpolated_points ================================================ FILE: requirements.txt ================================================ numpy Pillow scikit-learn scipy timm torch torchvision tqdm wget tifffile scikit-image kornia imageio tensorboard opencv-python setuptools==59.5.0; ================================================ FILE: utils/au_pro_util.py ================================================ """ Code based on the official MVTec 3D-AD evaluation code found at https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz Utility functions that compute a PRO curve and its definite integral, given pairs of anomaly and ground truth maps. The PRO curve can also be integrated up to a constant integration limit. """ import numpy as np from scipy.ndimage.measurements import label from bisect import bisect class GroundTruthComponent: """ Stores sorted anomaly scores of a single ground truth component. Used to efficiently compute the region overlap for many increasing thresholds. """ def __init__(self, anomaly_scores): """ Initialize the module. Args: anomaly_scores: List of all anomaly scores within the ground truth component as numpy array. """ # Keep a sorted list of all anomaly scores within the component. self.anomaly_scores = anomaly_scores.copy() self.anomaly_scores.sort() # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels. self.index = 0 # The last evaluated threshold. self.last_threshold = None def compute_overlap(self, threshold): """ Compute the region overlap for a specific threshold. Thresholds must be passed in increasing order. Args: threshold: Threshold to compute the region overlap. Returns: Region overlap for the specified threshold. """ if self.last_threshold is not None: assert self.last_threshold <= threshold # Increase the index until it points to an anomaly score that is just above the specified threshold. while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold): self.index += 1 # Compute the fraction of component pixels that are correctly segmented as anomalous. return 1.0 - self.index / len(self.anomaly_scores) def trapezoid(x, y, x_max=None): """ This function calculates the definit integral of a curve given by x- and corresponding y-values. In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by setting a value x_max. Points that do not have a finite x or y value will be ignored with a warning. Args: x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain the same value multiple times. In that case, the order of the corresponding y values will affect the integration with the trapezoidal rule. y: Values of the function corresponding to x values. x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its neighbors. Must not lie outside of the range of x. Returns: Area under the curve. """ x = np.array(x) y = np.array(y) finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y)) if not finite_mask.all(): print( """WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""") x = x[finite_mask] y = y[finite_mask] # Introduce a correction term if max_x is not an element of x. correction = 0. if x_max is not None: if x_max not in x: # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max). ins = bisect(x, x_max) # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x). assert 0 < ins < len(x) # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1]. y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1])) correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1]) # Cut off at x_max. mask = x <= x_max x = x[mask] y = y[mask] # Return area under the curve using the trapezoidal rule. return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction def collect_anomaly_scores(anomaly_maps, ground_truth_maps): """ Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false positive pixel from anomaly maps. Args: anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains an anomaly. Returns: ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each component, a sorted list of its anomaly scores is stored. anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list can be used to quickly select thresholds that fix a certain false positive rate. """ # Make sure an anomaly map is present for each ground truth map. assert len(anomaly_maps) == len(ground_truth_maps) # Initialize ground truth components and scores of potential fp pixels. ground_truth_components = [] anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size) # Structuring element for computing connected components. structure = np.ones((3, 3), dtype=int) # Collect anomaly scores within each ground truth region and for all potential fp pixels. ok_index = 0 for gt_map, prediction in zip(ground_truth_maps, anomaly_maps): # Compute the connected components in the ground truth map. labeled, n_components = label(gt_map, structure) # Store all potential fp scores. num_ok_pixels = len(prediction[labeled == 0]) anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy() ok_index += num_ok_pixels # Fetch anomaly scores within each GT component. for k in range(n_components): component_scores = prediction[labeled == (k + 1)] ground_truth_components.append(GroundTruthComponent(component_scores)) # Sort all potential false positive scores. anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index) anomaly_scores_ok_pixels.sort() return ground_truth_components, anomaly_scores_ok_pixels def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds): """ Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground truth maps. The number of interpolation points can be set manually. Args: anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel. ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains an anomaly. num_thresholds: Number of thresholds to compute the PRO curve. Returns: fprs: List of false positive rates. pros: List of correspoding PRO values. """ # Fetch sorted anomaly scores. ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps) # Select equidistant thresholds. threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int) fprs = [1.0] pros = [1.0] for pos in threshold_positions: threshold = anomaly_scores_ok_pixels[pos] # Compute the false positive rate for this threshold. fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels) # Compute the PRO value for this threshold. pro = 0.0 for component in ground_truth_components: pro += component.compute_overlap(threshold) pro /= len(ground_truth_components) fprs.append(fpr) pros.append(pro) # Return (FPR/PRO) pairs in increasing FPR order. fprs = fprs[::-1] pros = pros[::-1] return fprs, pros def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100): """ Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images. Args: gts: List of tensors that contain the ground truth images for a single dataset object. predictions: List of tensors containing anomaly images for each ground truth image. integration_limit: Integration limit to use when computing the area under the PRO curve. num_thresholds: Number of thresholds to use to sample the area under the PRO curve. Returns: au_pro: Area under the PRO curve computed up to the given integration limit. pro_curve: PRO curve values for localization (fpr,pro). """ # Compute the PRO curve. pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds) # Compute the area under the PRO curve. au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit) au_pro /= integration_limit # Return the evaluation metrics. return au_pro, pro_curve ================================================ FILE: utils/lr_sched.py ================================================ import math def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < args.warmup_epochs: lr = args.lr * epoch / args.warmup_epochs else: lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] else: param_group["lr"] = lr return lr ================================================ FILE: utils/misc.py ================================================ import builtins import datetime import os import time from collections import defaultdict, deque from pathlib import Path import torch import torch.distributed as dist from torch._six import inf class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop('force', False) force = force or (get_world_size() > 8) if is_master or force: now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if args.dist_on_itp: args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') setup_for_distributed(is_master=True) # hack args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def save_model_gan(args, epoch, model, discriminator, model_without_ddp, discriminator_without_ddp, optimizer_g, optimizer_d, loss_scaler): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'discriminator_without_ddp': discriminator_without_ddp.state_dict(), 'optimizer_g': optimizer_g.state_dict(), 'optimizer_d': optimizer_d.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) discriminator.save_checkpoint(save_dir=args.output_dir, tag="checkpoint_d-%s" % epoch_name, client_state=client_state) def load_model(args, model_without_ddp, optimizer, loss_scaler): if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) print("Resume checkpoint %s" % args.resume) if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") def load_model_gan(args, model_without_ddp, discriminator_without_ddp, optimizer_g, optimizer_d, loss_scaler): if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) discriminator_without_ddp.load_state_dict(checkpoint['discriminator']) print("Resume checkpoint %s" % args.resume) if 'optimizer_d' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): optimizer_d.load_state_dict(checkpoint['optimizer_d']) if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): optimizer_g.load_state_dict(checkpoint['optimizer_g']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") def all_reduce_mean(x): world_size = get_world_size() if world_size > 1: x_reduce = torch.tensor(x).cuda() dist.all_reduce(x_reduce) x_reduce /= world_size return x_reduce.item() else: return x ================================================ FILE: utils/mvtec3d_util.py ================================================ import tifffile as tiff import torch def organized_pc_to_unorganized_pc(organized_pc): return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2]) def read_tiff_organized_pc(path): tiff_img = tiff.imread(path) return tiff_img def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True): torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous() torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width), mode='nearest') if tensor_out: return torch_resized_organized_pc.squeeze(dim=0).contiguous() else: return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy() def organized_pc_to_depth_map(organized_pc): return organized_pc[:, :, 2] ================================================ FILE: utils/preprocess_eyecandies.py ================================================ import os from shutil import copyfile import cv2 import numpy as np import tifffile import yaml import imageio.v3 as iio import math import argparse # The same camera has been used for all the images FOCAL_LENGTH = 711.11 def load_and_convert_depth(depth_img, info_depth): with open(info_depth) as f: data = yaml.safe_load(f) mind, maxd = data["normalization"]["min"], data["normalization"]["max"] dimg = iio.imread(depth_img) dimg = dimg.astype(np.float32) dimg = dimg / 65535.0 * (maxd - mind) + mind return dimg def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length): # input depth map (in meters) --- cfr previous section depth_mt = load_and_convert_depth(depth_img, info_depth) # input pose pose = np.loadtxt(pose_txt) # camera intrinsics height, width = depth_mt.shape[:2] intrinsics_4x4 = np.array([ [focal_length, 0, width / 2, 0], [0, focal_length, height / 2, 0], [0, 0, 1, 0], [0, 0, 0, 1]] ) # build the camera projection matrix camera_proj = intrinsics_4x4 @ pose # build the (u, v, 1, 1/depth) vectors (non optimized version) camera_vectors = np.zeros((width * height, 4)) count=0 for j in range(height): for i in range(width): camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]]) count += 1 # invert and apply to each 4-vector hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T # print(hom_3d_pts.shape) # remove the homogeneous coordinate pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T return pcd[:, :3] def remove_point_cloud_background(pc): # The second dim is z dz = pc[256,1] - pc[-256,1] dy = pc[256,2] - pc[-256,2] norm = math.sqrt(dz**2 + dy**2) start_points = np.array([0, pc[-256, 1], pc[-256, 2]]) cos_theta = dy / norm sin_theta = dz / norm # Transform and rotation rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]]) processed_pc = (rotation_matrix @ (pc - start_points).T).T # Remove background point for i in range(processed_pc.shape[0]): if processed_pc[i,1] > -0.02: processed_pc[i, :] = -start_points if processed_pc[i,2] > 1.8: processed_pc[i, :] = -start_points elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1: processed_pc[i, :] = -start_points processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points index = [0, 2, 1] processed_pc = processed_pc[:,index] return processed_pc*[0.1, -0.1, 0.1] if __name__ == '__main__': parser = argparse.ArgumentParser(description='Process some integers.') parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.") parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path") args = parser.parse_args() os.mkdir(args.target_dir) categories_list = os.listdir(args.dataset_path) for category_dir in categories_list: category_root_path = os.path.join(args.dataset_path, category_dir) category_train_path = os.path.join(category_root_path, '/train/data') category_test_path = os.path.join(category_root_path, '/test_public/data') category_target_path = os.path.join(args.target_dir, category_dir) os.mkdir(category_target_path) os.mkdir(os.path.join(category_target_path, 'train')) category_target_train_good_path = os.path.join(category_target_path, 'train/good') category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb') category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz') os.mkdir(category_target_train_good_path) os.mkdir(category_target_train_good_rgb_path) os.mkdir(category_target_train_good_xyz_path) os.mkdir(os.path.join(category_target_path, 'test')) category_target_test_good_path = os.path.join(category_target_path, 'test/good') category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb') category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz') category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt') os.mkdir(category_target_test_good_path) os.mkdir(category_target_test_good_rgb_path) os.mkdir(category_target_test_good_xyz_path) os.mkdir(category_target_test_good_gt_path) category_target_test_bad_path = os.path.join(category_target_path, 'test/bad') category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb') category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz') category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt') os.mkdir(category_target_test_bad_path) os.mkdir(category_target_test_bad_rgb_path) os.mkdir(category_target_test_bad_xyz_path) os.mkdir(category_target_test_bad_gt_path) category_train_files = os.listdir(category_train_path) num_train_files = len(category_train_files)//17 for i in range(0, num_train_files): pc = depth_to_pointcloud( os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'), os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'), os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'), FOCAL_LENGTH, ) pc = remove_point_cloud_background(pc) pc = pc.reshape(512,512,3) tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png')) category_test_files = os.listdir(category_test_path) num_test_files = len(category_test_files)//17 for i in range(0, num_test_files): mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png')) if np.any(mask): pc = depth_to_pointcloud( os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), FOCAL_LENGTH, ) pc = remove_point_cloud_background(pc) pc = pc.reshape(512,512,3) tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc) cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask) copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png')) else: pc = depth_to_pointcloud( os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'), os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'), os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'), FOCAL_LENGTH, ) pc = remove_point_cloud_background(pc) pc = pc.reshape(512,512,3) tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc) cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask) copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png')) ================================================ FILE: utils/preprocessing.py ================================================ import os import numpy as np import tifffile as tiff import open3d as o3d from pathlib import Path from PIL import Image import math import mvtec3d_util as mvt_util import argparse def get_edges_of_pc(organized_pc): unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2]) unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0) unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0) unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0) unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:] return unorganized_edges_pc def get_plane_eq(unorganized_pc,ransac_n_pts=50): o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc)) plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000) return plane_model def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005): # PREP PC unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean) unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) clean_planeless_unorganized_pc = unorganized_pc.copy() planeless_unorganized_rgb = unorganized_rgb.copy() # REMOVE PLANE plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean)) distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T)) plane_indices = np.argwhere(distances < distance_threshold) planeless_unorganized_rgb[plane_indices] = 0 clean_planeless_unorganized_pc[plane_indices] = 0 clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0], organized_pc_clean.shape[1], organized_pc_clean.shape[2]) planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0], organized_rgb.shape[1], organized_rgb.shape[2]) return clean_planeless_organized_pc, planeless_organized_rgb def connected_components_cleaning(organized_pc, organized_rgb, image_path): unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc) unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb) nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0] unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :] o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros)) labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False)) unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True) max_label = labels.max() if max_label>0: print("##########################################################################") print(f"Point cloud file {image_path} has {max_label + 1} clusters") print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}") print("##########################################################################\n\n") largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)] outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id) outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array] unorganized_pc[outlier_indices_original_pc_array] = 0 unorganized_rgb[outlier_indices_original_pc_array] = 0 organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0], organized_pc.shape[1], organized_pc.shape[2]) organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0], organized_rgb.shape[1], organized_rgb.shape[2]) return organized_clustered_pc, organized_clustered_rgb def roundup_next_100(x): return int(math.ceil(x / 100.0)) * 100 def pad_cropped_pc(cropped_pc, single_channel=False): orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1] round_orig_h = roundup_next_100(orig_h) round_orig_w = roundup_next_100(orig_w) large_side = max(round_orig_h, round_orig_w) a = (large_side - orig_h) // 2 aa = large_side - a - orig_h b = (large_side - orig_w) // 2 bb = large_side - b - orig_w if single_channel: return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant') else: return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant') def preprocess_pc(tiff_path): # READ FILES organized_pc = mvt_util.read_tiff_organized_pc(tiff_path) rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png") gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png") organized_rgb = np.array(Image.open(rgb_path)) organized_gt = None gt_exists = os.path.isfile(gt_path) if gt_exists: organized_gt = np.array(Image.open(gt_path)) # REMOVE PLANE planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb) # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE) padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False) padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False) if gt_exists: padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True) organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path) # SAVE PREPROCESSED FILES tiff.imsave(tiff_path, organized_clustered_pc) Image.fromarray(organized_clustered_rgb).save(rgb_path) if gt_exists: Image.fromarray(padded_organized_gt).save(gt_path) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD') parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)') args = parser.parse_args() root_path = args.dataset_path paths = Path(root_path).rglob('*.tiff') print(f"Found {len(list(paths))} tiff files in {root_path}") processed_files = 0 for path in Path(root_path).rglob('*.tiff'): preprocess_pc(path) processed_files += 1 if processed_files % 50 == 0: print(f"Processed {processed_files} tiff files...") ================================================ FILE: utils/utils.py ================================================ import numpy as np import random import torch from torchvision import transforms from PIL import ImageFilter def set_seeds(seed: int = 0) -> None: np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) class KNNGaussianBlur(torch.nn.Module): def __init__(self, radius : int = 4): super().__init__() self.radius = radius self.unload = transforms.ToPILImage() self.load = transforms.ToTensor() self.blur_kernel = ImageFilter.GaussianBlur(radius=4) def __call__(self, img): map_max = img.max() final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)) * map_max return final_map