Repository: nomewang/M3DM
Branch: main
Commit: 62b7f468aa00
Files: 20
Total size: 156.1 KB
Directory structure:
gitextract_ziog4jk2/
├── LICENSE
├── README.md
├── dataset.py
├── engine_fusion_pretrain.py
├── feature_extractors/
│ ├── features.py
│ └── multiple_features.py
├── fusion_pretrain.py
├── m3dm_runner.py
├── main.py
├── models/
│ ├── feature_fusion.py
│ ├── models.py
│ └── pointnet2_utils.py
├── requirements.txt
└── utils/
├── au_pro_util.py
├── lr_sched.py
├── misc.py
├── mvtec3d_util.py
├── preprocess_eyecandies.py
├── preprocessing.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 nomewang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Multimodal Industrial Anomaly Detection via Hybrid Fusion (CVPR 2023)
## Abstract
> 2D-based Industrial Anomaly Detection has been widely discussed, however, multimodal industrial anomaly detection based on 3D point clouds and RGB images still has many untouched fields. Existing multimodal industrial anomaly detection methods directly concatenate the multimodal features, which leads to a strong disturbance between features and harms the detection performance. In this paper, we propose **Multi-3D-Memory** (**M3DM**), a novel multimodal anomaly detection method with hybrid fusion scheme: firstly, we design an unsupervised feature fusion with patch-wise contrastive learning to encourage the interaction of different modal features; secondly, we use a decision layer fusion with multiple memory banks to avoid loss of information and additional novelty classifiers to make the final decision. We further propose a point feature alignment operation to better align the point cloud and RGB features. Extensive experiments show that our multimodal industrial anomaly detection model outperforms the state-of-the-art (SOTA) methods on both detection and segmentation precision on MVTec-3D AD dataset.

- `The pipeline of Multi-3D-Memory (M3DM).` Our M3DM contains three important parts: (1) **Point Feature Alignment** (PFA) converts Point Group features to plane features with interpolation and project operation, $\text{FPS}$ is the farthest point sampling and $\mathcal{F_{pt}}$ is a pretrained Point Transformer; (2) **Unsupervised Feature Fusion** (UFF) fuses point feature and image feature together with a patch-wise contrastive loss $\mathcal{L_{con}}$, where $\mathcal{F_{rgb}}$ is a Vision Transformer, $\chi_{rgb},\chi_{pt}$ are MLP layers and $\sigma_r, \sigma_p$ are single fully connected layers; (3) **Decision Layer Fusion** (DLF) combines multimodal information with multiple memory banks and makes the final decision with 2 learnable modules $\mathcal D_a, \mathcal{D_s}$ for anomaly detection and segmentation, where $\mathcal{M_{rgb}}$, $\mathcal{M_{fs}}$, $\mathcal{M_{pt}}$ are memory banks, $\phi, \psi$ are score function for single memory bank detection and segmentation, and $\mathcal{P}$ is the memory bank building algorithm.
### [Paper](https://arxiv.org/pdf/2303.00601.pdf)
## Setup
We implement this repo with the following environment:
- Ubuntu 18.04
- Python 3.8
- Pytorch 1.9.0
- CUDA 11.3
Install the other package via:
``` bash
pip install -r requirement.txt
# install knn_cuda
pip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl
# install pointnet2_ops_lib
pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"
```
## Data Download and Preprocess
### Dataset
- The `MVTec-3D AD` dataset can be download from the [Official Website of MVTec-3D AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad).
- The `Eyecandies` dataset can be download from the [Official Website of Eyecandies](https://eyecan-ai.github.io/eyecandies/).
After download, put the dataset in `dataset` folder.
### Datapreprocess
To run the preprocessing
```bash
python utils/preprocessing.py datasets/mvtec3d/
```
It may take a few hours to run the preprocessing.
### Checkpoints
The following table lists the pretrain model used in M3DM:
| Backbone | Pretrain Method |
| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| Point Transformer | [Point-MAE](https://drive.google.com/file/d/1-wlRIz0GM8o6BuPTJz4kTt6c_z1Gh6LX/view?usp=sharing) |
| Point Transformer | [Point-Bert](https://cloud.tsinghua.edu.cn/f/202b29805eea45d7be92/?dl=1) |
| ViT-b/8 | [DINO](https://drive.google.com/file/d/17s6lwfxwG_nf1td6LXunL-LjRaX67iyK/view?usp=sharing) |
| ViT-b/8 | [Supervised ImageNet 1K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) |
| ViT-b/8 | [Supervised ImageNet 21K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz) |
| ViT-s/8 | [DINO](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth) |
| UFF | [UFF Module](https://drive.google.com/file/d/1Z2AkfPqenJEv-IdWhVdRcvVQAsJC4DxW/view?usp=sharing) |
Put the checkpoint files in `checkpoints` folder.
## Train and Test
Train and test the double lib version and save the feature for UFF training:
```bash
mkdir -p datasets/patch_lib
python3 main.py \
--method_name DINO+Point_MAE \
--memory_bank multiple \
--rgb_backbone_name vit_base_patch8_224_dino \
--xyz_backbone_name Point_MAE \
--save_feature \
```
Train the UFF:
```bash
OMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=1 fusion_pretrain.py \
--accum_iter 16 \
--lr 0.003 \
--batch_size 16 \
--data_path datasets/patch_lib \
--output_dir checkpoints \
```
Train and test the full setting with the following command:
```bash
python3 main.py \
--method_name DINO+Point_MAE+Fusion \
--use_uff \
--memory_bank multiple \
--rgb_backbone_name vit_base_patch8_224_dino \
--xyz_backbone_name Point_MAE \
--fusion_module_path checkpoints/{FUSION_CHECKPOINT}.pth \
```
Note: if you set `--method_name DINO` or `--method_name Point_MAE`, set `--memory_bank single` at the same time.
If you find this repository useful for your research, please use the following.
```bibtex
@inproceedings{wang2023multimodal,
title={Multimodal Industrial Anomaly Detection via Hybrid Fusion},
author={Wang, Yue and Peng, Jinlong and Zhang, Jiangning and Yi, Ran and Wang, Yabiao and Wang, Chengjie},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8032--8041},
year={2023}
}
```
## Thanks
Our repo is built on [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [MoCo-v3](https://github.com/facebookresearch/moco-v3), thanks their extraordinary works!
================================================
FILE: dataset.py
================================================
import os
from PIL import Image
from torchvision import transforms
import glob
from torch.utils.data import Dataset
from utils.mvtec3d_util import *
from torch.utils.data import DataLoader
import numpy as np
def eyecandies_classes():
return [
'CandyCane',
'ChocolateCookie',
'ChocolatePraline',
'Confetto',
'GummyBear',
'HazelnutTruffle',
'LicoriceSandwich',
'Lollipop',
'Marshmallow',
'PeppermintCandy',
]
def mvtec3d_classes():
return [
"bagel",
"cable_gland",
"carrot",
"cookie",
"dowel",
"foam",
"peach",
"potato",
"rope",
"tire",
]
RGB_SIZE = 224
class BaseAnomalyDetectionDataset(Dataset):
def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
self.IMAGENET_STD = [0.229, 0.224, 0.225]
self.cls = class_name
self.size = img_size
self.img_path = os.path.join(dataset_path, self.cls, split)
self.rgb_transform = transforms.Compose(
[transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])
class PreTrainTensorDataset(Dataset):
def __init__(self, root_path):
super().__init__()
self.root_path = root_path
self.tensor_paths = os.listdir(self.root_path)
def __len__(self):
return len(self.tensor_paths)
def __getitem__(self, idx):
tensor_path = self.tensor_paths[idx]
tensor = torch.load(os.path.join(self.root_path, tensor_path))
label = 0
return tensor, label
class TrainDataset(BaseAnomalyDetectionDataset):
def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
super().__init__(split="train", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
self.img_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
def load_dataset(self):
img_tot_paths = []
tot_labels = []
rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + "/*.png")
tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + "/*.tiff")
rgb_paths.sort()
tiff_paths.sort()
sample_paths = list(zip(rgb_paths, tiff_paths))
img_tot_paths.extend(sample_paths)
tot_labels.extend([0] * len(sample_paths))
return img_tot_paths, tot_labels
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path, label = self.img_paths[idx], self.labels[idx]
rgb_path = img_path[0]
tiff_path = img_path[1]
img = Image.open(rgb_path).convert('RGB')
img = self.rgb_transform(img)
organized_pc = read_tiff_organized_pc(tiff_path)
depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
resized_organized_pc = resized_organized_pc.clone().detach().float()
return (img, resized_organized_pc, resized_depth_map_3channel), label
class TestDataset(BaseAnomalyDetectionDataset):
def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):
super().__init__(split="test", class_name=class_name, img_size=img_size, dataset_path=dataset_path)
self.gt_transform = transforms.Compose([
transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),
transforms.ToTensor()])
self.img_paths, self.gt_paths, self.labels = self.load_dataset() # self.labels => good : 0, anomaly : 1
def load_dataset(self):
img_tot_paths = []
gt_tot_paths = []
tot_labels = []
defect_types = os.listdir(self.img_path)
for defect_type in defect_types:
if defect_type == 'good':
rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
rgb_paths.sort()
tiff_paths.sort()
sample_paths = list(zip(rgb_paths, tiff_paths))
img_tot_paths.extend(sample_paths)
gt_tot_paths.extend([0] * len(sample_paths))
tot_labels.extend([0] * len(sample_paths))
else:
rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + "/*.png")
tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + "/*.tiff")
gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + "/*.png")
rgb_paths.sort()
tiff_paths.sort()
gt_paths.sort()
sample_paths = list(zip(rgb_paths, tiff_paths))
img_tot_paths.extend(sample_paths)
gt_tot_paths.extend(gt_paths)
tot_labels.extend([1] * len(sample_paths))
assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"
return img_tot_paths, gt_tot_paths, tot_labels
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]
rgb_path = img_path[0]
tiff_path = img_path[1]
img_original = Image.open(rgb_path).convert('RGB')
img = self.rgb_transform(img_original)
organized_pc = read_tiff_organized_pc(tiff_path)
depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)
resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)
resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)
resized_organized_pc = resized_organized_pc.clone().detach().float()
if gt == 0:
gt = torch.zeros(
[1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])
else:
gt = Image.open(gt).convert('L')
gt = self.gt_transform(gt)
gt = torch.where(gt > 0.5, 1., .0)
return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path
def get_data_loader(split, class_name, img_size, args):
if split in ['train']:
dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)
elif split in ['test']:
dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)
data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False,
pin_memory=True)
return data_loader
================================================
FILE: engine_fusion_pretrain.py
================================================
import math
import sys
from typing import Iterable
import torch
import utils.misc as misc
import utils.lr_sched as lr_sched
def train_one_epoch(model: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
log_writer=None,
args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
xyz_samples = samples[:,:,:1152].to(device, non_blocking=True)
rgb_samples = samples[:,:,1152:].to(device, non_blocking=True)
with torch.cuda.amp.autocast():
loss = model(xyz_samples, rgb_samples)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, parameters=model.parameters(),
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
lr = optimizer.param_groups[0]["lr"]
metric_logger.update(lr=lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: feature_extractors/features.py
================================================
"""
PatchCore logic based on https://github.com/rvorias/ind_knn_ad
"""
import torch
import numpy as np
import os
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn import random_projection
from sklearn import linear_model
from sklearn.svm import OneClassSVM
from sklearn.ensemble import IsolationForest
from sklearn.metrics import roc_auc_score
from timm.models.layers import DropPath, trunc_normal_
from pointnet2_ops import pointnet2_utils
from knn_cuda import KNN
from utils.utils import KNNGaussianBlur
from utils.utils import set_seeds
from utils.au_pro_util import calculate_au_pro
from models.pointnet2_utils import interpolating_points
from models.feature_fusion import FeatureFusionBlock
from models.models import Model
class Features(torch.nn.Module):
def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9):
super().__init__()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.deep_feature_extractor = Model(
device=self.device,
rgb_backbone_name=args.rgb_backbone_name,
xyz_backbone_name=args.xyz_backbone_name,
group_size = args.group_size,
num_group=args.num_group
)
self.deep_feature_extractor.to(self.device)
self.args = args
self.image_size = args.img_size
self.f_coreset = args.f_coreset
self.coreset_eps = args.coreset_eps
self.blur = KNNGaussianBlur(4)
self.n_reweight = 3
set_seeds(0)
self.patch_xyz_lib = []
self.patch_rgb_lib = []
self.patch_fusion_lib = []
self.patch_lib = []
self.random_state = args.random_state
self.xyz_dim = 0
self.rgb_dim = 0
self.xyz_mean=0
self.xyz_std=0
self.rgb_mean=0
self.rgb_std=0
self.fusion_mean=0
self.fusion_std=0
self.average = torch.nn.AvgPool2d(3, stride=1) # torch.nn.AvgPool2d(1, stride=1) #
self.resize = torch.nn.AdaptiveAvgPool2d((56, 56))
self.resize2 = torch.nn.AdaptiveAvgPool2d((56, 56))
self.image_preds = list()
self.image_labels = list()
self.pixel_preds = list()
self.pixel_labels = list()
self.gts = []
self.predictions = []
self.image_rocauc = 0
self.pixel_rocauc = 0
self.au_pro = 0
self.ins_id = 0
self.rgb_layernorm = torch.nn.LayerNorm(768, elementwise_affine=False)
if self.args.use_uff:
self.fusion = FeatureFusionBlock(1152, 768, mlp_ratio=4.)
ckpt = torch.load(args.fusion_module_path)['model']
incompatible = self.fusion.load_state_dict(ckpt, strict=False)
print('[Fusion Block]', incompatible)
self.detect_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter)
self.seg_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu, max_iter=args.ocsvm_maxiter)
self.s_lib = []
self.s_map_lib = []
def __call__(self, rgb, xyz):
# Extract the desired feature maps using the backbone model.
rgb = rgb.to(self.device)
xyz = xyz.to(self.device)
with torch.no_grad():
rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)
interpolate = True
if interpolate:
interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps).to("cpu")
xyz_feature_maps = [fmap.to("cpu") for fmap in [xyz_feature_maps]]
rgb_feature_maps = [fmap.to("cpu") for fmap in [rgb_feature_maps]]
if interpolate:
return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps
else:
return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx
def add_sample_to_mem_bank(self, sample):
raise NotImplementedError
def predict(self, sample, mask, label):
raise NotImplementedError
def add_sample_to_late_fusion_mem_bank(self, sample):
raise NotImplementedError
def interpolate_points(self, rgb, xyz):
with torch.no_grad():
rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)
return xyz_feature_maps, center, xyz
def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
raise NotImplementedError
def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):
raise NotImplementedError
def run_coreset(self):
raise NotImplementedError
def calculate_metrics(self):
self.image_preds = np.stack(self.image_preds)
self.image_labels = np.stack(self.image_labels)
self.pixel_preds = np.array(self.pixel_preds)
self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds)
self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds)
self.au_pro, _ = calculate_au_pro(self.gts, self.predictions)
def save_prediction_maps(self, output_path, rgb_path, save_num=5):
for i in range(max(save_num, len(self.predictions))):
# fig = plt.figure(dpi=300)
fig = plt.figure()
ax3 = fig.add_subplot(1,3,1)
gt = plt.imread(rgb_path[i][0])
ax3.imshow(gt)
ax2 = fig.add_subplot(1,3,2)
im2 = ax2.imshow(self.gts[i], cmap=plt.cm.gray)
ax = fig.add_subplot(1,3,3)
im = ax.imshow(self.predictions[i], cmap=plt.cm.jet)
class_dir = os.path.join(output_path, rgb_path[i][0].split('/')[-5])
if not os.path.exists(class_dir):
os.mkdir(class_dir)
ad_dir = os.path.join(class_dir, rgb_path[i][0].split('/')[-3])
if not os.path.exists(ad_dir):
os.mkdir(ad_dir)
plt.savefig(os.path.join(ad_dir, str(self.image_preds[i]) + '_pred_' + rgb_path[i][0].split('/')[-1] + '.jpg'))
def run_late_fusion(self):
self.s_lib = torch.cat(self.s_lib, 0)
self.s_map_lib = torch.cat(self.s_map_lib, 0)
self.detect_fuser.fit(self.s_lib)
self.seg_fuser.fit(self.s_map_lib)
def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=True, force_cpu=False):
print(f" Fitting random projections. Start dim = {z_lib.shape}.")
try:
transformer = random_projection.SparseRandomProjection(eps=eps, random_state=self.random_state)
z_lib = torch.tensor(transformer.fit_transform(z_lib))
print(f" DONE. Transformed dim = {z_lib.shape}.")
except ValueError:
print(" Error: could not project vectors. Please increase `eps`.")
select_idx = 0
last_item = z_lib[select_idx:select_idx + 1]
coreset_idx = [torch.tensor(select_idx)]
min_distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True)
if float16:
last_item = last_item.half()
z_lib = z_lib.half()
min_distances = min_distances.half()
if torch.cuda.is_available() and not force_cpu:
last_item = last_item.to("cuda")
z_lib = z_lib.to("cuda")
min_distances = min_distances.to("cuda")
for _ in tqdm(range(n - 1)):
distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True) # broadcasting step
min_distances = torch.minimum(distances, min_distances) # iterative step
select_idx = torch.argmax(min_distances) # selection step
# bookkeeping
last_item = z_lib[select_idx:select_idx + 1]
min_distances[select_idx] = 0
coreset_idx.append(select_idx.to("cpu"))
return torch.stack(coreset_idx)
================================================
FILE: feature_extractors/multiple_features.py
================================================
import torch
from feature_extractors.features import Features
from utils.mvtec3d_util import *
import numpy as np
import math
import os
class RGBFeatures(Features):
def add_sample_to_mem_bank(self, sample):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, _, _, center_idx, _ = self(sample[0],unorganized_pc_no_zeros.contiguous())
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
self.patch_lib.append(rgb_patch)
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, _ = self(sample[0], unorganized_pc_no_zeros.contiguous())
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
self.compute_s_s_map(rgb_patch, rgb_feature_maps[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def run_coreset(self):
self.patch_lib = torch.cat(self.patch_lib, 0)
self.mean = torch.mean(self.patch_lib)
self.std = torch.std(self.patch_lib)
self.patch_lib = (self.patch_lib - self.mean)/self.std
# self.patch_lib = self.rgb_layernorm(self.patch_lib)
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,
n=int(self.f_coreset * self.patch_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_lib = self.patch_lib[self.coreset_idx]
def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
patch = (patch - self.mean)/self.std
# self.patch_lib = self.rgb_layernorm(self.patch_lib)
dist = torch.cdist(patch, self.patch_lib)
min_val, min_idx = torch.min(dist, dim=1)
# print(min_val.shape)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')
s_map = self.blur(s_map)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
class PointFeatures(Features):
def add_sample_to_mem_bank(self, sample):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
self.patch_lib.append(xyz_patch)
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
self.compute_s_s_map(xyz_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def run_coreset(self):
self.patch_lib = torch.cat(self.patch_lib, 0)
if self.args.rm_zero_for_project:
self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]]
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,
n=int(self.f_coreset * self.patch_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_lib = self.patch_lib[self.coreset_idx]
if self.args.rm_zero_for_project:
self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]]
self.patch_lib = torch.cat((self.patch_lib, torch.zeros(1, self.patch_lib.shape[1])), 0)
def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
dist = torch.cdist(patch, self.patch_lib)
min_val, min_idx = torch.min(dist, dim=1)
# print(min_val.shape)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')
s_map = self.blur(s_map)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
FUSION_BLOCK= True
class FusionFeatures(Features):
def add_sample_to_mem_bank(self, sample, class_name=None):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))
rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))
rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T
if FUSION_BLOCK:
with torch.no_grad():
fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0))
fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()
else:
fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1)
if class_name is not None:
torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))
self.ins_id += 1
self.patch_lib.append(fusion_patch)
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))
rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))
rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T
if FUSION_BLOCK:
with torch.no_grad():
fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0))
fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()
else:
fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1)
self.compute_s_s_map(fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
dist = torch.cdist(patch, self.patch_lib)
min_val, min_idx = torch.min(dist, dim=1)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear')
s_map = self.blur(s_map)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
def run_coreset(self):
self.patch_lib = torch.cat(self.patch_lib, 0)
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,
n=int(self.f_coreset * self.patch_lib.shape[0]),
eps=self.coreset_eps)
self.patch_lib = self.patch_lib[self.coreset_idx]
class DoubleRGBPointFeatures(Features):
def add_sample_to_mem_bank(self, sample, class_name=None):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1)
patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1)
if class_name is not None:
torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))
self.ins_id += 1
self.patch_xyz_lib.append(xyz_patch)
self.patch_rgb_lib.append(rgb_patch)
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def add_sample_to_late_fusion_mem_bank(self, sample):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
# 2D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]])
s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)
self.s_lib.append(s)
self.s_map_lib.append(s_map)
def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
# 2D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]])
s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)
s = torch.tensor(self.detect_fuser.score_samples(s))
s_map = torch.tensor(self.seg_fuser.score_samples(s_map))
s_map = s_map.view(1, 224, 224)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):
min_val, min_idx = torch.min(dist, dim=1)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)/1000
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
if modal=='xyz':
m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1
else:
m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
if modal=='xyz':
m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1)/1000
else:
m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)/1000
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')
s_map = self.blur(s_map)
return s, s_map
def run_coreset(self):
self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)
self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)
self.xyz_mean = torch.mean(self.patch_xyz_lib)
self.xyz_std = torch.std(self.patch_rgb_lib)
self.rgb_mean = torch.mean(self.patch_xyz_lib)
self.rgb_std = torch.std(self.patch_rgb_lib)
self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std
self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]
class DoubleRGBPointFeatures_add(Features):
def add_sample_to_mem_bank(self, sample, class_name=None):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1)
patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1)
if class_name is not None:
torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))
self.ins_id += 1
self.patch_xyz_lib.append(xyz_patch)
self.patch_rgb_lib.append(rgb_patch)
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def add_sample_to_late_fusion_mem_bank(self, sample):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
# 2D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s = torch.tensor([[s_xyz, s_rgb]])
s_map = torch.cat([s_map_xyz, s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)
self.s_lib.append(s)
self.s_map_lib.append(s_map)
def run_coreset(self):
self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)
self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)
self.xyz_mean = torch.mean(self.patch_xyz_lib)
self.xyz_std = torch.std(self.patch_rgb_lib)
self.rgb_mean = torch.mean(self.patch_xyz_lib)
self.rgb_std = torch.std(self.patch_rgb_lib)
self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std
self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]
def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
# 2D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s = s_xyz + s_rgb
s_map = s_map_xyz + s_map_rgb
s_map = s_map.view(1, self.image_size, self.image_size)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):
min_val, min_idx = torch.min(dist, dim=1)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
if modal=='xyz':
m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1
else:
m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
if modal=='xyz':
m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1)
else:
m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)
s_map = self.blur(s_map)
return s, s_map
class TripleFeatures(Features):
def add_sample_to_mem_bank(self, sample, class_name=None):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))
rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))
rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T
self.patch_rgb_lib.append(rgb_patch)
if self.args.asy_memory_bank is None or len(self.patch_xyz_lib) < self.args.asy_memory_bank:
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))
xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T
if FUSION_BLOCK:
with torch.no_grad():
fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))
fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()
else:
fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)
self.patch_xyz_lib.append(xyz_patch)
self.patch_fusion_lib.append(fusion_patch)
if class_name is not None:
torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))
self.ins_id += 1
def predict(self, sample, mask, label):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))
xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))
rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))
rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T
if FUSION_BLOCK:
with torch.no_grad():
fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))
fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()
else:
fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)
self.compute_s_s_map(xyz_patch, rgb_patch, fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)
def add_sample_to_late_fusion_mem_bank(self, sample):
organized_pc = sample[1]
organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()
unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)
rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())
xyz_patch = torch.cat(xyz_feature_maps, 1)
xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)
xyz_patch_full[:,:,nonzero_indices] = interpolated_pc
xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)
xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))
xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T
xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))
xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T
rgb_patch = torch.cat(rgb_feature_maps, 1)
rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T
rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))
rgb_patch2 = self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))
rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T
if FUSION_BLOCK:
with torch.no_grad():
fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))
fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()
else:
fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)
# 3D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
fusion_feat_size = (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0])))
# 3 memory bank results
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion')
s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]])
s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0)
self.s_lib.append(s)
self.s_map_lib.append(s_map)
def run_coreset(self):
self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)
self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)
self.patch_fusion_lib = torch.cat(self.patch_fusion_lib, 0)
self.xyz_mean = torch.mean(self.patch_xyz_lib)
self.xyz_std = torch.std(self.patch_rgb_lib)
self.rgb_mean = torch.mean(self.patch_xyz_lib)
self.rgb_std = torch.std(self.patch_rgb_lib)
self.fusion_mean = torch.mean(self.patch_xyz_lib)
self.fusion_std = torch.std(self.patch_rgb_lib)
self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std
self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std
self.patch_fusion_lib = (self.patch_fusion_lib - self.fusion_mean)/self.fusion_std
if self.f_coreset < 1:
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]
self.coreset_idx = self.get_coreset_idx_randomp(self.patch_fusion_lib,
n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),
eps=self.coreset_eps, )
self.patch_fusion_lib = self.patch_fusion_lib[self.coreset_idx]
self.patch_xyz_lib = self.patch_xyz_lib[torch.nonzero(torch.all(self.patch_xyz_lib!=0, dim=1))[:,0]]
self.patch_xyz_lib = torch.cat((self.patch_xyz_lib, torch.zeros(1, self.patch_xyz_lib.shape[1])), 0)
def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):
'''
center: point group center position
neighbour_idx: each group point index
nonzero_indices: point indices of original point clouds
xyz: nonzero point clouds
'''
# 3D dist
xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std
rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std
fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std
dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)
dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)
dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib)
rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))
xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))
fusion_feat_size = (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0])))
s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')
s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')
s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion')
s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]])
s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0)
s = torch.tensor(self.detect_fuser.score_samples(s))
s_map = torch.tensor(self.seg_fuser.score_samples(s_map))
s_map = s_map.view(1, self.image_size, self.image_size)
self.image_preds.append(s.numpy())
self.image_labels.append(label)
self.pixel_preds.extend(s_map.flatten().numpy())
self.pixel_labels.extend(mask.flatten().numpy())
self.predictions.append(s_map.detach().cpu().squeeze().numpy())
self.gts.append(mask.detach().cpu().squeeze().numpy())
def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):
min_val, min_idx = torch.min(dist, dim=1)
s_idx = torch.argmax(min_val)
s_star = torch.max(min_val)
# reweighting
m_test = patch[s_idx].unsqueeze(0) # anomalous patch
if modal=='xyz':
m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_xyz_lib) # find knn to m_star pt.1
elif modal=='rgb':
m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_rgb_lib) # find knn to m_star pt.1
else:
m_star = self.patch_fusion_lib[min_idx[s_idx]].unsqueeze(0) # closest neighbour
w_dist = torch.cdist(m_star, self.patch_fusion_lib) # find knn to m_star pt.1
_, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
# equation 7 from the paper
if modal=='xyz':
m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1)
elif modal=='rgb':
m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)
else:
m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1:]], dim=1)
# sparse reweight
# if modal=='rgb':
# _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False) # pt.2
# else:
# _, nn_idx = torch.topk(w_dist, k=4*self.n_reweight, largest=False) # pt.2
# if modal=='xyz':
# m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1::4]], dim=1)
# elif modal=='rgb':
# m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)
# else:
# m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1::4]], dim=1)
# Softmax normalization trick as in transformers.
# As the patch vectors grow larger, their norm might differ a lot.
# exp(norm) can give infinities.
D = torch.sqrt(torch.tensor(patch.shape[1]))
w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))
s = w * s_star
# segmentation map
s_map = min_val.view(1, 1, *feature_map_dims)
s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)
s_map = self.blur(s_map)
return s, s_map
================================================
FILE: fusion_pretrain.py
================================================
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import timm
import timm.optim.optim_factory as optim_factory
import utils.misc as misc
from utils.misc import NativeScalerWithGradNormCount as NativeScaler
from engine_fusion_pretrain import train_one_epoch
import dataset
import torch
from models.feature_fusion import FeatureFusionBlock
def get_args_parser():
parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
parser.add_argument('--batch_size', default=64, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=3, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--input_size', default=224, type=int,
help='images input size')
# Optimizer parameters
parser.add_argument('--weight_decay', type=float, default=1.5e-6,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=0.002, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',
help='epochs to warmup LR')
# Dataset parameters
parser.add_argument('--data_path', default='', type=str,
help='dataset path')
parser.add_argument('--output_dir', default='./output_dir',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output_dir',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
dataset_train = dataset.PreTrainTensorDataset(args.data_path)
print(dataset_train)
if True: # args.distributed:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
print("Sampler_train = %s" % str(sampler_train))
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
if global_rank == 0 and args.log_dir is not None:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 256
print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
model = FeatureFusionBlock(1152, 768)
model.to(device)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
model_without_ddp = model.module
# following timm: set wd as 0 for bias and norm layers
optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, device, epoch, loss_scaler,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs):
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
================================================
FILE: m3dm_runner.py
================================================
import torch
from tqdm import tqdm
import os
from feature_extractors import multiple_features
from dataset import get_data_loader
class M3DM():
def __init__(self, args):
self.args = args
self.image_size = args.img_size
self.count = args.max_sample
if args.method_name == 'DINO':
self.methods = {
"DINO": multiple_features.RGBFeatures(args),
}
elif args.method_name == 'Point_MAE':
self.methods = {
"Point_MAE": multiple_features.PointFeatures(args),
}
elif args.method_name == 'Fusion':
self.methods = {
"Fusion": multiple_features.FusionFeatures(args),
}
elif args.method_name == 'DINO+Point_MAE':
self.methods = {
"DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures(args),
}
elif args.method_name == 'DINO+Point_MAE+add':
self.methods = {
"DINO+Point_MAE": multiple_features.DoubleRGBPointFeatures_add(args),
}
elif args.method_name == 'DINO+Point_MAE+Fusion':
self.methods = {
"DINO+Point_MAE+Fusion": multiple_features.TripleFeatures(args),
}
def fit(self, class_name):
train_loader = get_data_loader("train", class_name=class_name, img_size=self.image_size, args=self.args)
flag = 0
for sample, _ in tqdm(train_loader, desc=f'Extracting train features for class {class_name}'):
for method in self.methods.values():
if self.args.save_feature:
method.add_sample_to_mem_bank(sample, class_name=class_name)
else:
method.add_sample_to_mem_bank(sample)
flag += 1
if flag > self.count:
flag = 0
break
for method_name, method in self.methods.items():
print(f'\n\nRunning coreset for {method_name} on class {class_name}...')
method.run_coreset()
if self.args.memory_bank == 'multiple':
flag = 0
for sample, _ in tqdm(train_loader, desc=f'Running late fusion for {method_name} on class {class_name}..'):
for method_name, method in self.methods.items():
method.add_sample_to_late_fusion_mem_bank(sample)
flag += 1
if flag > self.count:
flag = 0
break
for method_name, method in self.methods.items():
print(f'\n\nTraining Dicision Layer Fusion for {method_name} on class {class_name}...')
method.run_late_fusion()
def evaluate(self, class_name):
image_rocaucs = dict()
pixel_rocaucs = dict()
au_pros = dict()
test_loader = get_data_loader("test", class_name=class_name, img_size=self.image_size, args=self.args)
path_list = []
with torch.no_grad():
for sample, mask, label, rgb_path in tqdm(test_loader, desc=f'Extracting test features for class {class_name}'):
for method in self.methods.values():
method.predict(sample, mask, label)
path_list.append(rgb_path)
for method_name, method in self.methods.items():
method.calculate_metrics()
image_rocaucs[method_name] = round(method.image_rocauc, 3)
pixel_rocaucs[method_name] = round(method.pixel_rocauc, 3)
au_pros[method_name] = round(method.au_pro, 3)
print(
f'Class: {class_name}, {method_name} Image ROCAUC: {method.image_rocauc:.3f}, {method_name} Pixel ROCAUC: {method.pixel_rocauc:.3f}, {method_name} AU-PRO: {method.au_pro:.3f}')
if self.args.save_preds:
method.save_prediction_maps('./pred_maps', path_list)
return image_rocaucs, pixel_rocaucs, au_pros
================================================
FILE: main.py
================================================
import argparse
from m3dm_runner import M3DM
from dataset import eyecandies_classes, mvtec3d_classes
import pandas as pd
def run_3d_ads(args):
if args.dataset_type=='eyecandies':
classes = eyecandies_classes()
elif args.dataset_type=='mvtec3d':
classes = mvtec3d_classes()
METHOD_NAMES = [args.method_name]
image_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
pixel_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
au_pros_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])
for cls in classes:
model = M3DM(args)
model.fit(cls)
image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls)
image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs)
pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs)
au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros)
print(f"\nFinished running on class {cls}")
print("################################################################################\n\n")
image_rocaucs_df['Mean'] = round(image_rocaucs_df.iloc[:, 1:].mean(axis=1),3)
pixel_rocaucs_df['Mean'] = round(pixel_rocaucs_df.iloc[:, 1:].mean(axis=1),3)
au_pros_df['Mean'] = round(au_pros_df.iloc[:, 1:].mean(axis=1),3)
print("\n\n################################################################################")
print("############################# Image ROCAUC Results #############################")
print("################################################################################\n")
print(image_rocaucs_df.to_markdown(index=False))
print("\n\n################################################################################")
print("############################# Pixel ROCAUC Results #############################")
print("################################################################################\n")
print(pixel_rocaucs_df.to_markdown(index=False))
print("\n\n##########################################################################")
print("############################# AU PRO Results #############################")
print("##########################################################################\n")
print(au_pros_df.to_markdown(index=False))
with open("results/image_rocauc_results.md", "a") as tf:
tf.write(image_rocaucs_df.to_markdown(index=False))
with open("results/pixel_rocauc_results.md", "a") as tf:
tf.write(pixel_rocaucs_df.to_markdown(index=False))
with open("results/aupro_results.md", "a") as tf:
tf.write(au_pros_df.to_markdown(index=False))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str,
choices=['DINO', 'Point_MAE', 'Fusion', 'DINO+Point_MAE', 'DINO+Point_MAE+Fusion', 'DINO+Point_MAE+add'],
help='Anomaly detection modal name.')
parser.add_argument('--max_sample', default=400, type=int,
help='Max sample number.')
parser.add_argument('--memory_bank', default='multiple', type=str,
choices=["multiple", "single"],
help='memory bank mode: "multiple", "single".')
parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str,
choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'],
help='Timm checkpoints name of RGB backbone.')
parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert'],
help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert].')
parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str,
help='Checkpoints for fusion module.')
parser.add_argument('--save_feature', default=False, action='store_true',
help='Save feature for training fusion block.')
parser.add_argument('--use_uff', default=False, action='store_true',
help='Use UFF module.')
parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str,
help='Save feature for training fusion block.')
parser.add_argument('--save_preds', default=False, action='store_true',
help='Save predicts results.')
parser.add_argument('--group_size', default=128, type=int,
help='Point group size of Point Transformer.')
parser.add_argument('--num_group', default=1024, type=int,
help='Point groups number of Point Transformer.')
parser.add_argument('--random_state', default=None, type=int,
help='random_state for random project')
parser.add_argument('--dataset_type', default='mvtec3d', type=str, choices=['mvtec3d', 'eyecandies'],
help='Dataset type for training or testing')
parser.add_argument('--dataset_path', default='datasets/mvtec3d', type=str,
help='Dataset store path')
parser.add_argument('--img_size', default=224, type=int,
help='Images size for model')
parser.add_argument('--xyz_s_lambda', default=1.0, type=float,
help='xyz_s_lambda')
parser.add_argument('--xyz_smap_lambda', default=1.0, type=float,
help='xyz_smap_lambda')
parser.add_argument('--rgb_s_lambda', default=0.1, type=float,
help='rgb_s_lambda')
parser.add_argument('--rgb_smap_lambda', default=0.1, type=float,
help='rgb_smap_lambda')
parser.add_argument('--fusion_s_lambda', default=1.0, type=float,
help='fusion_s_lambda')
parser.add_argument('--fusion_smap_lambda', default=1.0, type=float,
help='fusion_smap_lambda')
parser.add_argument('--coreset_eps', default=0.9, type=float,
help='eps for sparse project')
parser.add_argument('--f_coreset', default=0.1, type=float,
help='eps for sparse project')
parser.add_argument('--asy_memory_bank', default=None, type=int,
help='build an asymmetric memory bank for point clouds')
parser.add_argument('--ocsvm_nu', default=0.5, type=float,
help='ocsvm nu')
parser.add_argument('--ocsvm_maxiter', default=1000, type=int,
help='ocsvm maxiter')
parser.add_argument('--rm_zero_for_project', default=False, action='store_true',
help='Save predicts results.')
args = parser.parse_args()
run_3d_ads(args)
================================================
FILE: models/feature_fusion.py
================================================
import torch
import torch.nn as nn
import math
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class FeatureFusionBlock(nn.Module):
def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.):
super().__init__()
self.xyz_dim = xyz_dim
self.rgb_dim = rgb_dim
self.xyz_norm = nn.LayerNorm(xyz_dim)
self.xyz_mlp = Mlp(in_features=xyz_dim, hidden_features=int(xyz_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)
self.rgb_norm = nn.LayerNorm(rgb_dim)
self.rgb_mlp = Mlp(in_features=rgb_dim, hidden_features=int(rgb_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)
self.rgb_head = nn.Linear(rgb_dim, 256)
self.xyz_head = nn.Linear(xyz_dim, 256)
self.T = 1
def feature_fusion(self, xyz_feature, rgb_feature):
xyz_feature = self.xyz_mlp(self.xyz_norm(xyz_feature))
rgb_feature = self.rgb_mlp(self.rgb_norm(rgb_feature))
feature = torch.cat([xyz_feature, rgb_feature], dim=2)
return feature
def contrastive_loss(self, q, k):
# normalize
q = nn.functional.normalize(q, dim=1)
k = nn.functional.normalize(k, dim=1)
# gather all targets
# Einstein sum is more intuitive
logits = torch.einsum('nc,mc->nm', [q, k]) / self.T
N = logits.shape[0] # batch size per GPU
labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()
return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)
def reparameterize(self, mu, logvar):
"""
Will a single z be enough ti compute the expectation
for the loss??
:param mu: (Tensor) Mean of the latent Gaussian
:param logvar: (Tensor) Standard deviation of the latent Gaussian
:return:
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, xyz_feature, rgb_feature):
feature = self.feature_fusion(xyz_feature, rgb_feature)
feature_xyz = feature[:,:, :self.xyz_dim]
feature_rgb = feature[:,:, self.xyz_dim:]
q = self.rgb_head(feature_rgb.view(-1, feature_rgb.shape[2]))
k = self.xyz_head(feature_xyz.view(-1, feature_xyz.shape[2]))
xyz_feature = xyz_feature.view(-1, xyz_feature.shape[2])
rgb_feature = rgb_feature.view(-1, rgb_feature.shape[2])
patch_no_zeros_indices = torch.nonzero(torch.all(xyz_feature != 0, dim=1))
loss = self.contrastive_loss(q[patch_no_zeros_indices,:].squeeze(), k[patch_no_zeros_indices,:].squeeze())
return loss
================================================
FILE: models/models.py
================================================
import torch
import torch.nn as nn
import timm
from timm.models.layers import DropPath, trunc_normal_
from pointnet2_ops import pointnet2_utils
from knn_cuda import KNN
class Model(torch.nn.Module):
def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino', out_indices=None, checkpoint_path='',
pool_last=False, xyz_backbone_name='Point_MAE', group_size=128, num_group=1024):
super().__init__()
# 'vit_base_patch8_224_dino'
# Determine if to output features.
self.device = device
kwargs = {'features_only': True if out_indices else False}
if out_indices:
kwargs.update({'out_indices': out_indices})
## RGB backbone
self.rgb_backbone = timm.create_model(model_name=rgb_backbone_name, pretrained=True, checkpoint_path=checkpoint_path,
**kwargs)
## XYZ backbone
if xyz_backbone_name=='Point_MAE':
self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group)
self.xyz_backbone.load_model_from_ckpt("checkpoints/pointmae_pretrain.pth")
elif xyz_backbone_name=='Point_Bert':
self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group, encoder_dims=256)
self.xyz_backbone.load_model_from_pb_ckpt("checkpoints/Point-BERT.pth")
def forward_rgb_features(self, x):
x = self.rgb_backbone.patch_embed(x)
x = self.rgb_backbone._pos_embed(x)
x = self.rgb_backbone.norm_pre(x)
if self.rgb_backbone.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.rgb_backbone.blocks(x)
x = self.rgb_backbone.norm(x)
feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28)
return feat
def forward(self, rgb, xyz):
rgb_features = self.forward_rgb_features(rgb)
xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz)
return rgb_features, xyz_features, center, ori_idx, center_idx
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()
return fps_data, fps_idx
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
self.knn = KNN(k=self.group_size, transpose_mode=True)
def forward(self, xyz):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
batch_size, num_points, _ = xyz.shape
# fps the centers out
center, center_idx = fps(xyz.contiguous(), self.num_group) # B G 3
# knn to get the neighborhood
_, idx = self.knn(xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
ori_idx = idx
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :]
neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous()
# normalize
neighborhood = neighborhood - center.unsqueeze(2)
return neighborhood, center, ori_idx, center_idx
class Encoder(nn.Module):
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(3, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n, _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 3)
# encoder
feature = self.first_conv(point_groups.transpose(2, 1))
feature_global = torch.max(feature, dim=2, keepdim=True)[0]
feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)
feature = self.second_conv(feature)
feature_global = torch.max(feature, dim=2, keepdim=False)[0]
return feature_global.reshape(bs, g, self.encoder_channel)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q * self.scale) @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class TransformerEncoder(nn.Module):
""" Transformer Encoder without hierarchical structure
"""
def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
super().__init__()
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate
)
for i in range(depth)])
def forward(self, x, pos):
feature_list = []
fetch_idx = [3, 7, 11]
for i, block in enumerate(self.blocks):
x = block(x + pos)
if i in fetch_idx:
feature_list.append(x)
return feature_list
class PointTransformer(nn.Module):
def __init__(self, group_size=128, num_group=1024, encoder_dims=384):
super().__init__()
self.trans_dim = 384
self.depth = 12
self.drop_path_rate = 0.1
self.num_heads = 6
self.group_size = group_size
self.num_group = num_group
# grouper
self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)
# define the encoder
self.encoder_dims = encoder_dims
if self.encoder_dims != self.trans_dim:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
self.reduce_dim = nn.Linear(self.encoder_dims, self.trans_dim)
self.encoder = Encoder(encoder_channel=self.encoder_dims)
# bridge encoder and transformer
self.pos_embed = nn.Sequential(
nn.Linear(3, 128),
nn.GELU(),
nn.Linear(128, self.trans_dim)
)
dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]
self.blocks = TransformerEncoder(
embed_dim=self.trans_dim,
depth=self.depth,
drop_path_rate=dpr,
num_heads=self.num_heads
)
self.norm = nn.LayerNorm(self.trans_dim)
def load_model_from_ckpt(self, bert_ckpt_path):
if bert_ckpt_path is not None:
ckpt = torch.load(bert_ckpt_path)
base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
for k in list(base_ckpt.keys()):
if k.startswith('MAE_encoder'):
base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k]
del base_ckpt[k]
elif k.startswith('base_model'):
base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
del base_ckpt[k]
incompatible = self.load_state_dict(base_ckpt, strict=False)
#if incompatible.missing_keys:
# print('missing_keys')
# print(
# incompatible.missing_keys
# )
#if incompatible.unexpected_keys:
# print('unexpected_keys')
# print(
# incompatible.unexpected_keys
# )
# print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
def load_model_from_pb_ckpt(self, bert_ckpt_path):
ckpt = torch.load(bert_ckpt_path)
base_ckpt = {k.replace("module.", ""): v for k, v in ckpt['base_model'].items()}
for k in list(base_ckpt.keys()):
if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):
base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]
elif k.startswith('base_model'):
base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
del base_ckpt[k]
incompatible = self.load_state_dict(base_ckpt, strict=False)
if incompatible.missing_keys:
print('missing_keys')
print(
incompatible.missing_keys
)
if incompatible.unexpected_keys:
print('unexpected_keys')
print(
incompatible.unexpected_keys
)
print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')
def forward(self, pts):
if self.encoder_dims != self.trans_dim:
B,C,N = pts.shape
pts = pts.transpose(-1, -2) # B N 3
# divide the point clo ud in the same form. This is important
neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
# # generate mask
# bool_masked_pos = self._mask_center(center, no_mask = False) # B G
# encoder the input cloud blocks
group_input_tokens = self.encoder(neighborhood) # B G N
group_input_tokens = self.reduce_dim(group_input_tokens)
# prepare cls
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
# add pos embedding
pos = self.pos_embed(center)
# final input
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
pos = torch.cat((cls_pos, pos), dim=1)
# transformer
feature_list = self.blocks(x, pos)
feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list]
x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152
return x, center, ori_idx, center_idx
else:
B, C, N = pts.shape
pts = pts.transpose(-1, -2) # B N 3
# divide the point clo ud in the same form. This is important
neighborhood, center, ori_idx, center_idx = self.group_divider(pts)
group_input_tokens = self.encoder(neighborhood) # B G N
pos = self.pos_embed(center)
# final input
x = group_input_tokens
# transformer
feature_list = self.blocks(x, pos)
feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list]
x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152
return x, center, ori_idx, center_idx
================================================
FILE: models/pointnet2_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np
def timeit(tag, t):
print("{}: {}s".format(tag, time() - t))
return time()
def pc_normalize(pc):
l = pc.shape[0]
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
mask = dist < distance
distance[mask] = dist[mask]
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
B, N, C = xyz.shape
S = npoint
fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
new_xyz = index_points(xyz, fps_idx)
idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
if points is not None:
grouped_points = index_points(points, idx)
new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
else:
new_points = grouped_xyz_norm
if returnfps:
return new_xyz, new_points, grouped_xyz, fps_idx
else:
return new_xyz, new_points
def sample_and_group_all(xyz, points):
"""
Input:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, 1, 3]
new_points: sampled points data, [B, 1, N, 3+D]
"""
device = xyz.device
B, N, C = xyz.shape
new_xyz = torch.zeros(B, 1, C).to(device)
grouped_xyz = xyz.view(B, 1, N, C)
if points is not None:
new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
else:
new_points = grouped_xyz
return new_xyz, new_points
def interpolating_points(xyz1, xyz2, points2):
"""
Input:
xyz1: input points position data, [B, C, N]
xyz2: sampled input points position data, [B, C, S]
points2: input points data, [B, D, S]
Return:
new_points: upsampled points data, [B, D', N]
"""
xyz1 = xyz1.permute(0, 2, 1)
xyz2 = xyz2.permute(0, 2, 1)
points2 = points2.permute(0, 2, 1)
B, N, C = xyz1.shape
_, S, _ = xyz2.shape
if S == 1:
interpolated_points = points2.repeat(1, N, 1)
else:
dists = square_distance(xyz1, xyz2)
dists, idx = dists.sort(dim=-1)
dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
dist_recip = 1.0 / (dists + 1e-8)
norm = torch.sum(dist_recip, dim=2, keepdim=True)
weight = dist_recip / norm
interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
interpolated_points = interpolated_points.permute(0, 2, 1)
return interpolated_points
================================================
FILE: requirements.txt
================================================
numpy
Pillow
scikit-learn
scipy
timm
torch
torchvision
tqdm
wget
tifffile
scikit-image
kornia
imageio
tensorboard
opencv-python
setuptools==59.5.0;
================================================
FILE: utils/au_pro_util.py
================================================
"""
Code based on the official MVTec 3D-AD evaluation code found at
https://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz
Utility functions that compute a PRO curve and its definite integral, given
pairs of anomaly and ground truth maps.
The PRO curve can also be integrated up to a constant integration limit.
"""
import numpy as np
from scipy.ndimage.measurements import label
from bisect import bisect
class GroundTruthComponent:
"""
Stores sorted anomaly scores of a single ground truth component.
Used to efficiently compute the region overlap for many increasing thresholds.
"""
def __init__(self, anomaly_scores):
"""
Initialize the module.
Args:
anomaly_scores: List of all anomaly scores within the ground truth
component as numpy array.
"""
# Keep a sorted list of all anomaly scores within the component.
self.anomaly_scores = anomaly_scores.copy()
self.anomaly_scores.sort()
# Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels.
self.index = 0
# The last evaluated threshold.
self.last_threshold = None
def compute_overlap(self, threshold):
"""
Compute the region overlap for a specific threshold.
Thresholds must be passed in increasing order.
Args:
threshold: Threshold to compute the region overlap.
Returns:
Region overlap for the specified threshold.
"""
if self.last_threshold is not None:
assert self.last_threshold <= threshold
# Increase the index until it points to an anomaly score that is just above the specified threshold.
while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold):
self.index += 1
# Compute the fraction of component pixels that are correctly segmented as anomalous.
return 1.0 - self.index / len(self.anomaly_scores)
def trapezoid(x, y, x_max=None):
"""
This function calculates the definit integral of a curve given by x- and corresponding y-values.
In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by
setting a value x_max.
Points that do not have a finite x or y value will be ignored with a warning.
Args:
x: Samples from the domain of the function to integrate need to be sorted in ascending order. May contain
the same value multiple times. In that case, the order of the corresponding y values will affect the
integration with the trapezoidal rule.
y: Values of the function corresponding to x values.
x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its
neighbors. Must not lie outside of the range of x.
Returns:
Area under the curve.
"""
x = np.array(x)
y = np.array(y)
finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))
if not finite_mask.all():
print(
"""WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.""")
x = x[finite_mask]
y = y[finite_mask]
# Introduce a correction term if max_x is not an element of x.
correction = 0.
if x_max is not None:
if x_max not in x:
# Get the insertion index that would keep x sorted after np.insert(x, ins, x_max).
ins = bisect(x, x_max)
# x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x).
assert 0 < ins < len(x)
# Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not
# know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1].
y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1]))
correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])
# Cut off at x_max.
mask = x <= x_max
x = x[mask]
y = y[mask]
# Return area under the curve using the trapezoidal rule.
return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction
def collect_anomaly_scores(anomaly_maps, ground_truth_maps):
"""
Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false
positive pixel from anomaly maps.
Args:
anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
an anomaly.
Returns:
ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each
component, a sorted list of its anomaly scores is stored.
anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list
can be used to quickly select thresholds that fix a certain false positive rate.
"""
# Make sure an anomaly map is present for each ground truth map.
assert len(anomaly_maps) == len(ground_truth_maps)
# Initialize ground truth components and scores of potential fp pixels.
ground_truth_components = []
anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size)
# Structuring element for computing connected components.
structure = np.ones((3, 3), dtype=int)
# Collect anomaly scores within each ground truth region and for all potential fp pixels.
ok_index = 0
for gt_map, prediction in zip(ground_truth_maps, anomaly_maps):
# Compute the connected components in the ground truth map.
labeled, n_components = label(gt_map, structure)
# Store all potential fp scores.
num_ok_pixels = len(prediction[labeled == 0])
anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy()
ok_index += num_ok_pixels
# Fetch anomaly scores within each GT component.
for k in range(n_components):
component_scores = prediction[labeled == (k + 1)]
ground_truth_components.append(GroundTruthComponent(component_scores))
# Sort all potential false positive scores.
anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index)
anomaly_scores_ok_pixels.sort()
return ground_truth_components, anomaly_scores_ok_pixels
def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):
"""
Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground
truth maps. The number of interpolation points can be set manually.
Args:
anomaly_maps: List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.
ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels
for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains
an anomaly.
num_thresholds: Number of thresholds to compute the PRO curve.
Returns:
fprs: List of false positive rates.
pros: List of correspoding PRO values.
"""
# Fetch sorted anomaly scores.
ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps)
# Select equidistant thresholds.
threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int)
fprs = [1.0]
pros = [1.0]
for pos in threshold_positions:
threshold = anomaly_scores_ok_pixels[pos]
# Compute the false positive rate for this threshold.
fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels)
# Compute the PRO value for this threshold.
pro = 0.0
for component in ground_truth_components:
pro += component.compute_overlap(threshold)
pro /= len(ground_truth_components)
fprs.append(fpr)
pros.append(pro)
# Return (FPR/PRO) pairs in increasing FPR order.
fprs = fprs[::-1]
pros = pros[::-1]
return fprs, pros
def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100):
"""
Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images.
Args:
gts: List of tensors that contain the ground truth images for a single dataset object.
predictions: List of tensors containing anomaly images for each ground truth image.
integration_limit: Integration limit to use when computing the area under the PRO curve.
num_thresholds: Number of thresholds to use to sample the area under the PRO curve.
Returns:
au_pro: Area under the PRO curve computed up to the given integration limit.
pro_curve: PRO curve values for localization (fpr,pro).
"""
# Compute the PRO curve.
pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds)
# Compute the area under the PRO curve.
au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit)
au_pro /= integration_limit
# Return the evaluation metrics.
return au_pro, pro_curve
================================================
FILE: utils/lr_sched.py
================================================
import math
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr
================================================
FILE: utils/misc.py
================================================
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import torch
import torch.distributed as dist
from torch._six import inf
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def save_model_gan(args, epoch, model, discriminator, model_without_ddp, discriminator_without_ddp,
optimizer_g, optimizer_d, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'discriminator_without_ddp': discriminator_without_ddp.state_dict(),
'optimizer_g': optimizer_g.state_dict(),
'optimizer_d': optimizer_d.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
discriminator.save_checkpoint(save_dir=args.output_dir, tag="checkpoint_d-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def load_model_gan(args, model_without_ddp, discriminator_without_ddp,
optimizer_g, optimizer_d, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
discriminator_without_ddp.load_state_dict(checkpoint['discriminator'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer_d' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer_d.load_state_dict(checkpoint['optimizer_d'])
if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer_g.load_state_dict(checkpoint['optimizer_g'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
================================================
FILE: utils/mvtec3d_util.py
================================================
import tifffile as tiff
import torch
def organized_pc_to_unorganized_pc(organized_pc):
return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2])
def read_tiff_organized_pc(path):
tiff_img = tiff.imread(path)
return tiff_img
def resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True):
torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous()
torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width),
mode='nearest')
if tensor_out:
return torch_resized_organized_pc.squeeze(dim=0).contiguous()
else:
return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy()
def organized_pc_to_depth_map(organized_pc):
return organized_pc[:, :, 2]
================================================
FILE: utils/preprocess_eyecandies.py
================================================
import os
from shutil import copyfile
import cv2
import numpy as np
import tifffile
import yaml
import imageio.v3 as iio
import math
import argparse
# The same camera has been used for all the images
FOCAL_LENGTH = 711.11
def load_and_convert_depth(depth_img, info_depth):
with open(info_depth) as f:
data = yaml.safe_load(f)
mind, maxd = data["normalization"]["min"], data["normalization"]["max"]
dimg = iio.imread(depth_img)
dimg = dimg.astype(np.float32)
dimg = dimg / 65535.0 * (maxd - mind) + mind
return dimg
def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length):
# input depth map (in meters) --- cfr previous section
depth_mt = load_and_convert_depth(depth_img, info_depth)
# input pose
pose = np.loadtxt(pose_txt)
# camera intrinsics
height, width = depth_mt.shape[:2]
intrinsics_4x4 = np.array([
[focal_length, 0, width / 2, 0],
[0, focal_length, height / 2, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]]
)
# build the camera projection matrix
camera_proj = intrinsics_4x4 @ pose
# build the (u, v, 1, 1/depth) vectors (non optimized version)
camera_vectors = np.zeros((width * height, 4))
count=0
for j in range(height):
for i in range(width):
camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]])
count += 1
# invert and apply to each 4-vector
hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T
# print(hom_3d_pts.shape)
# remove the homogeneous coordinate
pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T
return pcd[:, :3]
def remove_point_cloud_background(pc):
# The second dim is z
dz = pc[256,1] - pc[-256,1]
dy = pc[256,2] - pc[-256,2]
norm = math.sqrt(dz**2 + dy**2)
start_points = np.array([0, pc[-256, 1], pc[-256, 2]])
cos_theta = dy / norm
sin_theta = dz / norm
# Transform and rotation
rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]])
processed_pc = (rotation_matrix @ (pc - start_points).T).T
# Remove background point
for i in range(processed_pc.shape[0]):
if processed_pc[i,1] > -0.02:
processed_pc[i, :] = -start_points
if processed_pc[i,2] > 1.8:
processed_pc[i, :] = -start_points
elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1:
processed_pc[i, :] = -start_points
processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points
index = [0, 2, 1]
processed_pc = processed_pc[:,index]
return processed_pc*[0.1, -0.1, 0.1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help="Original Eyecandies dataset path.")
parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help="Processed Eyecandies dataset path")
args = parser.parse_args()
os.mkdir(args.target_dir)
categories_list = os.listdir(args.dataset_path)
for category_dir in categories_list:
category_root_path = os.path.join(args.dataset_path, category_dir)
category_train_path = os.path.join(category_root_path, '/train/data')
category_test_path = os.path.join(category_root_path, '/test_public/data')
category_target_path = os.path.join(args.target_dir, category_dir)
os.mkdir(category_target_path)
os.mkdir(os.path.join(category_target_path, 'train'))
category_target_train_good_path = os.path.join(category_target_path, 'train/good')
category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb')
category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz')
os.mkdir(category_target_train_good_path)
os.mkdir(category_target_train_good_rgb_path)
os.mkdir(category_target_train_good_xyz_path)
os.mkdir(os.path.join(category_target_path, 'test'))
category_target_test_good_path = os.path.join(category_target_path, 'test/good')
category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb')
category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz')
category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt')
os.mkdir(category_target_test_good_path)
os.mkdir(category_target_test_good_rgb_path)
os.mkdir(category_target_test_good_xyz_path)
os.mkdir(category_target_test_good_gt_path)
category_target_test_bad_path = os.path.join(category_target_path, 'test/bad')
category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb')
category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz')
category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt')
os.mkdir(category_target_test_bad_path)
os.mkdir(category_target_test_bad_rgb_path)
os.mkdir(category_target_test_bad_xyz_path)
os.mkdir(category_target_test_bad_gt_path)
category_train_files = os.listdir(category_train_path)
num_train_files = len(category_train_files)//17
for i in range(0, num_train_files):
pc = depth_to_pointcloud(
os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'),
os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'),
os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'),
FOCAL_LENGTH,
)
pc = remove_point_cloud_background(pc)
pc = pc.reshape(512,512,3)
tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png'))
category_test_files = os.listdir(category_test_path)
num_test_files = len(category_test_files)//17
for i in range(0, num_test_files):
mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png'))
if np.any(mask):
pc = depth_to_pointcloud(
os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
FOCAL_LENGTH,
)
pc = remove_point_cloud_background(pc)
pc = pc.reshape(512,512,3)
tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc)
cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask)
copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png'))
else:
pc = depth_to_pointcloud(
os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),
os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),
os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),
FOCAL_LENGTH,
)
pc = remove_point_cloud_background(pc)
pc = pc.reshape(512,512,3)
tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)
cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask)
copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png'))
================================================
FILE: utils/preprocessing.py
================================================
import os
import numpy as np
import tifffile as tiff
import open3d as o3d
from pathlib import Path
from PIL import Image
import math
import mvtec3d_util as mvt_util
import argparse
def get_edges_of_pc(organized_pc):
unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2])
unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0)
unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0)
unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0)
unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:]
return unorganized_edges_pc
def get_plane_eq(unorganized_pc,ransac_n_pts=50):
o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc))
plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000)
return plane_model
def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005):
# PREP PC
unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean)
unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
clean_planeless_unorganized_pc = unorganized_pc.copy()
planeless_unorganized_rgb = unorganized_rgb.copy()
# REMOVE PLANE
plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean))
distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T))
plane_indices = np.argwhere(distances < distance_threshold)
planeless_unorganized_rgb[plane_indices] = 0
clean_planeless_unorganized_pc[plane_indices] = 0
clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0],
organized_pc_clean.shape[1],
organized_pc_clean.shape[2])
planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0],
organized_rgb.shape[1],
organized_rgb.shape[2])
return clean_planeless_organized_pc, planeless_organized_rgb
def connected_components_cleaning(organized_pc, organized_rgb, image_path):
unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc)
unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)
nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]
unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]
o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))
labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False))
unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True)
max_label = labels.max()
if max_label>0:
print("##########################################################################")
print(f"Point cloud file {image_path} has {max_label + 1} clusters")
print(f"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}")
print("##########################################################################\n\n")
largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)]
outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id)
outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array]
unorganized_pc[outlier_indices_original_pc_array] = 0
unorganized_rgb[outlier_indices_original_pc_array] = 0
organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0],
organized_pc.shape[1],
organized_pc.shape[2])
organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0],
organized_rgb.shape[1],
organized_rgb.shape[2])
return organized_clustered_pc, organized_clustered_rgb
def roundup_next_100(x):
return int(math.ceil(x / 100.0)) * 100
def pad_cropped_pc(cropped_pc, single_channel=False):
orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1]
round_orig_h = roundup_next_100(orig_h)
round_orig_w = roundup_next_100(orig_w)
large_side = max(round_orig_h, round_orig_w)
a = (large_side - orig_h) // 2
aa = large_side - a - orig_h
b = (large_side - orig_w) // 2
bb = large_side - b - orig_w
if single_channel:
return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant')
else:
return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant')
def preprocess_pc(tiff_path):
# READ FILES
organized_pc = mvt_util.read_tiff_organized_pc(tiff_path)
rgb_path = str(tiff_path).replace("xyz", "rgb").replace("tiff", "png")
gt_path = str(tiff_path).replace("xyz", "gt").replace("tiff", "png")
organized_rgb = np.array(Image.open(rgb_path))
organized_gt = None
gt_exists = os.path.isfile(gt_path)
if gt_exists:
organized_gt = np.array(Image.open(gt_path))
# REMOVE PLANE
planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb)
# PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE)
padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False)
padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False)
if gt_exists:
padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True)
organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path)
# SAVE PREPROCESSED FILES
tiff.imsave(tiff_path, organized_clustered_pc)
Image.fromarray(organized_clustered_rgb).save(rgb_path)
if gt_exists:
Image.fromarray(padded_organized_gt).save(gt_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD')
parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)')
args = parser.parse_args()
root_path = args.dataset_path
paths = Path(root_path).rglob('*.tiff')
print(f"Found {len(list(paths))} tiff files in {root_path}")
processed_files = 0
for path in Path(root_path).rglob('*.tiff'):
preprocess_pc(path)
processed_files += 1
if processed_files % 50 == 0:
print(f"Processed {processed_files} tiff files...")
================================================
FILE: utils/utils.py
================================================
import numpy as np
import random
import torch
from torchvision import transforms
from PIL import ImageFilter
def set_seeds(seed: int = 0) -> None:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
class KNNGaussianBlur(torch.nn.Module):
def __init__(self, radius : int = 4):
super().__init__()
self.radius = radius
self.unload = transforms.ToPILImage()
self.load = transforms.ToTensor()
self.blur_kernel = ImageFilter.GaussianBlur(radius=4)
def __call__(self, img):
map_max = img.max()
final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)) * map_max
return final_map
gitextract_ziog4jk2/
├── LICENSE
├── README.md
├── dataset.py
├── engine_fusion_pretrain.py
├── feature_extractors/
│ ├── features.py
│ └── multiple_features.py
├── fusion_pretrain.py
├── m3dm_runner.py
├── main.py
├── models/
│ ├── feature_fusion.py
│ ├── models.py
│ └── pointnet2_utils.py
├── requirements.txt
└── utils/
├── au_pro_util.py
├── lr_sched.py
├── misc.py
├── mvtec3d_util.py
├── preprocess_eyecandies.py
├── preprocessing.py
└── utils.py
SYMBOL INDEX (185 symbols across 17 files)
FILE: dataset.py
function eyecandies_classes (line 10) | def eyecandies_classes():
function mvtec3d_classes (line 24) | def mvtec3d_classes():
class BaseAnomalyDetectionDataset (line 40) | class BaseAnomalyDetectionDataset(Dataset):
method __init__ (line 42) | def __init__(self, split, class_name, img_size, dataset_path='datasets...
class PreTrainTensorDataset (line 53) | class PreTrainTensorDataset(Dataset):
method __init__ (line 54) | def __init__(self, root_path):
method __len__ (line 60) | def __len__(self):
method __getitem__ (line 63) | def __getitem__(self, idx):
class TrainDataset (line 72) | class TrainDataset(BaseAnomalyDetectionDataset):
method __init__ (line 73) | def __init__(self, class_name, img_size, dataset_path='datasets/eyecan...
method load_dataset (line 77) | def load_dataset(self):
method __len__ (line 89) | def __len__(self):
method __getitem__ (line 92) | def __getitem__(self, idx):
class TestDataset (line 109) | class TestDataset(BaseAnomalyDetectionDataset):
method __init__ (line 110) | def __init__(self, class_name, img_size, dataset_path='datasets/eyecan...
method load_dataset (line 117) | def load_dataset(self):
method __len__ (line 150) | def __len__(self):
method __getitem__ (line 153) | def __getitem__(self, idx):
function get_data_loader (line 180) | def get_data_loader(split, class_name, img_size, args):
FILE: engine_fusion_pretrain.py
function train_one_epoch (line 11) | def train_one_epoch(model: torch.nn.Module,
FILE: feature_extractors/features.py
class Features (line 28) | class Features(torch.nn.Module):
method __init__ (line 30) | def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9):
method __call__ (line 97) | def __call__(self, rgb, xyz):
method add_sample_to_mem_bank (line 116) | def add_sample_to_mem_bank(self, sample):
method predict (line 119) | def predict(self, sample, mask, label):
method add_sample_to_late_fusion_mem_bank (line 122) | def add_sample_to_late_fusion_mem_bank(self, sample):
method interpolate_points (line 125) | def interpolate_points(self, rgb, xyz):
method compute_s_s_map (line 130) | def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_...
method compute_single_s_s_map (line 133) | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal=...
method run_coreset (line 136) | def run_coreset(self):
method calculate_metrics (line 139) | def calculate_metrics(self):
method save_prediction_maps (line 148) | def save_prediction_maps(self, output_path, rgb_path, save_num=5):
method run_late_fusion (line 173) | def run_late_fusion(self):
method get_coreset_idx_randomp (line 179) | def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=Tru...
FILE: feature_extractors/multiple_features.py
class RGBFeatures (line 8) | class RGBFeatures(Features):
method add_sample_to_mem_bank (line 10) | def add_sample_to_mem_bank(self, sample):
method predict (line 24) | def predict(self, sample, mask, label):
method run_coreset (line 38) | def run_coreset(self):
method compute_s_s_map (line 54) | def compute_s_s_map(self, patch, feature_map_dims, mask, label, center...
class PointFeatures (line 96) | class PointFeatures(Features):
method add_sample_to_mem_bank (line 98) | def add_sample_to_mem_bank(self, sample):
method predict (line 118) | def predict(self, sample, mask, label):
method run_coreset (line 137) | def run_coreset(self):
method compute_s_s_map (line 156) | def compute_s_s_map(self, patch, feature_map_dims, mask, label, center...
class FusionFeatures (line 198) | class FusionFeatures(Features):
method add_sample_to_mem_bank (line 200) | def add_sample_to_mem_bank(self, sample, class_name=None):
method predict (line 238) | def predict(self, sample, mask, label):
method compute_s_s_map (line 272) | def compute_s_s_map(self, patch, feature_map_dims, mask, label, center...
method run_coreset (line 310) | def run_coreset(self):
class DoubleRGBPointFeatures (line 319) | class DoubleRGBPointFeatures(Features):
method add_sample_to_mem_bank (line 321) | def add_sample_to_mem_bank(self, sample, class_name=None):
method predict (line 351) | def predict(self, sample, mask, label):
method add_sample_to_late_fusion_mem_bank (line 372) | def add_sample_to_late_fusion_mem_bank(self, sample):
method compute_s_s_map (line 413) | def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask...
method compute_single_s_s_map (line 449) | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal=...
method run_coreset (line 484) | def run_coreset(self):
class DoubleRGBPointFeatures_add (line 507) | class DoubleRGBPointFeatures_add(Features):
method add_sample_to_mem_bank (line 509) | def add_sample_to_mem_bank(self, sample, class_name=None):
method predict (line 540) | def predict(self, sample, mask, label):
method add_sample_to_late_fusion_mem_bank (line 561) | def add_sample_to_late_fusion_mem_bank(self, sample):
method run_coreset (line 601) | def run_coreset(self):
method compute_s_s_map (line 624) | def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask...
method compute_single_s_s_map (line 655) | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal=...
class TripleFeatures (line 691) | class TripleFeatures(Features):
method add_sample_to_mem_bank (line 693) | def add_sample_to_mem_bank(self, sample, class_name=None):
method predict (line 739) | def predict(self, sample, mask, label):
method add_sample_to_late_fusion_mem_bank (line 774) | def add_sample_to_late_fusion_mem_bank(self, sample):
method run_coreset (line 836) | def run_coreset(self):
method compute_s_s_map (line 871) | def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_...
method compute_single_s_s_map (line 915) | def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal=...
FILE: fusion_pretrain.py
function get_args_parser (line 32) | def get_args_parser():
function main (line 94) | def main(args):
FILE: m3dm_runner.py
class M3DM (line 9) | class M3DM():
method __init__ (line 10) | def __init__(self, args):
method fit (line 40) | def fit(self, class_name):
method evaluate (line 74) | def evaluate(self, class_name):
FILE: main.py
function run_3d_ads (line 7) | def run_3d_ads(args):
FILE: models/feature_fusion.py
class Mlp (line 5) | class Mlp(nn.Module):
method __init__ (line 6) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 15) | def forward(self, x):
class FeatureFusionBlock (line 23) | class FeatureFusionBlock(nn.Module):
method __init__ (line 24) | def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.):
method feature_fusion (line 41) | def feature_fusion(self, xyz_feature, rgb_feature):
method contrastive_loss (line 50) | def contrastive_loss(self, q, k):
method reparameterize (line 61) | def reparameterize(self, mu, logvar):
method forward (line 74) | def forward(self, xyz_feature, rgb_feature):
FILE: models/models.py
class Model (line 8) | class Model(torch.nn.Module):
method __init__ (line 10) | def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino...
method forward_rgb_features (line 36) | def forward_rgb_features(self, x):
method forward (line 50) | def forward(self, rgb, xyz):
function fps (line 60) | def fps(data, number):
class Group (line 69) | class Group(nn.Module):
method __init__ (line 70) | def __init__(self, num_group, group_size):
method forward (line 76) | def forward(self, xyz):
class Encoder (line 101) | class Encoder(nn.Module):
method __init__ (line 102) | def __init__(self, encoder_channel):
method forward (line 118) | def forward(self, point_groups):
class Mlp (line 135) | class Mlp(nn.Module):
method __init__ (line 136) | def __init__(self, in_features, hidden_features=None, out_features=Non...
method forward (line 145) | def forward(self, x):
class Attention (line 154) | class Attention(nn.Module):
method __init__ (line 155) | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, at...
method forward (line 167) | def forward(self, x):
class Block (line 182) | class Block(nn.Module):
method __init__ (line 183) | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_sc...
method forward (line 195) | def forward(self, x):
class TransformerEncoder (line 201) | class TransformerEncoder(nn.Module):
method __init__ (line 205) | def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4.,...
method forward (line 217) | def forward(self, x, pos):
class PointTransformer (line 227) | class PointTransformer(nn.Module):
method __init__ (line 228) | def __init__(self, group_size=128, num_group=1024, encoder_dims=384):
method load_model_from_ckpt (line 265) | def load_model_from_ckpt(self, bert_ckpt_path):
method load_model_from_pb_ckpt (line 294) | def load_model_from_pb_ckpt(self, bert_ckpt_path):
method forward (line 321) | def forward(self, pts):
FILE: models/pointnet2_utils.py
function timeit (line 7) | def timeit(tag, t):
function pc_normalize (line 11) | def pc_normalize(pc):
function square_distance (line 19) | def square_distance(src, dst):
function index_points (line 41) | def index_points(points, idx):
function farthest_point_sample (line 60) | def farthest_point_sample(xyz, npoint):
function query_ball_point (line 84) | def query_ball_point(radius, nsample, xyz, new_xyz):
function sample_and_group (line 107) | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=Fal...
function sample_and_group_all (line 138) | def sample_and_group_all(xyz, points):
function interpolating_points (line 157) | def interpolating_points(xyz1, xyz2, points2):
FILE: utils/au_pro_util.py
class GroundTruthComponent (line 15) | class GroundTruthComponent:
method __init__ (line 21) | def __init__(self, anomaly_scores):
method compute_overlap (line 39) | def compute_overlap(self, threshold):
function trapezoid (line 61) | def trapezoid(x, y, x_max=None):
function collect_anomaly_scores (line 113) | def collect_anomaly_scores(anomaly_maps, ground_truth_maps):
function compute_pro (line 166) | def compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):
function calculate_au_pro (line 213) | def calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresh...
FILE: utils/lr_sched.py
function adjust_learning_rate (line 3) | def adjust_learning_rate(optimizer, epoch, args):
FILE: utils/misc.py
class SmoothedValue (line 13) | class SmoothedValue(object):
method __init__ (line 18) | def __init__(self, window_size=20, fmt=None):
method update (line 26) | def update(self, value, n=1):
method synchronize_between_processes (line 31) | def synchronize_between_processes(self):
method median (line 45) | def median(self):
method avg (line 50) | def avg(self):
method global_avg (line 55) | def global_avg(self):
method max (line 59) | def max(self):
method value (line 63) | def value(self):
method __str__ (line 66) | def __str__(self):
class MetricLogger (line 75) | class MetricLogger(object):
method __init__ (line 76) | def __init__(self, delimiter="\t"):
method update (line 80) | def update(self, **kwargs):
method __getattr__ (line 89) | def __getattr__(self, attr):
method __str__ (line 97) | def __str__(self):
method synchronize_between_processes (line 105) | def synchronize_between_processes(self):
method add_meter (line 109) | def add_meter(self, name, meter):
method log_every (line 112) | def log_every(self, iterable, print_freq, header=None):
function setup_for_distributed (line 159) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 176) | def is_dist_avail_and_initialized():
function get_world_size (line 184) | def get_world_size():
function get_rank (line 190) | def get_rank():
function is_main_process (line 196) | def is_main_process():
function save_on_master (line 200) | def save_on_master(*args, **kwargs):
function init_distributed_mode (line 205) | def init_distributed_mode(args):
class NativeScalerWithGradNormCount (line 240) | class NativeScalerWithGradNormCount:
method __init__ (line 243) | def __init__(self):
method __call__ (line 246) | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, c...
method state_dict (line 262) | def state_dict(self):
method load_state_dict (line 265) | def load_state_dict(self, state_dict):
function get_grad_norm_ (line 269) | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
function save_model (line 284) | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_sc...
function save_model_gan (line 303) | def save_model_gan(args, epoch, model, discriminator, model_without_ddp,...
function load_model (line 326) | def load_model(args, model_without_ddp, optimizer, loss_scaler):
function load_model_gan (line 342) | def load_model_gan(args, model_without_ddp, discriminator_without_ddp,
function all_reduce_mean (line 364) | def all_reduce_mean(x):
FILE: utils/mvtec3d_util.py
function organized_pc_to_unorganized_pc (line 5) | def organized_pc_to_unorganized_pc(organized_pc):
function read_tiff_organized_pc (line 9) | def read_tiff_organized_pc(path):
function resize_organized_pc (line 14) | def resize_organized_pc(organized_pc, target_height=224, target_width=22...
function organized_pc_to_depth_map (line 24) | def organized_pc_to_depth_map(organized_pc):
FILE: utils/preprocess_eyecandies.py
function load_and_convert_depth (line 14) | def load_and_convert_depth(depth_img, info_depth):
function depth_to_pointcloud (line 24) | def depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length):
function remove_point_cloud_background (line 58) | def remove_point_cloud_background(pc):
FILE: utils/preprocessing.py
function get_edges_of_pc (line 12) | def get_edges_of_pc(organized_pc):
function get_plane_eq (line 20) | def get_plane_eq(unorganized_pc,ransac_n_pts=50):
function remove_plane (line 25) | def remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0...
function connected_components_cleaning (line 49) | def connected_components_cleaning(organized_pc, organized_rgb, image_path):
function roundup_next_100 (line 80) | def roundup_next_100(x):
function pad_cropped_pc (line 83) | def pad_cropped_pc(cropped_pc, single_channel=False):
function preprocess_pc (line 99) | def preprocess_pc(tiff_path):
FILE: utils/utils.py
function set_seeds (line 7) | def set_seeds(seed: int = 0) -> None:
class KNNGaussianBlur (line 12) | class KNNGaussianBlur(torch.nn.Module):
method __init__ (line 13) | def __init__(self, radius : int = 4):
method __call__ (line 20) | def __call__(self, img):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (166K chars).
[
{
"path": "LICENSE",
"chars": 1065,
"preview": "MIT License\n\nCopyright (c) 2023 nomewang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\no"
},
{
"path": "README.md",
"chars": 6879,
"preview": "# Multimodal Industrial Anomaly Detection via Hybrid Fusion (CVPR 2023)\n\n## Abstract\n> 2D-based Industrial Anomaly Detec"
},
{
"path": "dataset.py",
"chars": 7295,
"preview": "import os\nfrom PIL import Image\nfrom torchvision import transforms\nimport glob\nfrom torch.utils.data import Dataset\nfrom"
},
{
"path": "engine_fusion_pretrain.py",
"chars": 2651,
"preview": "import math\nimport sys\nfrom typing import Iterable\n\nimport torch\n\nimport utils.misc as misc\nimport utils.lr_sched as lr_"
},
{
"path": "feature_extractors/features.py",
"chars": 8465,
"preview": "\"\"\"\r\nPatchCore logic based on https://github.com/rvorias/ind_knn_ad\r\n\"\"\"\r\nimport torch\r\nimport numpy as np\r\nimport os\r\nf"
},
{
"path": "feature_extractors/multiple_features.py",
"chars": 50766,
"preview": "import torch\r\nfrom feature_extractors.features import Features\r\nfrom utils.mvtec3d_util import *\r\nimport numpy as np\r\nim"
},
{
"path": "fusion_pretrain.py",
"chars": 7274,
"preview": "import argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport time\nfrom pathlib import Path\n\nimport to"
},
{
"path": "m3dm_runner.py",
"chars": 4053,
"preview": "import torch\nfrom tqdm import tqdm\nimport os\n\nfrom feature_extractors import multiple_features\n \nfrom dataset imp"
},
{
"path": "main.py",
"chars": 6929,
"preview": "import argparse\nfrom m3dm_runner import M3DM\nfrom dataset import eyecandies_classes, mvtec3d_classes\nimport pandas as pd"
},
{
"path": "models/feature_fusion.py",
"chars": 3193,
"preview": "import torch\nimport torch.nn as nn\nimport math\n\nclass Mlp(nn.Module):\n def __init__(self, in_features, hidden_feature"
},
{
"path": "models/models.py",
"chars": 13927,
"preview": "import torch\nimport torch.nn as nn\nimport timm\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom pointnet2_ops"
},
{
"path": "models/pointnet2_utils.py",
"chars": 5993,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom time import time\nimport numpy as np\n\ndef timeit("
},
{
"path": "requirements.txt",
"chars": 147,
"preview": "numpy\nPillow\nscikit-learn\nscipy\ntimm\ntorch\ntorchvision\ntqdm\nwget\ntifffile\nscikit-image\nkornia\nimageio\ntensorboard\nopencv"
},
{
"path": "utils/au_pro_util.py",
"chars": 9926,
"preview": "\"\"\"\nCode based on the official MVTec 3D-AD evaluation code found at\nhttps://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c"
},
{
"path": "utils/lr_sched.py",
"chars": 604,
"preview": "import math\n\ndef adjust_learning_rate(optimizer, epoch, args):\n \"\"\"Decay the learning rate with half-cycle cosine aft"
},
{
"path": "utils/misc.py",
"chars": 13302,
"preview": "import builtins\nimport datetime\nimport os\nimport time\nfrom collections import defaultdict, deque\nfrom pathlib import Pat"
},
{
"path": "utils/mvtec3d_util.py",
"chars": 945,
"preview": "import tifffile as tiff\nimport torch\n\n\ndef organized_pc_to_unorganized_pc(organized_pc):\n return organized_pc.reshape"
},
{
"path": "utils/preprocess_eyecandies.py",
"chars": 8276,
"preview": "import os\r\nfrom shutil import copyfile\r\nimport cv2\r\nimport numpy as np\r\nimport tifffile\r\nimport yaml\r\nimport imageio.v3 "
},
{
"path": "utils/preprocessing.py",
"chars": 7479,
"preview": "import os\nimport numpy as np\nimport tifffile as tiff\nimport open3d as o3d\nfrom pathlib import Path\nfrom PIL import Image"
},
{
"path": "utils/utils.py",
"chars": 691,
"preview": "import numpy as np\nimport random\nimport torch\nfrom torchvision import transforms\nfrom PIL import ImageFilter\n\ndef set_se"
}
]
About this extraction
This page contains the full source code of the nomewang/M3DM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (156.1 KB), approximately 39.1k tokens, and a symbol index with 185 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.