[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 nomewang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Multimodal Industrial Anomaly Detection via Hybrid Fusion (CVPR 2023)\n\n## Abstract\n> 2D-based Industrial Anomaly Detection has been widely discussed, however, multimodal industrial anomaly detection based on 3D point clouds and RGB images still has many untouched fields. Existing multimodal industrial anomaly detection methods directly concatenate the multimodal features, which leads to a strong disturbance between features and harms the detection performance. In this paper, we propose **Multi-3D-Memory** (**M3DM**), a novel multimodal anomaly detection method with hybrid fusion scheme: firstly, we design an unsupervised feature fusion with patch-wise contrastive learning to encourage the interaction of different modal features; secondly, we use a decision layer fusion with multiple memory banks to avoid loss of information and additional novelty classifiers to make the final decision. We further propose a point feature alignment operation to better align the point cloud and RGB features. Extensive experiments show that our multimodal industrial anomaly detection model outperforms the state-of-the-art (SOTA) methods on both detection and segmentation precision on  MVTec-3D AD dataset.\n\n![piplien](figures/pipeline.png)\n- `The pipeline of  Multi-3D-Memory (M3DM).` Our M3DM contains three important parts: (1) **Point Feature Alignment** (PFA) converts Point Group features to plane features with interpolation and project operation, $\\text{FPS}$ is the farthest point sampling and $\\mathcal{F_{pt}}$ is a pretrained Point Transformer; (2) **Unsupervised Feature Fusion** (UFF) fuses point feature and image feature together with a patch-wise contrastive loss $\\mathcal{L_{con}}$, where $\\mathcal{F_{rgb}}$ is a Vision Transformer, $\\chi_{rgb},\\chi_{pt}$ are MLP layers and $\\sigma_r, \\sigma_p$ are single fully connected layers; (3) **Decision Layer Fusion** (DLF) combines multimodal information with multiple memory banks and makes the final decision with 2 learnable modules $\\mathcal D_a, \\mathcal{D_s}$ for anomaly detection and segmentation, where $\\mathcal{M_{rgb}}$, $\\mathcal{M_{fs}}$, $\\mathcal{M_{pt}}$ are memory banks, $\\phi, \\psi$ are score function for single memory bank detection and segmentation, and  $\\mathcal{P}$ is the memory bank building algorithm.\n\n### [Paper](https://arxiv.org/pdf/2303.00601.pdf)\n\n## Setup\n\nWe implement this repo with the following environment:\n- Ubuntu 18.04\n- Python 3.8\n- Pytorch 1.9.0\n- CUDA 11.3\n\nInstall the other package via:\n\n``` bash\npip install -r requirement.txt\n# install knn_cuda\npip install --upgrade https://github.com/unlimblue/KNN_CUDA/releases/download/0.2/KNN_CUDA-0.2-py3-none-any.whl\n# install pointnet2_ops_lib\npip install \"git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib\"\n```\n\n## Data Download and Preprocess\n\n### Dataset\n\n- The `MVTec-3D AD` dataset can be download from the [Official Website of MVTec-3D AD](https://www.mvtec.com/company/research/datasets/mvtec-3d-ad). \n\n- The `Eyecandies` dataset can be download from the [Official Website of Eyecandies](https://eyecan-ai.github.io/eyecandies/). \n\nAfter download, put the dataset in `dataset` folder.\n\n### Datapreprocess\n\n\nTo run the preprocessing \n```bash\npython utils/preprocessing.py datasets/mvtec3d/\n```\n\nIt may take a few hours to run the preprocessing. \n\n### Checkpoints\n\nThe following table lists the pretrain model used in M3DM:\n\n| Backbone          | Pretrain Method                                                                                                                                                                 |\n| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| Point Transformer | [Point-MAE](https://drive.google.com/file/d/1-wlRIz0GM8o6BuPTJz4kTt6c_z1Gh6LX/view?usp=sharing)                                                                                       |\n| Point Transformer | [Point-Bert](https://cloud.tsinghua.edu.cn/f/202b29805eea45d7be92/?dl=1)                                                                                                        |\n| ViT-b/8           | [DINO](https://drive.google.com/file/d/17s6lwfxwG_nf1td6LXunL-LjRaX67iyK/view?usp=sharing)                                                                                   |\n| ViT-b/8           | [Supervised ImageNet 1K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz) |\n| ViT-b/8           | [Supervised ImageNet 21K](https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz)                                        |\n| ViT-s/8           | [DINO](https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth)                                                                               |\n| UFF                | [UFF Module](https://drive.google.com/file/d/1Z2AkfPqenJEv-IdWhVdRcvVQAsJC4DxW/view?usp=sharing)                                                                               |\n\nPut the checkpoint files in `checkpoints` folder.\n\n## Train and Test\n\nTrain and test the double lib version and save the feature for UFF training:\n\n```bash\nmkdir -p datasets/patch_lib\npython3 main.py \\\n--method_name DINO+Point_MAE \\\n--memory_bank multiple \\\n--rgb_backbone_name vit_base_patch8_224_dino \\\n--xyz_backbone_name Point_MAE \\\n--save_feature \\\n```\n\nTrain the UFF:\n\n```bash\nOMP_NUM_THREADS=1 python3 -m torch.distributed.launch --nproc_per_node=1 fusion_pretrain.py    \\\n--accum_iter 16 \\\n--lr 0.003 \\\n--batch_size 16 \\\n--data_path datasets/patch_lib \\\n--output_dir checkpoints \\\n```\n\nTrain and test the full setting with the following command:\n\n```bash\npython3 main.py \\\n--method_name DINO+Point_MAE+Fusion \\\n--use_uff \\\n--memory_bank multiple \\\n--rgb_backbone_name vit_base_patch8_224_dino \\\n--xyz_backbone_name Point_MAE \\\n--fusion_module_path checkpoints/{FUSION_CHECKPOINT}.pth \\\n```\n\nNote: if you set `--method_name DINO` or `--method_name Point_MAE`, set `--memory_bank single` at the same time. \n\n\n\nIf you find this repository useful for your research, please use the following.\n\n```bibtex\n@inproceedings{wang2023multimodal,\n  title={Multimodal Industrial Anomaly Detection via Hybrid Fusion},\n  author={Wang, Yue and Peng, Jinlong and Zhang, Jiangning and Yi, Ran and Wang, Yabiao and Wang, Chengjie},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={8032--8041},\n  year={2023}\n}\n```\n\n## Thanks\n\nOur repo is built on [3D-ADS](https://github.com/eliahuhorwitz/3D-ADS) and [MoCo-v3](https://github.com/facebookresearch/moco-v3), thanks their extraordinary works!\n"
  },
  {
    "path": "dataset.py",
    "content": "import os\nfrom PIL import Image\nfrom torchvision import transforms\nimport glob\nfrom torch.utils.data import Dataset\nfrom utils.mvtec3d_util import *\nfrom torch.utils.data import DataLoader\nimport numpy as np\n\ndef eyecandies_classes():\n    return [\n        'CandyCane',\n        'ChocolateCookie',\n        'ChocolatePraline',\n        'Confetto',\n        'GummyBear',\n        'HazelnutTruffle',\n        'LicoriceSandwich',\n        'Lollipop',\n        'Marshmallow',\n        'PeppermintCandy',   \n    ]\n\ndef mvtec3d_classes():\n    return [\n        \"bagel\",\n        \"cable_gland\",\n        \"carrot\",\n        \"cookie\",\n        \"dowel\",\n        \"foam\",\n        \"peach\",\n        \"potato\",\n        \"rope\",\n        \"tire\",\n    ]\n\nRGB_SIZE = 224\n\nclass BaseAnomalyDetectionDataset(Dataset):\n\n    def __init__(self, split, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):\n        self.IMAGENET_MEAN = [0.485, 0.456, 0.406]\n        self.IMAGENET_STD = [0.229, 0.224, 0.225]\n        self.cls = class_name\n        self.size = img_size\n        self.img_path = os.path.join(dataset_path, self.cls, split)\n        self.rgb_transform = transforms.Compose(\n            [transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.BICUBIC),\n             transforms.ToTensor(),\n             transforms.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)])\n\nclass PreTrainTensorDataset(Dataset):\n    def __init__(self, root_path):\n        super().__init__()\n        self.root_path = root_path\n        self.tensor_paths = os.listdir(self.root_path)\n\n\n    def __len__(self):\n        return len(self.tensor_paths)\n\n    def __getitem__(self, idx):\n        tensor_path = self.tensor_paths[idx]\n\n        tensor = torch.load(os.path.join(self.root_path, tensor_path))\n\n        label = 0\n\n        return tensor, label\n\nclass TrainDataset(BaseAnomalyDetectionDataset):\n    def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):\n        super().__init__(split=\"train\", class_name=class_name, img_size=img_size, dataset_path=dataset_path)\n        self.img_paths, self.labels = self.load_dataset()  # self.labels => good : 0, anomaly : 1\n\n    def load_dataset(self):\n        img_tot_paths = []\n        tot_labels = []\n        rgb_paths = glob.glob(os.path.join(self.img_path, 'good', 'rgb') + \"/*.png\")\n        tiff_paths = glob.glob(os.path.join(self.img_path, 'good', 'xyz') + \"/*.tiff\")\n        rgb_paths.sort()\n        tiff_paths.sort()\n        sample_paths = list(zip(rgb_paths, tiff_paths))\n        img_tot_paths.extend(sample_paths)\n        tot_labels.extend([0] * len(sample_paths))\n        return img_tot_paths, tot_labels\n\n    def __len__(self):\n        return len(self.img_paths)\n\n    def __getitem__(self, idx):\n        img_path, label = self.img_paths[idx], self.labels[idx]\n        rgb_path = img_path[0]\n        tiff_path = img_path[1]\n        img = Image.open(rgb_path).convert('RGB')\n\n        img = self.rgb_transform(img)\n        organized_pc = read_tiff_organized_pc(tiff_path)\n        \n        depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)\n        resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)\n        resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)\n        resized_organized_pc = resized_organized_pc.clone().detach().float()\n\n        return (img, resized_organized_pc, resized_depth_map_3channel), label\n\n\nclass TestDataset(BaseAnomalyDetectionDataset):\n    def __init__(self, class_name, img_size, dataset_path='datasets/eyecandies_preprocessed'):\n        super().__init__(split=\"test\", class_name=class_name, img_size=img_size, dataset_path=dataset_path)\n        self.gt_transform = transforms.Compose([\n            transforms.Resize((RGB_SIZE, RGB_SIZE), interpolation=transforms.InterpolationMode.NEAREST),\n            transforms.ToTensor()])\n        self.img_paths, self.gt_paths, self.labels = self.load_dataset()  # self.labels => good : 0, anomaly : 1\n\n    def load_dataset(self):\n        img_tot_paths = []\n        gt_tot_paths = []\n        tot_labels = []\n        defect_types = os.listdir(self.img_path)\n\n        for defect_type in defect_types:\n            if defect_type == 'good':\n                rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + \"/*.png\")\n                tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + \"/*.tiff\")\n                rgb_paths.sort()\n                tiff_paths.sort()\n                sample_paths = list(zip(rgb_paths, tiff_paths))\n                img_tot_paths.extend(sample_paths)\n                gt_tot_paths.extend([0] * len(sample_paths))\n                tot_labels.extend([0] * len(sample_paths))\n            else:\n                rgb_paths = glob.glob(os.path.join(self.img_path, defect_type, 'rgb') + \"/*.png\")\n                tiff_paths = glob.glob(os.path.join(self.img_path, defect_type, 'xyz') + \"/*.tiff\")\n                gt_paths = glob.glob(os.path.join(self.img_path, defect_type, 'gt') + \"/*.png\")\n                rgb_paths.sort()\n                tiff_paths.sort()\n                gt_paths.sort()\n                sample_paths = list(zip(rgb_paths, tiff_paths))\n\n                img_tot_paths.extend(sample_paths)\n                gt_tot_paths.extend(gt_paths)\n                tot_labels.extend([1] * len(sample_paths))\n\n        assert len(img_tot_paths) == len(gt_tot_paths), \"Something wrong with test and ground truth pair!\"\n\n        return img_tot_paths, gt_tot_paths, tot_labels\n\n    def __len__(self):\n        return len(self.img_paths)\n\n    def __getitem__(self, idx):\n        img_path, gt, label = self.img_paths[idx], self.gt_paths[idx], self.labels[idx]\n        rgb_path = img_path[0]\n        tiff_path = img_path[1]\n        img_original = Image.open(rgb_path).convert('RGB')\n        img = self.rgb_transform(img_original)\n\n        organized_pc = read_tiff_organized_pc(tiff_path)\n        depth_map_3channel = np.repeat(organized_pc_to_depth_map(organized_pc)[:, :, np.newaxis], 3, axis=2)\n        resized_depth_map_3channel = resize_organized_pc(depth_map_3channel)\n        resized_organized_pc = resize_organized_pc(organized_pc, target_height=self.size, target_width=self.size)\n        resized_organized_pc = resized_organized_pc.clone().detach().float()\n        \n\n        \n\n        if gt == 0:\n            gt = torch.zeros(\n                [1, resized_depth_map_3channel.size()[-2], resized_depth_map_3channel.size()[-2]])\n        else:\n            gt = Image.open(gt).convert('L')\n            gt = self.gt_transform(gt)\n            gt = torch.where(gt > 0.5, 1., .0)\n\n        return (img, resized_organized_pc, resized_depth_map_3channel), gt[:1], label, rgb_path\n\n\ndef get_data_loader(split, class_name, img_size, args):\n    if split in ['train']:\n        dataset = TrainDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)\n    elif split in ['test']:\n        dataset = TestDataset(class_name=class_name, img_size=img_size, dataset_path=args.dataset_path)\n\n    data_loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False,\n                             pin_memory=True)\n    return data_loader\n"
  },
  {
    "path": "engine_fusion_pretrain.py",
    "content": "import math\nimport sys\nfrom typing import Iterable\n\nimport torch\n\nimport utils.misc as misc\nimport utils.lr_sched as lr_sched\n\n\ndef train_one_epoch(model: torch.nn.Module,\n                    data_loader: Iterable, optimizer: torch.optim.Optimizer,\n                    device: torch.device, epoch: int, loss_scaler,\n                    log_writer=None,\n                    args=None):\n    model.train(True)\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 20\n\n    accum_iter = args.accum_iter\n\n    optimizer.zero_grad()\n\n    if log_writer is not None:\n        print('log_dir: {}'.format(log_writer.log_dir))\n\n    for data_iter_step, (samples, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):\n\n        # we use a per iteration (instead of per epoch) lr scheduler\n        if data_iter_step % accum_iter == 0:\n            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)\n\n        \n        xyz_samples = samples[:,:,:1152].to(device, non_blocking=True)\n        rgb_samples = samples[:,:,1152:].to(device, non_blocking=True)\n\n        with torch.cuda.amp.autocast():\n            loss = model(xyz_samples, rgb_samples)\n\n        loss_value = loss.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        loss /= accum_iter\n        loss_scaler(loss, optimizer, parameters=model.parameters(),\n                    update_grad=(data_iter_step + 1) % accum_iter == 0)\n        if (data_iter_step + 1) % accum_iter == 0:\n            optimizer.zero_grad()\n\n        torch.cuda.synchronize()\n\n        metric_logger.update(loss=loss_value)\n\n        lr = optimizer.param_groups[0][\"lr\"]\n        metric_logger.update(lr=lr)\n\n        \n        loss_value_reduce = misc.all_reduce_mean(loss_value)\n        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:\n            \"\"\" We use epoch_1000x as the x-axis in tensorboard.\n            This calibrates different curves when batch size changes.\n            \"\"\"\n            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)\n            log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)\n            log_writer.add_scalar('lr', lr, epoch_1000x)\n\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}"
  },
  {
    "path": "feature_extractors/features.py",
    "content": "\"\"\"\r\nPatchCore logic based on https://github.com/rvorias/ind_knn_ad\r\n\"\"\"\r\nimport torch\r\nimport numpy as np\r\nimport os\r\nfrom tqdm import tqdm\r\nfrom matplotlib import pyplot as plt\r\n\r\nfrom sklearn import random_projection\r\nfrom sklearn import linear_model\r\nfrom sklearn.svm import OneClassSVM\r\nfrom sklearn.ensemble import IsolationForest\r\nfrom sklearn.metrics import roc_auc_score\r\n\r\nfrom timm.models.layers import DropPath, trunc_normal_\r\nfrom pointnet2_ops import pointnet2_utils\r\nfrom knn_cuda import KNN\r\n\r\nfrom utils.utils import KNNGaussianBlur\r\nfrom utils.utils import set_seeds\r\nfrom utils.au_pro_util import calculate_au_pro\r\n\r\nfrom models.pointnet2_utils import interpolating_points\r\nfrom models.feature_fusion import FeatureFusionBlock\r\nfrom models.models import Model\r\n\r\nclass Features(torch.nn.Module):\r\n\r\n    def __init__(self, args, image_size=224, f_coreset=0.1, coreset_eps=0.9):\r\n        super().__init__()\r\n        self.device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\r\n        self.deep_feature_extractor = Model(\r\n                                device=self.device, \r\n                                rgb_backbone_name=args.rgb_backbone_name, \r\n                                xyz_backbone_name=args.xyz_backbone_name, \r\n                                group_size = args.group_size, \r\n                                num_group=args.num_group\r\n                                )\r\n        self.deep_feature_extractor.to(self.device)\r\n\r\n        self.args = args\r\n        self.image_size = args.img_size\r\n        self.f_coreset = args.f_coreset\r\n        self.coreset_eps = args.coreset_eps\r\n        \r\n        self.blur = KNNGaussianBlur(4)\r\n        self.n_reweight = 3\r\n        set_seeds(0)\r\n        self.patch_xyz_lib = []\r\n        self.patch_rgb_lib = []\r\n        self.patch_fusion_lib = []\r\n        self.patch_lib = []\r\n        self.random_state = args.random_state\r\n\r\n        self.xyz_dim = 0\r\n        self.rgb_dim = 0\r\n\r\n        self.xyz_mean=0\r\n        self.xyz_std=0\r\n        self.rgb_mean=0\r\n        self.rgb_std=0\r\n        self.fusion_mean=0\r\n        self.fusion_std=0\r\n\r\n        self.average = torch.nn.AvgPool2d(3, stride=1) # torch.nn.AvgPool2d(1, stride=1) #\r\n        self.resize = torch.nn.AdaptiveAvgPool2d((56, 56))\r\n        self.resize2 = torch.nn.AdaptiveAvgPool2d((56, 56))\r\n\r\n        self.image_preds = list()\r\n        self.image_labels = list()\r\n        self.pixel_preds = list()\r\n        self.pixel_labels = list()\r\n        self.gts = []\r\n        self.predictions = []\r\n        self.image_rocauc = 0\r\n        self.pixel_rocauc = 0\r\n        self.au_pro = 0\r\n        self.ins_id = 0\r\n        self.rgb_layernorm = torch.nn.LayerNorm(768, elementwise_affine=False)\r\n\r\n        if self.args.use_uff:\r\n            self.fusion = FeatureFusionBlock(1152, 768, mlp_ratio=4.)\r\n\r\n            ckpt = torch.load(args.fusion_module_path)['model']\r\n\r\n            incompatible = self.fusion.load_state_dict(ckpt, strict=False)\r\n\r\n            print('[Fusion Block]', incompatible)\r\n\r\n        self.detect_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu,  max_iter=args.ocsvm_maxiter)\r\n        self.seg_fuser = linear_model.SGDOneClassSVM(random_state=42, nu=args.ocsvm_nu,  max_iter=args.ocsvm_maxiter)\r\n\r\n        self.s_lib = []\r\n        self.s_map_lib = []\r\n\r\n    def __call__(self, rgb, xyz):\r\n        # Extract the desired feature maps using the backbone model.\r\n        rgb = rgb.to(self.device)\r\n        xyz = xyz.to(self.device)\r\n        with torch.no_grad():\r\n            rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)\r\n\r\n        interpolate = True    \r\n        if interpolate:\r\n            interpolated_feature_maps = interpolating_points(xyz, center.permute(0,2,1), xyz_feature_maps).to(\"cpu\")\r\n\r\n        xyz_feature_maps = [fmap.to(\"cpu\") for fmap in [xyz_feature_maps]]\r\n        rgb_feature_maps = [fmap.to(\"cpu\") for fmap in [rgb_feature_maps]]\r\n\r\n        if interpolate:\r\n            return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx, interpolated_feature_maps\r\n        else:\r\n            return rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx\r\n\r\n    def add_sample_to_mem_bank(self, sample):\r\n        raise NotImplementedError\r\n\r\n    def predict(self, sample, mask, label):\r\n        raise NotImplementedError\r\n\r\n    def add_sample_to_late_fusion_mem_bank(self, sample):\r\n        raise NotImplementedError\r\n\r\n    def interpolate_points(self, rgb, xyz):\r\n        with torch.no_grad():\r\n            rgb_feature_maps, xyz_feature_maps, center, ori_idx, center_idx = self.deep_feature_extractor(rgb, xyz)\r\n        return xyz_feature_maps, center, xyz\r\n    \r\n    def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):\r\n        raise NotImplementedError\r\n    \r\n    def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):\r\n        raise NotImplementedError\r\n    \r\n    def run_coreset(self):\r\n        raise NotImplementedError\r\n\r\n    def calculate_metrics(self):\r\n        self.image_preds = np.stack(self.image_preds)\r\n        self.image_labels = np.stack(self.image_labels)\r\n        self.pixel_preds = np.array(self.pixel_preds)\r\n\r\n        self.image_rocauc = roc_auc_score(self.image_labels, self.image_preds)\r\n        self.pixel_rocauc = roc_auc_score(self.pixel_labels, self.pixel_preds)\r\n        self.au_pro, _ = calculate_au_pro(self.gts, self.predictions)\r\n\r\n    def save_prediction_maps(self, output_path, rgb_path, save_num=5):\r\n        for i in range(max(save_num, len(self.predictions))):            \r\n            # fig = plt.figure(dpi=300)\r\n            fig = plt.figure()\r\n        \r\n            ax3 = fig.add_subplot(1,3,1)\r\n            gt = plt.imread(rgb_path[i][0])    \r\n            ax3.imshow(gt)\r\n\r\n            ax2 = fig.add_subplot(1,3,2)\r\n            im2 = ax2.imshow(self.gts[i], cmap=plt.cm.gray)\r\n            \r\n            ax = fig.add_subplot(1,3,3)\r\n            im = ax.imshow(self.predictions[i], cmap=plt.cm.jet)\r\n            \r\n            class_dir = os.path.join(output_path, rgb_path[i][0].split('/')[-5])\r\n            if not os.path.exists(class_dir):\r\n                os.mkdir(class_dir)\r\n\r\n            ad_dir = os.path.join(class_dir, rgb_path[i][0].split('/')[-3])\r\n            if not os.path.exists(ad_dir):\r\n                os.mkdir(ad_dir)\r\n            \r\n            plt.savefig(os.path.join(ad_dir,  str(self.image_preds[i]) + '_pred_' + rgb_path[i][0].split('/')[-1] + '.jpg'))\r\n            \r\n    def run_late_fusion(self):\r\n        self.s_lib = torch.cat(self.s_lib, 0)\r\n        self.s_map_lib = torch.cat(self.s_map_lib, 0)\r\n        self.detect_fuser.fit(self.s_lib)\r\n        self.seg_fuser.fit(self.s_map_lib)\r\n\r\n    def get_coreset_idx_randomp(self, z_lib, n=1000, eps=0.90, float16=True, force_cpu=False):\r\n\r\n        print(f\"   Fitting random projections. Start dim = {z_lib.shape}.\")\r\n        try:\r\n            transformer = random_projection.SparseRandomProjection(eps=eps, random_state=self.random_state)\r\n            z_lib = torch.tensor(transformer.fit_transform(z_lib))\r\n\r\n            print(f\"   DONE.                 Transformed dim = {z_lib.shape}.\")\r\n        except ValueError:\r\n            print(\"   Error: could not project vectors. Please increase `eps`.\")\r\n\r\n        select_idx = 0\r\n        last_item = z_lib[select_idx:select_idx + 1]\r\n        coreset_idx = [torch.tensor(select_idx)]\r\n        min_distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True)\r\n\r\n        if float16:\r\n            last_item = last_item.half()\r\n            z_lib = z_lib.half()\r\n            min_distances = min_distances.half()\r\n        if torch.cuda.is_available() and not force_cpu:\r\n            last_item = last_item.to(\"cuda\")\r\n            z_lib = z_lib.to(\"cuda\")\r\n            min_distances = min_distances.to(\"cuda\")\r\n\r\n        for _ in tqdm(range(n - 1)):\r\n            distances = torch.linalg.norm(z_lib - last_item, dim=1, keepdims=True)  # broadcasting step\r\n            min_distances = torch.minimum(distances, min_distances)  # iterative step\r\n            select_idx = torch.argmax(min_distances)  # selection step\r\n\r\n            # bookkeeping\r\n            last_item = z_lib[select_idx:select_idx + 1]\r\n            min_distances[select_idx] = 0\r\n            coreset_idx.append(select_idx.to(\"cpu\"))\r\n        return torch.stack(coreset_idx)\r\n"
  },
  {
    "path": "feature_extractors/multiple_features.py",
    "content": "import torch\r\nfrom feature_extractors.features import Features\r\nfrom utils.mvtec3d_util import *\r\nimport numpy as np\r\nimport math\r\nimport os\r\n\r\nclass RGBFeatures(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, _, _, center_idx, _ = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        self.patch_lib.append(rgb_patch)\r\n\r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, _ = self(sample[0], unorganized_pc_no_zeros.contiguous())\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        self.compute_s_s_map(rgb_patch, rgb_feature_maps[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def run_coreset(self):\r\n\r\n        self.patch_lib = torch.cat(self.patch_lib, 0)\r\n        self.mean = torch.mean(self.patch_lib)\r\n        self.std = torch.std(self.patch_lib)\r\n        self.patch_lib = (self.patch_lib - self.mean)/self.std\r\n\r\n        # self.patch_lib = self.rgb_layernorm(self.patch_lib)\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,\r\n                                                            n=int(self.f_coreset * self.patch_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_lib = self.patch_lib[self.coreset_idx]\r\n\r\n\r\n    def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n        patch = (patch - self.mean)/self.std\r\n\r\n        # self.patch_lib = self.rgb_layernorm(self.patch_lib)\r\n        dist = torch.cdist(patch, self.patch_lib)\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        # print(min_val.shape)\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n        m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n        w_dist = torch.cdist(m_star, self.patch_lib)  # find knn to m_star pt.1\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5))\r\n        s = w * s_star\r\n\r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')\r\n        s_map = self.blur(s_map)\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\nclass PointFeatures(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n \r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n        self.patch_lib.append(xyz_patch)\r\n\r\n\r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n        self.compute_s_s_map(xyz_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def run_coreset(self):\r\n\r\n        self.patch_lib = torch.cat(self.patch_lib, 0)\r\n\r\n        if self.args.rm_zero_for_project:\r\n            self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]]\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,\r\n                                                            n=int(self.f_coreset * self.patch_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_lib = self.patch_lib[self.coreset_idx]\r\n            \r\n        if self.args.rm_zero_for_project:\r\n\r\n            self.patch_lib = self.patch_lib[torch.nonzero(torch.all(self.patch_lib!=0, dim=1))[:,0]]\r\n            self.patch_lib = torch.cat((self.patch_lib, torch.zeros(1, self.patch_lib.shape[1])), 0)\r\n\r\n\r\n    def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx, nonzero_patch_indices = None):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n\r\n        dist = torch.cdist(patch, self.patch_lib)\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        # print(min_val.shape)\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n        m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n        w_dist = torch.cdist(m_star, self.patch_lib)  # find knn to m_star pt.1\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D)) + 1e-5))\r\n        s = w * s_star\r\n\r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')\r\n        s_map = self.blur(s_map)\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\nFUSION_BLOCK= True\r\n\r\nclass FusionFeatures(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample, class_name=None):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n        \r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))\r\n        rgb_patch2 =  self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))\r\n        rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        if FUSION_BLOCK:\r\n            with torch.no_grad():\r\n                fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0))\r\n            fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()\r\n        else:\r\n            fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1)\r\n\r\n        if class_name is not None:\r\n            torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))\r\n            self.ins_id += 1\r\n\r\n        self.patch_lib.append(fusion_patch)\r\n\r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))\r\n        rgb_patch2 =  self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))\r\n        rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        if FUSION_BLOCK:\r\n            with torch.no_grad():\r\n                fusion_patch = self.fusion.feature_fusion(xyz_patch.unsqueeze(0), rgb_patch2.unsqueeze(0))\r\n            fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()\r\n        else:\r\n            fusion_patch = torch.cat([xyz_patch, rgb_patch2], dim=1)\r\n\r\n        self.compute_s_s_map(fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def compute_s_s_map(self, patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n        dist = torch.cdist(patch, self.patch_lib)\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n        m_star = self.patch_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n        w_dist = torch.cdist(m_star, self.patch_lib)  # find knn to m_star pt.1\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        m_star_knn = torch.linalg.norm(m_test - self.patch_lib[nn_idx[0, 1:]], dim=1)\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))\r\n        s = w * s_star\r\n\r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear')\r\n        s_map = self.blur(s_map)\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\n    def run_coreset(self):\r\n        self.patch_lib = torch.cat(self.patch_lib, 0)\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_lib,\r\n                                                            n=int(self.f_coreset * self.patch_lib.shape[0]),\r\n                                                            eps=self.coreset_eps)\r\n            self.patch_lib = self.patch_lib[self.coreset_idx]\r\n\r\nclass DoubleRGBPointFeatures(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample, class_name=None):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1)\r\n\r\n        patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1)\r\n\r\n        if class_name is not None:\r\n            torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))\r\n            self.ins_id += 1\r\n\r\n        self.patch_xyz_lib.append(xyz_patch)\r\n        self.patch_rgb_lib.append(rgb_patch)\r\n\r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def add_sample_to_late_fusion_mem_bank(self, sample):\r\n\r\n\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n    \r\n        # 2D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n\r\n        \r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n\r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n\r\n        s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]])\r\n \r\n        s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)\r\n\r\n        self.s_lib.append(s)\r\n        self.s_map_lib.append(s_map)\r\n\r\n    def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n        # 2D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n\r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n\r\n        s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb]])\r\n        s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)\r\n\r\n        \r\n        s = torch.tensor(self.detect_fuser.score_samples(s))\r\n\r\n        s_map = torch.tensor(self.seg_fuser.score_samples(s_map))\r\n        s_map = s_map.view(1, 224, 224)\r\n\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\n    def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)/1000\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n\r\n        if modal=='xyz':\r\n            m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_xyz_lib)  # find knn to m_star pt.1\r\n        else:\r\n            m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_rgb_lib)  # find knn to m_star pt.1\r\n\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        if modal=='xyz':\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1)/1000\r\n        else:\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)/1000\r\n\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))\r\n        s = w * s_star\r\n        \r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(224, 224), mode='bilinear')\r\n        s_map = self.blur(s_map)\r\n\r\n        return s, s_map\r\n\r\n    def run_coreset(self):\r\n        self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)\r\n        self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)\r\n\r\n        self.xyz_mean = torch.mean(self.patch_xyz_lib)\r\n        self.xyz_std = torch.std(self.patch_rgb_lib)\r\n        self.rgb_mean = torch.mean(self.patch_xyz_lib)\r\n        self.rgb_std = torch.std(self.patch_rgb_lib)\r\n\r\n        self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std\r\n\r\n        self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]\r\n\r\nclass DoubleRGBPointFeatures_add(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample, class_name=None):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        rgb_patch_resize = rgb_patch.repeat(4, 1).reshape(784, 4, -1).permute(1, 0, 2).reshape(784*4, -1)\r\n\r\n        patch = torch.cat([xyz_patch, rgb_patch_resize], dim=1)\r\n\r\n        if class_name is not None:\r\n            torch.save(patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))\r\n            self.ins_id += 1\r\n\r\n        self.patch_xyz_lib.append(xyz_patch)\r\n        self.patch_rgb_lib.append(rgb_patch)\r\n\r\n\r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        self.compute_s_s_map(xyz_patch, rgb_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def add_sample_to_late_fusion_mem_bank(self, sample):\r\n\r\n\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n    \r\n        # 2D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n\r\n        \r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n\r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n\r\n        s = torch.tensor([[s_xyz, s_rgb]])\r\n        s_map = torch.cat([s_map_xyz, s_map_rgb], dim=0).squeeze().reshape(2, -1).permute(1, 0)\r\n\r\n        self.s_lib.append(s)\r\n        self.s_map_lib.append(s_map)\r\n\r\n    def run_coreset(self):\r\n        self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)\r\n        self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)\r\n\r\n        self.xyz_mean = torch.mean(self.patch_xyz_lib)\r\n        self.xyz_std = torch.std(self.patch_rgb_lib)\r\n        self.rgb_mean = torch.mean(self.patch_xyz_lib)\r\n        self.rgb_std = torch.std(self.patch_rgb_lib)\r\n\r\n        self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std\r\n\r\n        self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]\r\n\r\n    def compute_s_s_map(self, xyz_patch, rgb_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n        # 2D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n\r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n\r\n        s = s_xyz + s_rgb\r\n        s_map = s_map_xyz + s_map_rgb\r\n        s_map = s_map.view(1, self.image_size, self.image_size)\r\n\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\n    def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n\r\n        if modal=='xyz':\r\n            m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_xyz_lib)  # find knn to m_star pt.1\r\n        else:\r\n            m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_rgb_lib)  # find knn to m_star pt.1\r\n\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        if modal=='xyz':\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1) \r\n        else:\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)\r\n\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))\r\n        s = w * s_star\r\n        \r\n\r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)\r\n        s_map = self.blur(s_map)\r\n\r\n        return s, s_map\r\n\r\nclass TripleFeatures(Features):\r\n\r\n    def add_sample_to_mem_bank(self, sample, class_name=None):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))\r\n        rgb_patch2 =  self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))\r\n        rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        self.patch_rgb_lib.append(rgb_patch)\r\n\r\n        if self.args.asy_memory_bank is None or len(self.patch_xyz_lib) < self.args.asy_memory_bank:\r\n\r\n            xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n            xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n            xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n            xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n            xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n            xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n            xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))\r\n            xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T\r\n\r\n            if FUSION_BLOCK:\r\n                with torch.no_grad():\r\n                    fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))\r\n                fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()\r\n            else:\r\n                fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)\r\n\r\n            self.patch_xyz_lib.append(xyz_patch)\r\n            self.patch_fusion_lib.append(fusion_patch)\r\n    \r\n\r\n        if class_name is not None:\r\n            torch.save(fusion_patch, os.path.join(self.args.save_feature_path, class_name+ str(self.ins_id) + '.pt'))\r\n            self.ins_id += 1\r\n\r\n        \r\n    def predict(self, sample, mask, label):\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))\r\n        xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n        rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))\r\n        rgb_patch2 =  self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))\r\n        rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        if FUSION_BLOCK:\r\n            with torch.no_grad():\r\n                fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))\r\n            fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()\r\n        else:\r\n            fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)\r\n    \r\n\r\n        self.compute_s_s_map(xyz_patch, rgb_patch, fusion_patch, xyz_patch_full_resized[0].shape[-2:], mask, label, center, neighbor_idx, nonzero_indices, unorganized_pc_no_zeros.contiguous(), center_idx)\r\n\r\n    def add_sample_to_late_fusion_mem_bank(self, sample):\r\n\r\n\r\n        organized_pc = sample[1]\r\n        organized_pc_np = organized_pc.squeeze().permute(1, 2, 0).numpy()\r\n        unorganized_pc = organized_pc_to_unorganized_pc(organized_pc=organized_pc_np)\r\n        nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\r\n        \r\n        unorganized_pc_no_zeros = torch.tensor(unorganized_pc[nonzero_indices, :]).unsqueeze(dim=0).permute(0, 2, 1)\r\n        rgb_feature_maps, xyz_feature_maps, center, neighbor_idx, center_idx, interpolated_pc = self(sample[0],unorganized_pc_no_zeros.contiguous())\r\n\r\n        xyz_patch = torch.cat(xyz_feature_maps, 1)\r\n        xyz_patch_full = torch.zeros((1, interpolated_pc.shape[1], self.image_size*self.image_size), dtype=xyz_patch.dtype)\r\n        xyz_patch_full[:,:,nonzero_indices] = interpolated_pc\r\n        xyz_patch_full_2d = xyz_patch_full.view(1, interpolated_pc.shape[1], self.image_size, self.image_size)\r\n        xyz_patch_full_resized = self.resize(self.average(xyz_patch_full_2d))\r\n        xyz_patch = xyz_patch_full_resized.reshape(xyz_patch_full_resized.shape[1], -1).T\r\n\r\n        xyz_patch_full_resized2 = self.resize2(self.average(xyz_patch_full_2d))\r\n        xyz_patch2 = xyz_patch_full_resized2.reshape(xyz_patch_full_resized2.shape[1], -1).T\r\n\r\n        rgb_patch = torch.cat(rgb_feature_maps, 1)\r\n        \r\n        rgb_patch = rgb_patch.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        rgb_patch_size = int(math.sqrt(rgb_patch.shape[0]))\r\n        rgb_patch2 =  self.resize2(rgb_patch.permute(1, 0).reshape(-1, rgb_patch_size, rgb_patch_size))\r\n        rgb_patch2 = rgb_patch2.reshape(rgb_patch.shape[1], -1).T\r\n\r\n        if FUSION_BLOCK:\r\n            with torch.no_grad():\r\n                fusion_patch = self.fusion.feature_fusion(xyz_patch2.unsqueeze(0), rgb_patch2.unsqueeze(0))\r\n            fusion_patch = fusion_patch.reshape(-1, fusion_patch.shape[2]).detach()\r\n        else:\r\n            fusion_patch = torch.cat([xyz_patch2, rgb_patch2], dim=1)\r\n    \r\n        # 3D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std\r\n\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n        dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib)\r\n        \r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n        fusion_feat_size =  (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0])))\r\n\r\n        # 3 memory bank results\r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n        s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion')\r\n        \r\n\r\n        s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]])\r\n \r\n        s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0)\r\n\r\n        self.s_lib.append(s)\r\n        self.s_map_lib.append(s_map)\r\n\r\n    def run_coreset(self):\r\n        self.patch_xyz_lib = torch.cat(self.patch_xyz_lib, 0)\r\n        self.patch_rgb_lib = torch.cat(self.patch_rgb_lib, 0)\r\n        self.patch_fusion_lib = torch.cat(self.patch_fusion_lib, 0)\r\n\r\n        self.xyz_mean = torch.mean(self.patch_xyz_lib)\r\n        self.xyz_std = torch.std(self.patch_rgb_lib)\r\n        self.rgb_mean = torch.mean(self.patch_xyz_lib)\r\n        self.rgb_std = torch.std(self.patch_rgb_lib)\r\n        self.fusion_mean = torch.mean(self.patch_xyz_lib)\r\n        self.fusion_std = torch.std(self.patch_rgb_lib)\r\n\r\n        self.patch_xyz_lib = (self.patch_xyz_lib - self.xyz_mean)/self.xyz_std\r\n        self.patch_rgb_lib = (self.patch_rgb_lib - self.rgb_mean)/self.rgb_std\r\n        self.patch_fusion_lib = (self.patch_fusion_lib - self.fusion_mean)/self.fusion_std\r\n\r\n        if self.f_coreset < 1:\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_xyz_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_xyz_lib = self.patch_xyz_lib[self.coreset_idx]\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_rgb_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_rgb_lib = self.patch_rgb_lib[self.coreset_idx]\r\n            self.coreset_idx = self.get_coreset_idx_randomp(self.patch_fusion_lib,\r\n                                                            n=int(self.f_coreset * self.patch_xyz_lib.shape[0]),\r\n                                                            eps=self.coreset_eps, )\r\n            self.patch_fusion_lib = self.patch_fusion_lib[self.coreset_idx]\r\n\r\n\r\n        self.patch_xyz_lib = self.patch_xyz_lib[torch.nonzero(torch.all(self.patch_xyz_lib!=0, dim=1))[:,0]]\r\n        self.patch_xyz_lib = torch.cat((self.patch_xyz_lib, torch.zeros(1, self.patch_xyz_lib.shape[1])), 0)\r\n\r\n\r\n    def compute_s_s_map(self, xyz_patch, rgb_patch, fusion_patch, feature_map_dims, mask, label, center, neighbour_idx, nonzero_indices, xyz, center_idx):\r\n        '''\r\n        center: point group center position\r\n        neighbour_idx: each group point index\r\n        nonzero_indices: point indices of original point clouds\r\n        xyz: nonzero point clouds\r\n        '''\r\n\r\n        # 3D dist \r\n        xyz_patch = (xyz_patch - self.xyz_mean)/self.xyz_std\r\n        rgb_patch = (rgb_patch - self.rgb_mean)/self.rgb_std\r\n        fusion_patch = (fusion_patch - self.fusion_mean)/self.fusion_std\r\n\r\n        dist_xyz = torch.cdist(xyz_patch, self.patch_xyz_lib)\r\n        dist_rgb = torch.cdist(rgb_patch, self.patch_rgb_lib)\r\n        dist_fusion = torch.cdist(fusion_patch, self.patch_fusion_lib)\r\n        \r\n        rgb_feat_size = (int(math.sqrt(rgb_patch.shape[0])), int(math.sqrt(rgb_patch.shape[0])))\r\n        xyz_feat_size = (int(math.sqrt(xyz_patch.shape[0])), int(math.sqrt(xyz_patch.shape[0])))\r\n        fusion_feat_size =  (int(math.sqrt(fusion_patch.shape[0])), int(math.sqrt(fusion_patch.shape[0])))\r\n\r\n  \r\n        s_xyz, s_map_xyz = self.compute_single_s_s_map(xyz_patch, dist_xyz, xyz_feat_size, modal='xyz')\r\n        s_rgb, s_map_rgb = self.compute_single_s_s_map(rgb_patch, dist_rgb, rgb_feat_size, modal='rgb')\r\n        s_fusion, s_map_fusion = self.compute_single_s_s_map(fusion_patch, dist_fusion, fusion_feat_size, modal='fusion')\r\n\r\n        s = torch.tensor([[self.args.xyz_s_lambda*s_xyz, self.args.rgb_s_lambda*s_rgb, self.args.fusion_s_lambda*s_fusion]])\r\n \r\n        s_map = torch.cat([self.args.xyz_smap_lambda*s_map_xyz, self.args.rgb_smap_lambda*s_map_rgb, self.args.fusion_smap_lambda*s_map_fusion], dim=0).squeeze().reshape(3, -1).permute(1, 0)\r\n \r\n        s = torch.tensor(self.detect_fuser.score_samples(s))\r\n\r\n        s_map = torch.tensor(self.seg_fuser.score_samples(s_map))\r\n  \r\n        s_map = s_map.view(1, self.image_size, self.image_size)\r\n\r\n\r\n        self.image_preds.append(s.numpy())\r\n        self.image_labels.append(label)\r\n        self.pixel_preds.extend(s_map.flatten().numpy())\r\n        self.pixel_labels.extend(mask.flatten().numpy())\r\n        self.predictions.append(s_map.detach().cpu().squeeze().numpy())\r\n        self.gts.append(mask.detach().cpu().squeeze().numpy())\r\n\r\n    def compute_single_s_s_map(self, patch, dist, feature_map_dims, modal='xyz'):\r\n\r\n        min_val, min_idx = torch.min(dist, dim=1)\r\n\r\n        s_idx = torch.argmax(min_val)\r\n        s_star = torch.max(min_val)\r\n\r\n        # reweighting\r\n        m_test = patch[s_idx].unsqueeze(0)  # anomalous patch\r\n\r\n        if modal=='xyz':\r\n            m_star = self.patch_xyz_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_xyz_lib)  # find knn to m_star pt.1\r\n        elif modal=='rgb':\r\n            m_star = self.patch_rgb_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_rgb_lib)  # find knn to m_star pt.1\r\n        else:\r\n            m_star = self.patch_fusion_lib[min_idx[s_idx]].unsqueeze(0)  # closest neighbour\r\n            w_dist = torch.cdist(m_star, self.patch_fusion_lib)  # find knn to m_star pt.1\r\n        _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n\r\n        # equation 7 from the paper\r\n        if modal=='xyz':\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1:]], dim=1) \r\n        elif modal=='rgb':\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)\r\n        else:\r\n            m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1:]], dim=1)\r\n\r\n        # sparse reweight\r\n        # if modal=='rgb':\r\n        #     _, nn_idx = torch.topk(w_dist, k=self.n_reweight, largest=False)  # pt.2\r\n        # else:\r\n        #     _, nn_idx = torch.topk(w_dist, k=4*self.n_reweight, largest=False)  # pt.2\r\n\r\n        # if modal=='xyz':\r\n        #     m_star_knn = torch.linalg.norm(m_test - self.patch_xyz_lib[nn_idx[0, 1::4]], dim=1) \r\n        # elif modal=='rgb':\r\n        #     m_star_knn = torch.linalg.norm(m_test - self.patch_rgb_lib[nn_idx[0, 1:]], dim=1)\r\n        # else:\r\n        #     m_star_knn = torch.linalg.norm(m_test - self.patch_fusion_lib[nn_idx[0, 1::4]], dim=1)\r\n        # Softmax normalization trick as in transformers.\r\n        # As the patch vectors grow larger, their norm might differ a lot.\r\n        # exp(norm) can give infinities.\r\n        D = torch.sqrt(torch.tensor(patch.shape[1]))\r\n        w = 1 - (torch.exp(s_star / D) / (torch.sum(torch.exp(m_star_knn / D))))\r\n\r\n        s = w * s_star\r\n\r\n        # segmentation map\r\n        s_map = min_val.view(1, 1, *feature_map_dims)\r\n        s_map = torch.nn.functional.interpolate(s_map, size=(self.image_size, self.image_size), mode='bilinear', align_corners=False)\r\n        s_map = self.blur(s_map)\r\n\r\n        return s, s_map\r\n"
  },
  {
    "path": "fusion_pretrain.py",
    "content": "import argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.tensorboard import SummaryWriter\nimport torchvision.transforms as transforms\n\nimport timm\n\nimport timm.optim.optim_factory as optim_factory\n\nimport utils.misc as misc\nfrom utils.misc import NativeScalerWithGradNormCount as NativeScaler\n\n\nfrom engine_fusion_pretrain import train_one_epoch\n\nimport dataset\n\nimport torch\nfrom models.feature_fusion import FeatureFusionBlock\n\n\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser('MAE pre-training', add_help=False)\n    parser.add_argument('--batch_size', default=64, type=int,\n                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')\n    parser.add_argument('--epochs', default=3, type=int)\n    parser.add_argument('--accum_iter', default=1, type=int,\n                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')\n\n    # Model parameters\n\n    parser.add_argument('--input_size', default=224, type=int,\n                        help='images input size')\n\n\n    # Optimizer parameters\n    parser.add_argument('--weight_decay', type=float, default=1.5e-6,\n                        help='weight decay (default: 0.05)')\n\n    parser.add_argument('--lr', type=float, default=None, metavar='LR',\n                        help='learning rate (absolute lr)')\n    parser.add_argument('--blr', type=float, default=0.002, metavar='LR',\n                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')\n    parser.add_argument('--min_lr', type=float, default=0., metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0')\n\n    parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',\n                        help='epochs to warmup LR')\n\n    # Dataset parameters\n    parser.add_argument('--data_path', default='', type=str,\n                        help='dataset path')\n\n    parser.add_argument('--output_dir', default='./output_dir',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default='./output_dir',\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='',\n                        help='resume from checkpoint')\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--num_workers', default=10, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://',\n                        help='url used to set up distributed training')\n\n    return parser\n\n\n\ndef main(args):\n    misc.init_distributed_mode(args)\n\n    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))\n    print(\"{}\".format(args).replace(', ', ',\\n'))\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + misc.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    cudnn.benchmark = True\n\n    \n    dataset_train = dataset.PreTrainTensorDataset(args.data_path)\n\n    print(dataset_train)\n\n\n    if True:  # args.distributed:\n        num_tasks = misc.get_world_size()\n        global_rank = misc.get_rank()\n        sampler_train = torch.utils.data.DistributedSampler(\n            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True\n        )\n        print(\"Sampler_train = %s\" % str(sampler_train))\n    else:\n        sampler_train = torch.utils.data.RandomSampler(dataset_train)\n\n    if global_rank == 0 and args.log_dir is not None:\n        os.makedirs(args.log_dir, exist_ok=True)\n        log_writer = SummaryWriter(log_dir=args.log_dir)\n    else:\n        log_writer = None\n\n    data_loader_train = torch.utils.data.DataLoader(\n        dataset_train, sampler=sampler_train,\n        batch_size=args.batch_size,\n        num_workers=args.num_workers,\n        pin_memory=args.pin_mem,\n        drop_last=True,\n    )\n    \n\n    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()\n    \n    if args.lr is None:  # only base_lr is specified\n        args.lr = args.blr * eff_batch_size / 256\n\n    print(\"base lr: %.2e\" % (args.lr * 256 / eff_batch_size))\n    print(\"actual lr: %.2e\" % args.lr)\n\n    print(\"accumulate grad iterations: %d\" % args.accum_iter)\n    print(\"effective batch size: %d\" % eff_batch_size)\n\n    model = FeatureFusionBlock(1152, 768)\n\n    model.to(device)\n\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)\n        model_without_ddp = model.module\n    \n    # following timm: set wd as 0 for bias and norm layers\n    optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr, betas=(0.9, 0.95))\n    print(optimizer)\n    loss_scaler = NativeScaler()\n\n    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    for epoch in range(args.start_epoch, args.epochs):\n        if args.distributed:\n            data_loader_train.sampler.set_epoch(epoch)\n        train_stats = train_one_epoch(\n            model, data_loader_train,\n            optimizer, device, epoch, loss_scaler,\n            log_writer=log_writer,\n            args=args\n        )\n        if args.output_dir and (epoch % 1 == 0 or epoch + 1 == args.epochs):\n            misc.save_model(\n                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                loss_scaler=loss_scaler, epoch=epoch)\n\n        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                        'epoch': epoch,}\n\n        if args.output_dir and misc.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\n\nif __name__ == '__main__':\n    args = get_args_parser()\n    args = args.parse_args()\n    if args.output_dir:\n        Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n    main(args)\n"
  },
  {
    "path": "m3dm_runner.py",
    "content": "import torch\nfrom tqdm import tqdm\nimport os\n\nfrom feature_extractors import multiple_features\n        \nfrom dataset import get_data_loader\n\nclass M3DM():\n    def __init__(self, args):\n        self.args = args\n        self.image_size = args.img_size\n        self.count = args.max_sample\n        if args.method_name == 'DINO':\n            self.methods = {\n                \"DINO\": multiple_features.RGBFeatures(args),\n            }\n        elif args.method_name == 'Point_MAE':\n            self.methods = {\n                \"Point_MAE\": multiple_features.PointFeatures(args),\n            }\n        elif args.method_name == 'Fusion':\n            self.methods = {\n                \"Fusion\": multiple_features.FusionFeatures(args),\n            }\n        elif args.method_name == 'DINO+Point_MAE':\n            self.methods = {\n                \"DINO+Point_MAE\": multiple_features.DoubleRGBPointFeatures(args),\n            }\n        elif args.method_name == 'DINO+Point_MAE+add':\n            self.methods = {\n                \"DINO+Point_MAE\": multiple_features.DoubleRGBPointFeatures_add(args),\n            }\n        elif args.method_name == 'DINO+Point_MAE+Fusion':\n            self.methods = {\n                \"DINO+Point_MAE+Fusion\": multiple_features.TripleFeatures(args),\n            }\n\n\n    def fit(self, class_name):\n        train_loader = get_data_loader(\"train\", class_name=class_name, img_size=self.image_size, args=self.args)\n\n        flag = 0\n        for sample, _ in tqdm(train_loader, desc=f'Extracting train features for class {class_name}'):\n            for method in self.methods.values():\n                if self.args.save_feature:\n                    method.add_sample_to_mem_bank(sample, class_name=class_name)\n                else:\n                    method.add_sample_to_mem_bank(sample)\n                flag += 1\n            if flag > self.count:\n                flag = 0\n                break\n                \n        for method_name, method in self.methods.items():\n            print(f'\\n\\nRunning coreset for {method_name} on class {class_name}...')\n            method.run_coreset()\n            \n\n        if self.args.memory_bank == 'multiple':    \n            flag = 0\n            for sample, _ in tqdm(train_loader, desc=f'Running late fusion for {method_name} on class {class_name}..'):\n                for method_name, method in self.methods.items():\n                    method.add_sample_to_late_fusion_mem_bank(sample)\n                    flag += 1\n                if flag > self.count:\n                    flag = 0\n                    break\n        \n            for method_name, method in self.methods.items():\n                print(f'\\n\\nTraining Dicision Layer Fusion for {method_name} on class {class_name}...')\n                method.run_late_fusion()\n\n    def evaluate(self, class_name):\n        image_rocaucs = dict()\n        pixel_rocaucs = dict()\n        au_pros = dict()\n        test_loader = get_data_loader(\"test\", class_name=class_name, img_size=self.image_size, args=self.args)\n        path_list = []\n        with torch.no_grad():\n        \n            for sample, mask, label, rgb_path in tqdm(test_loader, desc=f'Extracting test features for class {class_name}'):\n                for method in self.methods.values():\n                    method.predict(sample, mask, label)\n                    path_list.append(rgb_path)\n                        \n\n        for method_name, method in self.methods.items():\n            method.calculate_metrics()\n            image_rocaucs[method_name] = round(method.image_rocauc, 3)\n            pixel_rocaucs[method_name] = round(method.pixel_rocauc, 3)\n            au_pros[method_name] = round(method.au_pro, 3)\n            print(\n                f'Class: {class_name}, {method_name} Image ROCAUC: {method.image_rocauc:.3f}, {method_name} Pixel ROCAUC: {method.pixel_rocauc:.3f}, {method_name} AU-PRO: {method.au_pro:.3f}')\n            if self.args.save_preds:\n                method.save_prediction_maps('./pred_maps', path_list)\n        return image_rocaucs, pixel_rocaucs, au_pros\n"
  },
  {
    "path": "main.py",
    "content": "import argparse\nfrom m3dm_runner import M3DM\nfrom dataset import eyecandies_classes, mvtec3d_classes\nimport pandas as pd\n\n\ndef run_3d_ads(args):\n    if args.dataset_type=='eyecandies':\n        classes = eyecandies_classes()\n    elif args.dataset_type=='mvtec3d':\n        classes = mvtec3d_classes()\n\n    METHOD_NAMES = [args.method_name]\n\n    image_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])\n    pixel_rocaucs_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])\n    au_pros_df = pd.DataFrame(METHOD_NAMES, columns=['Method'])\n    for cls in classes:\n        model = M3DM(args)\n        model.fit(cls)\n        image_rocaucs, pixel_rocaucs, au_pros = model.evaluate(cls)\n        image_rocaucs_df[cls.title()] = image_rocaucs_df['Method'].map(image_rocaucs)\n        pixel_rocaucs_df[cls.title()] = pixel_rocaucs_df['Method'].map(pixel_rocaucs)\n        au_pros_df[cls.title()] = au_pros_df['Method'].map(au_pros)\n\n        print(f\"\\nFinished running on class {cls}\")\n        print(\"################################################################################\\n\\n\")\n\n    image_rocaucs_df['Mean'] = round(image_rocaucs_df.iloc[:, 1:].mean(axis=1),3)\n    pixel_rocaucs_df['Mean'] = round(pixel_rocaucs_df.iloc[:, 1:].mean(axis=1),3)\n    au_pros_df['Mean'] = round(au_pros_df.iloc[:, 1:].mean(axis=1),3)\n\n    print(\"\\n\\n################################################################################\")\n    print(\"############################# Image ROCAUC Results #############################\")\n    print(\"################################################################################\\n\")\n    print(image_rocaucs_df.to_markdown(index=False))\n\n    print(\"\\n\\n################################################################################\")\n    print(\"############################# Pixel ROCAUC Results #############################\")\n    print(\"################################################################################\\n\")\n    print(pixel_rocaucs_df.to_markdown(index=False))\n\n    print(\"\\n\\n##########################################################################\")\n    print(\"############################# AU PRO Results #############################\")\n    print(\"##########################################################################\\n\")\n    print(au_pros_df.to_markdown(index=False))\n\n\n\n    with open(\"results/image_rocauc_results.md\", \"a\") as tf:\n        tf.write(image_rocaucs_df.to_markdown(index=False))\n    with open(\"results/pixel_rocauc_results.md\", \"a\") as tf:\n        tf.write(pixel_rocaucs_df.to_markdown(index=False))\n    with open(\"results/aupro_results.md\", \"a\") as tf:\n        tf.write(au_pros_df.to_markdown(index=False))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Process some integers.')\n\n    parser.add_argument('--method_name', default='DINO+Point_MAE+Fusion', type=str, \n                        choices=['DINO', 'Point_MAE', 'Fusion', 'DINO+Point_MAE', 'DINO+Point_MAE+Fusion', 'DINO+Point_MAE+add'],\n                        help='Anomaly detection modal name.')\n    parser.add_argument('--max_sample', default=400, type=int,\n                        help='Max sample number.')\n    parser.add_argument('--memory_bank', default='multiple', type=str,\n                        choices=[\"multiple\", \"single\"],\n                        help='memory bank mode: \"multiple\", \"single\".')\n    parser.add_argument('--rgb_backbone_name', default='vit_base_patch8_224_dino', type=str, \n                        choices=['vit_base_patch8_224_dino', 'vit_base_patch8_224', 'vit_base_patch8_224_in21k', 'vit_small_patch8_224_dino'],\n                        help='Timm checkpoints name of RGB backbone.')\n    parser.add_argument('--xyz_backbone_name', default='Point_MAE', type=str, choices=['Point_MAE', 'Point_Bert'],\n                        help='Checkpoints name of RGB backbone[Point_MAE, Point_Bert].')\n    parser.add_argument('--fusion_module_path', default='checkpoints/checkpoint-0.pth', type=str,\n                        help='Checkpoints for fusion module.')\n    parser.add_argument('--save_feature', default=False, action='store_true',\n                        help='Save feature for training fusion block.')\n    parser.add_argument('--use_uff', default=False, action='store_true',\n                        help='Use UFF module.')\n    parser.add_argument('--save_feature_path', default='datasets/patch_lib', type=str,\n                        help='Save feature for training fusion block.')\n    parser.add_argument('--save_preds', default=False, action='store_true',\n                        help='Save predicts results.')\n    parser.add_argument('--group_size', default=128, type=int,\n                        help='Point group size of Point Transformer.')\n    parser.add_argument('--num_group', default=1024, type=int,\n                        help='Point groups number of Point Transformer.')\n    parser.add_argument('--random_state', default=None, type=int,\n                        help='random_state for random project')\n    parser.add_argument('--dataset_type', default='mvtec3d', type=str, choices=['mvtec3d', 'eyecandies'], \n                        help='Dataset type for training or testing')\n    parser.add_argument('--dataset_path', default='datasets/mvtec3d', type=str, \n                        help='Dataset store path')\n    parser.add_argument('--img_size', default=224, type=int,\n                        help='Images size for model')\n    parser.add_argument('--xyz_s_lambda', default=1.0, type=float,\n                        help='xyz_s_lambda')\n    parser.add_argument('--xyz_smap_lambda', default=1.0, type=float,\n                        help='xyz_smap_lambda')\n    parser.add_argument('--rgb_s_lambda', default=0.1, type=float,\n                        help='rgb_s_lambda')\n    parser.add_argument('--rgb_smap_lambda', default=0.1, type=float,\n                        help='rgb_smap_lambda')\n    parser.add_argument('--fusion_s_lambda', default=1.0, type=float,\n                        help='fusion_s_lambda')\n    parser.add_argument('--fusion_smap_lambda', default=1.0, type=float,\n                        help='fusion_smap_lambda')\n    parser.add_argument('--coreset_eps', default=0.9, type=float,\n                        help='eps for sparse project')\n    parser.add_argument('--f_coreset', default=0.1, type=float,\n                        help='eps for sparse project')\n    parser.add_argument('--asy_memory_bank', default=None, type=int,\n                        help='build an asymmetric memory bank for point clouds')\n    parser.add_argument('--ocsvm_nu', default=0.5, type=float,\n                        help='ocsvm nu')\n    parser.add_argument('--ocsvm_maxiter', default=1000, type=int,\n                        help='ocsvm maxiter')\n    parser.add_argument('--rm_zero_for_project', default=False, action='store_true',\n                        help='Save predicts results.')\n  \n\n\n    args = parser.parse_args()\n    run_3d_ads(args)\n"
  },
  {
    "path": "models/feature_fusion.py",
    "content": "import torch\nimport torch.nn as nn\nimport math\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\nclass FeatureFusionBlock(nn.Module):\n    def __init__(self, xyz_dim, rgb_dim, mlp_ratio=4.):\n        super().__init__()\n\n        self.xyz_dim = xyz_dim\n        self.rgb_dim = rgb_dim\n\n        self.xyz_norm = nn.LayerNorm(xyz_dim)\n        self.xyz_mlp = Mlp(in_features=xyz_dim, hidden_features=int(xyz_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)\n\n        self.rgb_norm = nn.LayerNorm(rgb_dim)\n        self.rgb_mlp = Mlp(in_features=rgb_dim, hidden_features=int(rgb_dim * mlp_ratio), act_layer=nn.GELU, drop=0.)\n\n        self.rgb_head = nn.Linear(rgb_dim, 256)\n        self.xyz_head = nn.Linear(xyz_dim, 256)\n        \n        self.T = 1\n\n    def feature_fusion(self, xyz_feature, rgb_feature):\n\n        xyz_feature  = self.xyz_mlp(self.xyz_norm(xyz_feature))\n        rgb_feature  = self.rgb_mlp(self.rgb_norm(rgb_feature))\n\n        feature = torch.cat([xyz_feature, rgb_feature], dim=2)\n\n        return feature\n\n    def contrastive_loss(self, q, k):\n        # normalize\n        q = nn.functional.normalize(q, dim=1)\n        k = nn.functional.normalize(k, dim=1)\n        # gather all targets\n        # Einstein sum is more intuitive\n        logits = torch.einsum('nc,mc->nm', [q, k]) / self.T\n        N = logits.shape[0]  # batch size per GPU\n        labels = (torch.arange(N, dtype=torch.long) + N * torch.distributed.get_rank()).cuda()\n        return nn.CrossEntropyLoss()(logits, labels) * (2 * self.T)\n\n    def reparameterize(self, mu, logvar):\n        \"\"\"\n        Will a single z be enough ti compute the expectation\n        for the loss??\n        :param mu: (Tensor) Mean of the latent Gaussian\n        :param logvar: (Tensor) Standard deviation of the latent Gaussian\n        :return:\n        \"\"\"\n        std = torch.exp(0.5 * logvar)\n        eps = torch.randn_like(std)\n        return eps * std + mu\n\n\n    def forward(self, xyz_feature, rgb_feature):\n\n\n        feature = self.feature_fusion(xyz_feature, rgb_feature)\n\n        feature_xyz = feature[:,:, :self.xyz_dim]\n        feature_rgb = feature[:,:, self.xyz_dim:]\n\n        q = self.rgb_head(feature_rgb.view(-1, feature_rgb.shape[2]))\n        k = self.xyz_head(feature_xyz.view(-1, feature_xyz.shape[2]))\n\n        xyz_feature = xyz_feature.view(-1, xyz_feature.shape[2])\n        rgb_feature = rgb_feature.view(-1, rgb_feature.shape[2])\n\n        patch_no_zeros_indices = torch.nonzero(torch.all(xyz_feature != 0, dim=1))\n        \n        loss = self.contrastive_loss(q[patch_no_zeros_indices,:].squeeze(), k[patch_no_zeros_indices,:].squeeze())\n\n        return loss\n\n"
  },
  {
    "path": "models/models.py",
    "content": "import torch\nimport torch.nn as nn\nimport timm\nfrom timm.models.layers import DropPath, trunc_normal_\nfrom pointnet2_ops import pointnet2_utils\nfrom knn_cuda import KNN\n\nclass Model(torch.nn.Module):\n\n    def __init__(self, device, rgb_backbone_name='vit_base_patch8_224_dino', out_indices=None, checkpoint_path='',\n                 pool_last=False, xyz_backbone_name='Point_MAE', group_size=128, num_group=1024):\n        super().__init__()\n        # 'vit_base_patch8_224_dino'\n        # Determine if to output features.\n        self.device = device\n\n        kwargs = {'features_only': True if out_indices else False}\n        if out_indices:\n            kwargs.update({'out_indices': out_indices})\n\n        ## RGB backbone\n        self.rgb_backbone = timm.create_model(model_name=rgb_backbone_name, pretrained=True, checkpoint_path=checkpoint_path,\n                                          **kwargs)\n        \n        ## XYZ backbone\n        \n        if xyz_backbone_name=='Point_MAE':\n            self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group)\n            self.xyz_backbone.load_model_from_ckpt(\"checkpoints/pointmae_pretrain.pth\")\n        elif xyz_backbone_name=='Point_Bert':\n            self.xyz_backbone=PointTransformer(group_size=group_size, num_group=num_group, encoder_dims=256)\n            self.xyz_backbone.load_model_from_pb_ckpt(\"checkpoints/Point-BERT.pth\")\n\n\n\n    def forward_rgb_features(self, x):\n        x = self.rgb_backbone.patch_embed(x)\n        x = self.rgb_backbone._pos_embed(x)\n        x = self.rgb_backbone.norm_pre(x)\n        if self.rgb_backbone.grad_checkpointing and not torch.jit.is_scripting():\n            x = checkpoint_seq(self.blocks, x)\n        else:\n            x = self.rgb_backbone.blocks(x)\n        x = self.rgb_backbone.norm(x)\n\n        feat = x[:,1:].permute(0, 2, 1).view(1, -1, 28, 28)\n        return feat\n\n\n    def forward(self, rgb, xyz):\n        \n        rgb_features = self.forward_rgb_features(rgb)\n\n        xyz_features, center, ori_idx, center_idx = self.xyz_backbone(xyz)\n\n        return rgb_features, xyz_features, center, ori_idx, center_idx\n\n\n\ndef fps(data, number):\n    '''\n        data B N 3\n        number int\n    '''\n    fps_idx = pointnet2_utils.furthest_point_sample(data, number)\n    fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1, 2).contiguous()\n    return fps_data, fps_idx\n\nclass Group(nn.Module):\n    def __init__(self, num_group, group_size):\n        super().__init__()\n        self.num_group = num_group\n        self.group_size = group_size\n        self.knn = KNN(k=self.group_size, transpose_mode=True)\n\n    def forward(self, xyz):\n        '''\n            input: B N 3\n            ---------------------------\n            output: B G M 3\n            center : B G 3\n        '''\n        batch_size, num_points, _ = xyz.shape\n        # fps the centers out\n        center, center_idx = fps(xyz.contiguous(), self.num_group)  # B G 3\n        # knn to get the neighborhood\n        _, idx = self.knn(xyz, center)  # B G M\n        assert idx.size(1) == self.num_group\n        assert idx.size(2) == self.group_size\n        ori_idx = idx\n        idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points\n        idx = idx + idx_base\n        idx = idx.view(-1)\n        neighborhood = xyz.reshape(batch_size * num_points, -1)[idx, :]\n        neighborhood = neighborhood.reshape(batch_size, self.num_group, self.group_size, 3).contiguous()\n        # normalize\n        neighborhood = neighborhood - center.unsqueeze(2)\n        return neighborhood, center, ori_idx, center_idx\n\n\nclass Encoder(nn.Module):\n    def __init__(self, encoder_channel):\n        super().__init__()\n        self.encoder_channel = encoder_channel\n        self.first_conv = nn.Sequential(\n            nn.Conv1d(3, 128, 1),\n            nn.BatchNorm1d(128),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(128, 256, 1)\n        )\n        self.second_conv = nn.Sequential(\n            nn.Conv1d(512, 512, 1),\n            nn.BatchNorm1d(512),\n            nn.ReLU(inplace=True),\n            nn.Conv1d(512, self.encoder_channel, 1)\n        )\n\n    def forward(self, point_groups):\n        '''\n            point_groups : B G N 3\n            -----------------\n            feature_global : B G C\n        '''\n        bs, g, n, _ = point_groups.shape\n        point_groups = point_groups.reshape(bs * g, n, 3)\n        # encoder\n        feature = self.first_conv(point_groups.transpose(2, 1))\n        feature_global = torch.max(feature, dim=2, keepdim=True)[0]\n        feature = torch.cat([feature_global.expand(-1, -1, n), feature], dim=1)\n        feature = self.second_conv(feature)\n        feature_global = torch.max(feature, dim=2, keepdim=False)[0]\n        return feature_global.reshape(bs, g, self.encoder_channel)\n\n\nclass Mlp(nn.Module):\n    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n\n\nclass Attention(nn.Module):\n    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights\n        self.scale = qk_scale or head_dim ** -0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q * self.scale) @ k.transpose(-2, -1)\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass Block(nn.Module):\n    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,\n                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n        super().__init__()\n        self.norm1 = norm_layer(dim)\n        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)\n\n        self.attn = Attention(\n            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)\n\n    def forward(self, x):\n        x = x + self.drop_path(self.attn(self.norm1(x)))\n        x = x + self.drop_path(self.mlp(self.norm2(x)))\n        return x\n\n\nclass TransformerEncoder(nn.Module):\n    \"\"\" Transformer Encoder without hierarchical structure\n    \"\"\"\n\n    def __init__(self, embed_dim=768, depth=4, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None,\n                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):\n        super().__init__()\n\n        self.blocks = nn.ModuleList([\n            Block(\n                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,\n                drop=drop_rate, attn_drop=attn_drop_rate,\n                drop_path=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate\n            )\n            for i in range(depth)])\n\n    def forward(self, x, pos):\n        feature_list = []\n        fetch_idx = [3, 7, 11]\n        for i, block in enumerate(self.blocks):\n            x = block(x + pos)\n            if i in fetch_idx:\n                feature_list.append(x)\n        return feature_list\n\n\nclass PointTransformer(nn.Module):\n    def __init__(self, group_size=128, num_group=1024, encoder_dims=384):\n        super().__init__()\n\n        self.trans_dim = 384\n        self.depth = 12\n        self.drop_path_rate = 0.1\n        self.num_heads = 6\n\n        self.group_size = group_size\n        self.num_group = num_group\n        # grouper\n        self.group_divider = Group(num_group=self.num_group, group_size=self.group_size)\n        # define the encoder\n        self.encoder_dims = encoder_dims\n        if self.encoder_dims != self.trans_dim:\n            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))\n            self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))\n            self.reduce_dim = nn.Linear(self.encoder_dims,  self.trans_dim)\n        self.encoder = Encoder(encoder_channel=self.encoder_dims)\n        # bridge encoder and transformer\n\n        self.pos_embed = nn.Sequential(\n            nn.Linear(3, 128),\n            nn.GELU(),\n            nn.Linear(128, self.trans_dim)\n        )\n\n        dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)]\n        self.blocks = TransformerEncoder(\n            embed_dim=self.trans_dim,\n            depth=self.depth,\n            drop_path_rate=dpr,\n            num_heads=self.num_heads\n        )\n\n        self.norm = nn.LayerNorm(self.trans_dim)\n\n    def load_model_from_ckpt(self, bert_ckpt_path):\n        if bert_ckpt_path is not None:\n            ckpt = torch.load(bert_ckpt_path)\n            base_ckpt = {k.replace(\"module.\", \"\"): v for k, v in ckpt['base_model'].items()}\n\n            for k in list(base_ckpt.keys()):\n                if k.startswith('MAE_encoder'):\n                    base_ckpt[k[len('MAE_encoder.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n                elif k.startswith('base_model'):\n                    base_ckpt[k[len('base_model.'):]] = base_ckpt[k]\n                    del base_ckpt[k]\n\n            incompatible = self.load_state_dict(base_ckpt, strict=False)\n\n            #if incompatible.missing_keys:\n            #    print('missing_keys')\n            #    print(\n            #            incompatible.missing_keys\n            #        )\n            #if incompatible.unexpected_keys:\n            #    print('unexpected_keys')\n            #    print(\n            #            incompatible.unexpected_keys\n\n            #        )\n\n            # print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')\n\n    def load_model_from_pb_ckpt(self, bert_ckpt_path):\n        ckpt = torch.load(bert_ckpt_path)\n        base_ckpt = {k.replace(\"module.\", \"\"): v for k, v in ckpt['base_model'].items()}\n        for k in list(base_ckpt.keys()):\n            if k.startswith('transformer_q') and not k.startswith('transformer_q.cls_head'):\n                base_ckpt[k[len('transformer_q.'):]] = base_ckpt[k]\n            elif k.startswith('base_model'):\n                base_ckpt[k[len('base_model.'):]] = base_ckpt[k]\n            del base_ckpt[k]\n\n        incompatible = self.load_state_dict(base_ckpt, strict=False)\n\n        if incompatible.missing_keys:\n            print('missing_keys')\n            print(\n                    incompatible.missing_keys\n                )\n        if incompatible.unexpected_keys:\n            print('unexpected_keys')\n            print(\n                    incompatible.unexpected_keys\n\n                )\n                \n        print(f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}')\n\n\n    def forward(self, pts):\n        if self.encoder_dims != self.trans_dim:\n            B,C,N = pts.shape\n            pts = pts.transpose(-1, -2) # B N 3\n            # divide the point clo  ud in the same form. This is important\n            neighborhood,  center, ori_idx, center_idx = self.group_divider(pts)\n            # # generate mask\n            # bool_masked_pos = self._mask_center(center, no_mask = False) # B G\n            # encoder the input cloud blocks\n            group_input_tokens = self.encoder(neighborhood)  #  B G N\n            group_input_tokens = self.reduce_dim(group_input_tokens)\n            # prepare cls\n            cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)  \n            cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)  \n            # add pos embedding\n            pos = self.pos_embed(center)\n            # final input\n            x = torch.cat((cls_tokens, group_input_tokens), dim=1)\n            pos = torch.cat((cls_pos, pos), dim=1)\n            # transformer\n            feature_list = self.blocks(x, pos)\n            feature_list = [self.norm(x)[:,1:].transpose(-1, -2).contiguous() for x in feature_list]\n            x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152\n            return x, center, ori_idx, center_idx \n        else:\n            B, C, N = pts.shape\n            pts = pts.transpose(-1, -2)  # B N 3\n            # divide the point clo  ud in the same form. This is important\n            neighborhood, center, ori_idx, center_idx = self.group_divider(pts)\n\n            group_input_tokens = self.encoder(neighborhood)  # B G N\n\n            pos = self.pos_embed(center)\n            # final input\n            x = group_input_tokens\n            # transformer\n            feature_list = self.blocks(x, pos)\n            feature_list = [self.norm(x).transpose(-1, -2).contiguous() for x in feature_list]\n            x = torch.cat((feature_list[0],feature_list[1],feature_list[2]), dim=1) #1152\n            return x, center, ori_idx, center_idx\n        "
  },
  {
    "path": "models/pointnet2_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom time import time\nimport numpy as np\n\ndef timeit(tag, t):\n    print(\"{}: {}s\".format(tag, time() - t))\n    return time()\n\ndef pc_normalize(pc):\n    l = pc.shape[0]\n    centroid = np.mean(pc, axis=0)\n    pc = pc - centroid\n    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))\n    pc = pc / m\n    return pc\n\ndef square_distance(src, dst):\n    \"\"\"\n    Calculate Euclid distance between each two points.\n    src^T * dst = xn * xm + yn * ym + zn * zm;\n    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;\n    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;\n    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2\n         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst\n    Input:\n        src: source points, [B, N, C]\n        dst: target points, [B, M, C]\n    Output:\n        dist: per-point square distance, [B, N, M]\n    \"\"\"\n    B, N, _ = src.shape\n    _, M, _ = dst.shape\n    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))\n    dist += torch.sum(src ** 2, -1).view(B, N, 1)\n    dist += torch.sum(dst ** 2, -1).view(B, 1, M)\n    return dist\n\n\ndef index_points(points, idx):\n    \"\"\"\n    Input:\n        points: input points data, [B, N, C]\n        idx: sample index data, [B, S]\n    Return:\n        new_points:, indexed points data, [B, S, C]\n    \"\"\"\n    device = points.device\n    B = points.shape[0]\n    view_shape = list(idx.shape)\n    view_shape[1:] = [1] * (len(view_shape) - 1)\n    repeat_shape = list(idx.shape)\n    repeat_shape[0] = 1\n    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)\n    new_points = points[batch_indices, idx, :]\n    return new_points\n\n\ndef farthest_point_sample(xyz, npoint):\n    \"\"\"\n    Input:\n        xyz: pointcloud data, [B, N, 3]\n        npoint: number of samples\n    Return:\n        centroids: sampled pointcloud index, [B, npoint]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)\n    distance = torch.ones(B, N).to(device) * 1e10\n    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)\n    batch_indices = torch.arange(B, dtype=torch.long).to(device)\n    for i in range(npoint):\n        centroids[:, i] = farthest\n        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)\n        dist = torch.sum((xyz - centroid) ** 2, -1)\n        mask = dist < distance\n        distance[mask] = dist[mask]\n        farthest = torch.max(distance, -1)[1]\n    return centroids\n\n\ndef query_ball_point(radius, nsample, xyz, new_xyz):\n    \"\"\"\n    Input:\n        radius: local region radius\n        nsample: max sample number in local region\n        xyz: all points, [B, N, 3]\n        new_xyz: query points, [B, S, 3]\n    Return:\n        group_idx: grouped points index, [B, S, nsample]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    _, S, _ = new_xyz.shape\n    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])\n    sqrdists = square_distance(new_xyz, xyz)\n    group_idx[sqrdists > radius ** 2] = N\n    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]\n    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])\n    mask = group_idx == N\n    group_idx[mask] = group_first[mask]\n    return group_idx\n\n\ndef sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):\n    \"\"\"\n    Input:\n        npoint:\n        radius:\n        nsample:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, npoint, nsample, 3]\n        new_points: sampled points data, [B, npoint, nsample, 3+D]\n    \"\"\"\n    B, N, C = xyz.shape\n    S = npoint\n    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]\n    new_xyz = index_points(xyz, fps_idx)\n    idx = query_ball_point(radius, nsample, xyz, new_xyz)\n    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]\n    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)\n\n    if points is not None:\n        grouped_points = index_points(points, idx)\n        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]\n    else:\n        new_points = grouped_xyz_norm\n    if returnfps:\n        return new_xyz, new_points, grouped_xyz, fps_idx\n    else:\n        return new_xyz, new_points\n\n\ndef sample_and_group_all(xyz, points):\n    \"\"\"\n    Input:\n        xyz: input points position data, [B, N, 3]\n        points: input points data, [B, N, D]\n    Return:\n        new_xyz: sampled points position data, [B, 1, 3]\n        new_points: sampled points data, [B, 1, N, 3+D]\n    \"\"\"\n    device = xyz.device\n    B, N, C = xyz.shape\n    new_xyz = torch.zeros(B, 1, C).to(device)\n    grouped_xyz = xyz.view(B, 1, N, C)\n    if points is not None:\n        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)\n    else:\n        new_points = grouped_xyz\n    return new_xyz, new_points\n\ndef interpolating_points(xyz1, xyz2, points2):\n    \"\"\"\n    Input:\n        xyz1: input points position data, [B, C, N]\n        xyz2: sampled input points position data, [B, C, S]\n        points2: input points data, [B, D, S]\n    Return:\n        new_points: upsampled points data, [B, D', N]\n    \"\"\"\n    xyz1 = xyz1.permute(0, 2, 1)\n    xyz2 = xyz2.permute(0, 2, 1)\n\n    points2 = points2.permute(0, 2, 1)\n    B, N, C = xyz1.shape\n    _, S, _ = xyz2.shape\n\n    if S == 1:\n        interpolated_points = points2.repeat(1, N, 1)\n    else:\n        dists = square_distance(xyz1, xyz2)\n        dists, idx = dists.sort(dim=-1)\n        dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]\n\n        dist_recip = 1.0 / (dists + 1e-8)\n        norm = torch.sum(dist_recip, dim=2, keepdim=True)\n        weight = dist_recip / norm\n        interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)\n\n        interpolated_points = interpolated_points.permute(0, 2, 1)\n    return interpolated_points"
  },
  {
    "path": "requirements.txt",
    "content": "numpy\nPillow\nscikit-learn\nscipy\ntimm\ntorch\ntorchvision\ntqdm\nwget\ntifffile\nscikit-image\nkornia\nimageio\ntensorboard\nopencv-python\nsetuptools==59.5.0;"
  },
  {
    "path": "utils/au_pro_util.py",
    "content": "\"\"\"\nCode based on the official MVTec 3D-AD evaluation code found at\nhttps://www.mydrive.ch/shares/45924/9ce7a138c69bbd4c8d648b72151f839d/download/428846918-1643297332/evaluation_code.tar.xz\n\nUtility functions that compute a PRO curve and its definite integral, given\npairs of anomaly and ground truth maps.\n\nThe PRO curve can also be integrated up to a constant integration limit.\n\"\"\"\nimport numpy as np\nfrom scipy.ndimage.measurements import label\nfrom bisect import bisect\n\n\nclass GroundTruthComponent:\n    \"\"\"\n    Stores sorted anomaly scores of a single ground truth component.\n    Used to efficiently compute the region overlap for many increasing thresholds.\n    \"\"\"\n\n    def __init__(self, anomaly_scores):\n        \"\"\"\n        Initialize the module.\n\n        Args:\n            anomaly_scores: List of all anomaly scores within the ground truth\n                            component as numpy array.\n        \"\"\"\n        # Keep a sorted list of all anomaly scores within the component.\n        self.anomaly_scores = anomaly_scores.copy()\n        self.anomaly_scores.sort()\n\n        # Pointer to the anomaly score where the current threshold divides the component into OK / NOK pixels.\n        self.index = 0\n\n        # The last evaluated threshold.\n        self.last_threshold = None\n\n    def compute_overlap(self, threshold):\n        \"\"\"\n        Compute the region overlap for a specific threshold.\n        Thresholds must be passed in increasing order.\n\n        Args:\n            threshold: Threshold to compute the region overlap.\n\n        Returns:\n            Region overlap for the specified threshold.\n        \"\"\"\n        if self.last_threshold is not None:\n            assert self.last_threshold <= threshold\n\n        # Increase the index until it points to an anomaly score that is just above the specified threshold.\n        while (self.index < len(self.anomaly_scores) and self.anomaly_scores[self.index] <= threshold):\n            self.index += 1\n\n        # Compute the fraction of component pixels that are correctly segmented as anomalous.\n        return 1.0 - self.index / len(self.anomaly_scores)\n\n\ndef trapezoid(x, y, x_max=None):\n    \"\"\"\n    This function calculates the definit integral of a curve given by x- and corresponding y-values.\n    In contrast to, e.g., 'numpy.trapz()', this function allows to define an upper bound to the integration range by\n    setting a value x_max.\n\n    Points that do not have a finite x or y value will be ignored with a warning.\n\n    Args:\n        x:     Samples from the domain of the function to integrate need to be sorted in ascending order. May contain\n               the same value multiple times. In that case, the order of the corresponding y values will affect the\n               integration with the trapezoidal rule.\n        y:     Values of the function corresponding to x values.\n        x_max: Upper limit of the integration. The y value at max_x will be determined by interpolating between its\n               neighbors. Must not lie outside of the range of x.\n\n    Returns:\n        Area under the curve.\n    \"\"\"\n\n    x = np.array(x)\n    y = np.array(y)\n    finite_mask = np.logical_and(np.isfinite(x), np.isfinite(y))\n    if not finite_mask.all():\n        print(\n            \"\"\"WARNING: Not all x and y values passed to trapezoid are finite. Will continue with only the finite values.\"\"\")\n    x = x[finite_mask]\n    y = y[finite_mask]\n\n    # Introduce a correction term if max_x is not an element of x.\n    correction = 0.\n    if x_max is not None:\n        if x_max not in x:\n            # Get the insertion index that would keep x sorted after np.insert(x, ins, x_max).\n            ins = bisect(x, x_max)\n            # x_max must be between the minimum and the maximum, so the insertion_point cannot be zero or len(x).\n            assert 0 < ins < len(x)\n\n            # Calculate the correction term which is the integral between the last x[ins-1] and x_max. Since we do not\n            # know the exact value of y at x_max, we interpolate between y[ins] and y[ins-1].\n            y_interp = y[ins - 1] + ((y[ins] - y[ins - 1]) * (x_max - x[ins - 1]) / (x[ins] - x[ins - 1]))\n            correction = 0.5 * (y_interp + y[ins - 1]) * (x_max - x[ins - 1])\n\n        # Cut off at x_max.\n        mask = x <= x_max\n        x = x[mask]\n        y = y[mask]\n\n    # Return area under the curve using the trapezoidal rule.\n    return np.sum(0.5 * (y[1:] + y[:-1]) * (x[1:] - x[:-1])) + correction\n\n\ndef collect_anomaly_scores(anomaly_maps, ground_truth_maps):\n    \"\"\"\n    Extract anomaly scores for each ground truth connected component as well as anomaly scores for each potential false\n    positive pixel from anomaly maps.\n\n    Args:\n        anomaly_maps:      List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.\n\n        ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels\n                           for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains\n                           an anomaly.\n\n    Returns:\n        ground_truth_components: A list of all ground truth connected components that appear in the dataset. For each\n                                 component, a sorted list of its anomaly scores is stored.\n\n        anomaly_scores_ok_pixels: A sorted list of anomaly scores of all anomaly-free pixels of the dataset. This list\n                                  can be used to quickly select thresholds that fix a certain false positive rate.\n    \"\"\"\n    # Make sure an anomaly map is present for each ground truth map.\n    assert len(anomaly_maps) == len(ground_truth_maps)\n\n    # Initialize ground truth components and scores of potential fp pixels.\n    ground_truth_components = []\n    anomaly_scores_ok_pixels = np.zeros(len(ground_truth_maps) * ground_truth_maps[0].size)\n\n    # Structuring element for computing connected components.\n    structure = np.ones((3, 3), dtype=int)\n\n    # Collect anomaly scores within each ground truth region and for all potential fp pixels.\n    ok_index = 0\n    for gt_map, prediction in zip(ground_truth_maps, anomaly_maps):\n\n        # Compute the connected components in the ground truth map.\n        labeled, n_components = label(gt_map, structure)\n\n        # Store all potential fp scores.\n        num_ok_pixels = len(prediction[labeled == 0])\n        anomaly_scores_ok_pixels[ok_index:ok_index + num_ok_pixels] = prediction[labeled == 0].copy()\n        ok_index += num_ok_pixels\n\n        # Fetch anomaly scores within each GT component.\n        for k in range(n_components):\n            component_scores = prediction[labeled == (k + 1)]\n            ground_truth_components.append(GroundTruthComponent(component_scores))\n\n    # Sort all potential false positive scores.\n    anomaly_scores_ok_pixels = np.resize(anomaly_scores_ok_pixels, ok_index)\n    anomaly_scores_ok_pixels.sort()\n\n    return ground_truth_components, anomaly_scores_ok_pixels\n\n\ndef compute_pro(anomaly_maps, ground_truth_maps, num_thresholds):\n    \"\"\"\n    Compute the PRO curve at equidistant interpolation points for a set of anomaly maps with corresponding ground\n    truth maps. The number of interpolation points can be set manually.\n\n    Args:\n        anomaly_maps:      List of anomaly maps (2D numpy arrays) that contain a real-valued anomaly score at each pixel.\n\n        ground_truth_maps: List of ground truth maps (2D numpy arrays) that contain binary-valued ground truth labels\n                           for each pixel. 0 indicates that a pixel is anomaly-free. 1 indicates that a pixel contains\n                           an anomaly.\n\n        num_thresholds:    Number of thresholds to compute the PRO curve.\n    Returns:\n        fprs: List of false positive rates.\n        pros: List of correspoding PRO values.\n    \"\"\"\n    # Fetch sorted anomaly scores.\n    ground_truth_components, anomaly_scores_ok_pixels = collect_anomaly_scores(anomaly_maps, ground_truth_maps)\n\n    # Select equidistant thresholds.\n    threshold_positions = np.linspace(0, len(anomaly_scores_ok_pixels) - 1, num=num_thresholds, dtype=int)\n\n    fprs = [1.0]\n    pros = [1.0]\n    for pos in threshold_positions:\n        threshold = anomaly_scores_ok_pixels[pos]\n\n        # Compute the false positive rate for this threshold.\n        fpr = 1.0 - (pos + 1) / len(anomaly_scores_ok_pixels)\n\n        # Compute the PRO value for this threshold.\n        pro = 0.0\n        for component in ground_truth_components:\n            pro += component.compute_overlap(threshold)\n        pro /= len(ground_truth_components)\n\n        fprs.append(fpr)\n        pros.append(pro)\n\n    # Return (FPR/PRO) pairs in increasing FPR order.\n    fprs = fprs[::-1]\n    pros = pros[::-1]\n\n    return fprs, pros\n\n\ndef calculate_au_pro(gts, predictions, integration_limit=0.3, num_thresholds=100):\n    \"\"\"\n    Compute the area under the PRO curve for a set of ground truth images and corresponding anomaly images.\n    Args:\n        gts:         List of tensors that contain the ground truth images for a single dataset object.\n        predictions: List of tensors containing anomaly images for each ground truth image.\n        integration_limit:    Integration limit to use when computing the area under the PRO curve.\n        num_thresholds:       Number of thresholds to use to sample the area under the PRO curve.\n\n    Returns:\n        au_pro:    Area under the PRO curve computed up to the given integration limit.\n        pro_curve: PRO curve values for localization (fpr,pro).\n    \"\"\"\n    # Compute the PRO curve.\n    pro_curve = compute_pro(anomaly_maps=predictions, ground_truth_maps=gts, num_thresholds=num_thresholds)\n\n    # Compute the area under the PRO curve.\n    au_pro = trapezoid(pro_curve[0], pro_curve[1], x_max=integration_limit)\n    au_pro /= integration_limit\n\n    # Return the evaluation metrics.\n    return au_pro, pro_curve\n"
  },
  {
    "path": "utils/lr_sched.py",
    "content": "import math\n\ndef adjust_learning_rate(optimizer, epoch, args):\n    \"\"\"Decay the learning rate with half-cycle cosine after warmup\"\"\"\n    if epoch < args.warmup_epochs:\n        lr = args.lr * epoch / args.warmup_epochs \n    else:\n        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \\\n            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))\n    for param_group in optimizer.param_groups:\n        if \"lr_scale\" in param_group:\n            param_group[\"lr\"] = lr * param_group[\"lr_scale\"]\n        else:\n            param_group[\"lr\"] = lr\n    return lr\n"
  },
  {
    "path": "utils/misc.py",
    "content": "import builtins\nimport datetime\nimport os\nimport time\nfrom collections import defaultdict, deque\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\nfrom torch._six import inf\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.4f} ({global_avg:.4f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    builtin_print = builtins.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        force = force or (get_world_size() > 8)\n        if is_master or force:\n            now = datetime.datetime.now().time()\n            builtin_print('[{}] '.format(now), end='')  # print with time stamp\n            builtin_print(*args, **kwargs)\n\n    builtins.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef init_distributed_mode(args):\n    if args.dist_on_itp:\n        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])\n        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])\n        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])\n        args.dist_url = \"tcp://%s:%s\" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])\n        os.environ['LOCAL_RANK'] = str(args.gpu)\n        os.environ['RANK'] = str(args.rank)\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n        # [\"RANK\", \"WORLD_SIZE\", \"MASTER_ADDR\", \"MASTER_PORT\", \"LOCAL_RANK\"]\n    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print('Not using distributed mode')\n        setup_for_distributed(is_master=True)  # hack\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}, gpu {}'.format(\n        args.rank, args.dist_url, args.gpu), flush=True)\n    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n                                         world_size=args.world_size, rank=args.rank)\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\nclass NativeScalerWithGradNormCount:\n    state_dict_key = \"amp_scaler\"\n\n    def __init__(self):\n        self._scaler = torch.cuda.amp.GradScaler()\n\n    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):\n        self._scaler.scale(loss).backward(create_graph=create_graph)\n        if update_grad:\n            if clip_grad is not None:\n                assert parameters is not None\n                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place\n                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)\n            else:\n                self._scaler.unscale_(optimizer)\n                norm = get_grad_norm_(parameters)\n            self._scaler.step(optimizer)\n            self._scaler.update()\n        else:\n            norm = None\n        return norm\n\n    def state_dict(self):\n        return self._scaler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self._scaler.load_state_dict(state_dict)\n\n\ndef get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return torch.tensor(0.)\n    device = parameters[0].grad.device\n    if norm_type == inf:\n        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)\n    return total_norm\n\n\ndef save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):\n    output_dir = Path(args.output_dir)\n    epoch_name = str(epoch)\n    if loss_scaler is not None:\n        checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]\n        for checkpoint_path in checkpoint_paths:\n            to_save = {\n                'model': model_without_ddp.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'epoch': epoch,\n                'scaler': loss_scaler.state_dict(),\n                'args': args,\n            }\n\n            save_on_master(to_save, checkpoint_path)\n    else:\n        client_state = {'epoch': epoch}\n        model.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint-%s\" % epoch_name, client_state=client_state)\n\ndef save_model_gan(args, epoch, model, discriminator, model_without_ddp, discriminator_without_ddp,\n                    optimizer_g, optimizer_d, loss_scaler):\n    output_dir = Path(args.output_dir)\n    epoch_name = str(epoch)\n    if loss_scaler is not None:\n        checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]\n        for checkpoint_path in checkpoint_paths:\n            to_save = {\n                'model': model_without_ddp.state_dict(),\n                'discriminator_without_ddp': discriminator_without_ddp.state_dict(),\n                'optimizer_g': optimizer_g.state_dict(),\n                'optimizer_d': optimizer_d.state_dict(),\n                'epoch': epoch,\n                'scaler': loss_scaler.state_dict(),\n                'args': args,\n            }\n\n            save_on_master(to_save, checkpoint_path)\n    else:\n        client_state = {'epoch': epoch}\n        model.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint-%s\" % epoch_name, client_state=client_state)\n        discriminator.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint_d-%s\" % epoch_name, client_state=client_state)\n\ndef load_model(args, model_without_ddp, optimizer, loss_scaler):\n    if args.resume:\n        if args.resume.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.resume, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.resume, map_location='cpu')\n        model_without_ddp.load_state_dict(checkpoint['model'])\n        print(\"Resume checkpoint %s\" % args.resume)\n        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):\n            optimizer.load_state_dict(checkpoint['optimizer'])\n            args.start_epoch = checkpoint['epoch'] + 1\n            if 'scaler' in checkpoint:\n                loss_scaler.load_state_dict(checkpoint['scaler'])\n            print(\"With optim & sched!\")\n\ndef load_model_gan(args, model_without_ddp, discriminator_without_ddp,\n                    optimizer_g, optimizer_d, loss_scaler):\n    if args.resume:\n        if args.resume.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.resume, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.resume, map_location='cpu')\n        model_without_ddp.load_state_dict(checkpoint['model'])\n        discriminator_without_ddp.load_state_dict(checkpoint['discriminator'])\n        print(\"Resume checkpoint %s\" % args.resume)\n        if 'optimizer_d' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):\n            optimizer_d.load_state_dict(checkpoint['optimizer_d'])\n        if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):\n            optimizer_g.load_state_dict(checkpoint['optimizer_g'])\n            args.start_epoch = checkpoint['epoch'] + 1\n            if 'scaler' in checkpoint:\n                loss_scaler.load_state_dict(checkpoint['scaler'])\n            print(\"With optim & sched!\")\n\n\n\ndef all_reduce_mean(x):\n    world_size = get_world_size()\n    if world_size > 1:\n        x_reduce = torch.tensor(x).cuda()\n        dist.all_reduce(x_reduce)\n        x_reduce /= world_size\n        return x_reduce.item()\n    else:\n        return x"
  },
  {
    "path": "utils/mvtec3d_util.py",
    "content": "import tifffile as tiff\nimport torch\n\n\ndef organized_pc_to_unorganized_pc(organized_pc):\n    return organized_pc.reshape(organized_pc.shape[0] * organized_pc.shape[1], organized_pc.shape[2])\n\n\ndef read_tiff_organized_pc(path):\n    tiff_img = tiff.imread(path)\n    return tiff_img\n\n\ndef resize_organized_pc(organized_pc, target_height=224, target_width=224, tensor_out=True):\n    torch_organized_pc = torch.tensor(organized_pc).permute(2, 0, 1).unsqueeze(dim=0).contiguous()\n    torch_resized_organized_pc = torch.nn.functional.interpolate(torch_organized_pc, size=(target_height, target_width),\n                                                                 mode='nearest')\n    if tensor_out:\n        return torch_resized_organized_pc.squeeze(dim=0).contiguous()\n    else:\n        return torch_resized_organized_pc.squeeze().permute(1, 2, 0).contiguous().numpy()\n\n\ndef organized_pc_to_depth_map(organized_pc):\n    return organized_pc[:, :, 2]\n"
  },
  {
    "path": "utils/preprocess_eyecandies.py",
    "content": "import os\r\nfrom shutil import copyfile\r\nimport cv2\r\nimport numpy as np\r\nimport tifffile\r\nimport yaml\r\nimport imageio.v3 as iio\r\nimport math\r\nimport argparse\r\n\r\n# The same camera has been used for all the images\r\nFOCAL_LENGTH = 711.11\r\n\r\ndef load_and_convert_depth(depth_img, info_depth):\r\n    with open(info_depth) as f:\r\n        data = yaml.safe_load(f)\r\n    mind, maxd = data[\"normalization\"][\"min\"], data[\"normalization\"][\"max\"]\r\n\r\n    dimg = iio.imread(depth_img)\r\n    dimg = dimg.astype(np.float32)\r\n    dimg = dimg / 65535.0 * (maxd - mind) + mind\r\n    return dimg\r\n\r\ndef depth_to_pointcloud(depth_img, info_depth, pose_txt, focal_length):\r\n    # input depth map (in meters) --- cfr previous section\r\n    depth_mt = load_and_convert_depth(depth_img, info_depth)\r\n\r\n    # input pose\r\n    pose = np.loadtxt(pose_txt)\r\n\r\n    # camera intrinsics\r\n    height, width = depth_mt.shape[:2]\r\n    intrinsics_4x4 = np.array([\r\n        [focal_length, 0, width / 2, 0],\r\n        [0, focal_length, height / 2, 0],\r\n        [0, 0, 1, 0],\r\n        [0, 0, 0, 1]]\r\n    )\r\n\r\n    # build the camera projection matrix\r\n    camera_proj = intrinsics_4x4 @ pose\r\n\r\n    # build the (u, v, 1, 1/depth) vectors (non optimized version)\r\n    camera_vectors = np.zeros((width * height, 4))\r\n    count=0\r\n    for j in range(height):\r\n        for i in range(width):\r\n            camera_vectors[count, :] = np.array([i, j, 1, 1/depth_mt[j, i]])\r\n            count += 1\r\n\r\n    # invert and apply to each 4-vector\r\n    hom_3d_pts= np.linalg.inv(camera_proj) @ camera_vectors.T\r\n    # print(hom_3d_pts.shape)\r\n    # remove the homogeneous coordinate\r\n    pcd = depth_mt.reshape(-1, 1) * hom_3d_pts.T\r\n    return pcd[:, :3]\r\n\r\ndef remove_point_cloud_background(pc):\r\n\r\n    # The second dim is z\r\n    dz =  pc[256,1] - pc[-256,1]\r\n    dy =  pc[256,2] - pc[-256,2]\r\n\r\n    norm =  math.sqrt(dz**2 + dy**2)\r\n    start_points = np.array([0, pc[-256, 1], pc[-256, 2]])\r\n    cos_theta = dy / norm\r\n    sin_theta = dz / norm\r\n\r\n    # Transform and rotation\r\n    rotation_matrix = np.array([[1, 0, 0], [0, cos_theta, -sin_theta],[0, sin_theta, cos_theta]])\r\n    processed_pc = (rotation_matrix @ (pc - start_points).T).T\r\n\r\n    # Remove background point\r\n    for i in range(processed_pc.shape[0]):\r\n        if processed_pc[i,1] > -0.02:\r\n            processed_pc[i, :] = -start_points\r\n        if processed_pc[i,2] > 1.8:\r\n            processed_pc[i, :] = -start_points\r\n        elif processed_pc[i,0] > 1 or processed_pc[i,0] < -1:\r\n            processed_pc[i, :] = -start_points\r\n\r\n    processed_pc = (rotation_matrix.T @ processed_pc.T).T + start_points\r\n\r\n    index = [0, 2, 1]\r\n    processed_pc = processed_pc[:,index]\r\n    return processed_pc*[0.1, -0.1, 0.1]\r\n\r\n\r\nif __name__ == '__main__':\r\n\r\n    parser = argparse.ArgumentParser(description='Process some integers.')\r\n    parser.add_argument('--dataset_path', default='datasets/eyecandies', type=str, help=\"Original Eyecandies dataset path.\")\r\n    parser.add_argument('--target_dir', default='datasets/eyecandies_preprocessed', type=str, help=\"Processed Eyecandies dataset path\")\r\n    args = parser.parse_args()\r\n    \r\n    os.mkdir(args.target_dir)\r\n    categories_list = os.listdir(args.dataset_path)\r\n\r\n    for category_dir in categories_list:\r\n        category_root_path = os.path.join(args.dataset_path, category_dir)\r\n\r\n        category_train_path = os.path.join(category_root_path, '/train/data')\r\n        category_test_path = os.path.join(category_root_path, '/test_public/data')\r\n\r\n        category_target_path = os.path.join(args.target_dir, category_dir)\r\n        os.mkdir(category_target_path)\r\n\r\n        os.mkdir(os.path.join(category_target_path, 'train'))\r\n        category_target_train_good_path = os.path.join(category_target_path, 'train/good')\r\n        category_target_train_good_rgb_path = os.path.join(category_target_train_good_path, 'rgb')\r\n        category_target_train_good_xyz_path = os.path.join(category_target_train_good_path, 'xyz')\r\n        os.mkdir(category_target_train_good_path)\r\n        os.mkdir(category_target_train_good_rgb_path)\r\n        os.mkdir(category_target_train_good_xyz_path)\r\n\r\n        os.mkdir(os.path.join(category_target_path, 'test'))\r\n        category_target_test_good_path = os.path.join(category_target_path, 'test/good')\r\n        category_target_test_good_rgb_path = os.path.join(category_target_test_good_path, 'rgb')\r\n        category_target_test_good_xyz_path = os.path.join(category_target_test_good_path, 'xyz')\r\n        category_target_test_good_gt_path = os.path.join(category_target_test_good_path, 'gt')\r\n        os.mkdir(category_target_test_good_path)\r\n        os.mkdir(category_target_test_good_rgb_path)\r\n        os.mkdir(category_target_test_good_xyz_path)\r\n        os.mkdir(category_target_test_good_gt_path)\r\n        category_target_test_bad_path = os.path.join(category_target_path, 'test/bad')\r\n        category_target_test_bad_rgb_path = os.path.join(category_target_test_bad_path, 'rgb')\r\n        category_target_test_bad_xyz_path = os.path.join(category_target_test_bad_path, 'xyz')\r\n        category_target_test_bad_gt_path = os.path.join(category_target_test_bad_path, 'gt')\r\n        os.mkdir(category_target_test_bad_path)\r\n        os.mkdir(category_target_test_bad_rgb_path)\r\n        os.mkdir(category_target_test_bad_xyz_path)\r\n        os.mkdir(category_target_test_bad_gt_path)\r\n\r\n        category_train_files = os.listdir(category_train_path)\r\n        num_train_files = len(category_train_files)//17\r\n        for i in range(0, num_train_files):\r\n            pc = depth_to_pointcloud(\r\n                    os.path.join(category_train_path,str(i).zfill(3)+'_depth.png'),\r\n                    os.path.join(category_train_path,str(i).zfill(3)+'_info_depth.yaml'),\r\n                    os.path.join(category_train_path,str(i).zfill(3)+'_pose.txt'),\r\n                    FOCAL_LENGTH,\r\n                )\r\n            pc = remove_point_cloud_background(pc)\r\n            pc = pc.reshape(512,512,3)\r\n            tifffile.imwrite(os.path.join(category_target_train_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)\r\n            copyfile(os.path.join(category_train_path,str(i).zfill(3)+'_image_4.png'),os.path.join(category_target_train_good_rgb_path, str(i).zfill(3)+'.png'))\r\n            \r\n        \r\n        category_test_files = os.listdir(category_test_path)\r\n        num_test_files = len(category_test_files)//17\r\n        for i in range(0, num_test_files):\r\n            mask = cv2.imread(os.path.join(category_test_path,str(i).zfill(2)+'_mask.png'))\r\n            if np.any(mask):\r\n                pc = depth_to_pointcloud(\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),\r\n                    FOCAL_LENGTH,\r\n                    )\r\n                pc = remove_point_cloud_background(pc)\r\n                pc = pc.reshape(512,512,3)\r\n                tifffile.imwrite(os.path.join(category_target_test_bad_xyz_path, str(i).zfill(3)+'.tiff'), pc)\r\n                cv2.imwrite(os.path.join(category_target_test_bad_gt_path, str(i).zfill(3)+'.png'), mask)\r\n                copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_bad_rgb_path, str(i).zfill(3)+'.png'))\r\n            else:\r\n                pc = depth_to_pointcloud(\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_depth.png'),\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_info_depth.yaml'),\r\n                    os.path.join(category_test_path,str(i).zfill(2)+'_pose.txt'),\r\n                    FOCAL_LENGTH,\r\n                    )\r\n                pc = remove_point_cloud_background(pc)\r\n                pc = pc.reshape(512,512,3)\r\n                tifffile.imwrite(os.path.join(category_target_test_good_xyz_path, str(i).zfill(3)+'.tiff'), pc)\r\n                cv2.imwrite(os.path.join(category_target_test_good_gt_path, str(i).zfill(3)+'.png'), mask)\r\n                copyfile(os.path.join(category_test_path,str(i).zfill(2)+'_image_4.png'),os.path.join(category_target_test_good_rgb_path, str(i).zfill(3)+'.png'))\r\n"
  },
  {
    "path": "utils/preprocessing.py",
    "content": "import os\nimport numpy as np\nimport tifffile as tiff\nimport open3d as o3d\nfrom pathlib import Path\nfrom PIL import Image\nimport math\nimport mvtec3d_util as mvt_util\nimport argparse\n\n\ndef get_edges_of_pc(organized_pc):\n    unorganized_edges_pc = organized_pc[0:10, :, :].reshape(organized_pc[0:10, :, :].shape[0]*organized_pc[0:10, :, :].shape[1],organized_pc[0:10, :, :].shape[2])\n    unorganized_edges_pc = np.concatenate([unorganized_edges_pc,organized_pc[-10:, :, :].reshape(organized_pc[-10:, :, :].shape[0] * organized_pc[-10:, :, :].shape[1],organized_pc[-10:, :, :].shape[2])],axis=0)\n    unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, 0:10, :].reshape(organized_pc[:, 0:10, :].shape[0] * organized_pc[:, 0:10, :].shape[1],organized_pc[:, 0:10, :].shape[2])], axis=0)\n    unorganized_edges_pc = np.concatenate([unorganized_edges_pc, organized_pc[:, -10:, :].reshape(organized_pc[:, -10:, :].shape[0] * organized_pc[:, -10:, :].shape[1],organized_pc[:, -10:, :].shape[2])], axis=0)\n    unorganized_edges_pc = unorganized_edges_pc[np.nonzero(np.all(unorganized_edges_pc != 0, axis=1))[0],:]\n    return unorganized_edges_pc\n\ndef get_plane_eq(unorganized_pc,ransac_n_pts=50):\n    o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc))\n    plane_model, inliers = o3d_pc.segment_plane(distance_threshold=0.004, ransac_n=ransac_n_pts, num_iterations=1000)\n    return plane_model\n\ndef remove_plane(organized_pc_clean, organized_rgb ,distance_threshold=0.005):\n    # PREP PC\n    unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc_clean)\n    unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)\n    clean_planeless_unorganized_pc = unorganized_pc.copy()\n    planeless_unorganized_rgb = unorganized_rgb.copy()\n\n    # REMOVE PLANE\n    plane_model = get_plane_eq(get_edges_of_pc(organized_pc_clean))\n    distances = np.abs(np.dot(np.array(plane_model), np.hstack((clean_planeless_unorganized_pc, np.ones((clean_planeless_unorganized_pc.shape[0], 1)))).T))\n    plane_indices = np.argwhere(distances < distance_threshold)\n\n    planeless_unorganized_rgb[plane_indices] = 0\n    clean_planeless_unorganized_pc[plane_indices] = 0\n    clean_planeless_organized_pc = clean_planeless_unorganized_pc.reshape(organized_pc_clean.shape[0],\n                                                                          organized_pc_clean.shape[1],\n                                                                          organized_pc_clean.shape[2])\n    planeless_organized_rgb = planeless_unorganized_rgb.reshape(organized_rgb.shape[0],\n                                                                          organized_rgb.shape[1],\n                                                                          organized_rgb.shape[2])\n    return clean_planeless_organized_pc, planeless_organized_rgb\n\n\n\ndef connected_components_cleaning(organized_pc, organized_rgb, image_path):\n    unorganized_pc = mvt_util.organized_pc_to_unorganized_pc(organized_pc)\n    unorganized_rgb = mvt_util.organized_pc_to_unorganized_pc(organized_rgb)\n\n    nonzero_indices = np.nonzero(np.all(unorganized_pc != 0, axis=1))[0]\n    unorganized_pc_no_zeros = unorganized_pc[nonzero_indices, :]\n    o3d_pc = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(unorganized_pc_no_zeros))\n    labels = np.array(o3d_pc.cluster_dbscan(eps=0.006, min_points=30, print_progress=False))\n\n\n    unique_cluster_ids, cluster_size = np.unique(labels,return_counts=True)\n    max_label = labels.max()\n    if max_label>0:\n        print(\"##########################################################################\")\n        print(f\"Point cloud file {image_path} has {max_label + 1} clusters\")\n        print(f\"Cluster ids: {unique_cluster_ids}. Cluster size {cluster_size}\")\n        print(\"##########################################################################\\n\\n\")\n\n    largest_cluster_id = unique_cluster_ids[np.argmax(cluster_size)]\n    outlier_indices_nonzero_array = np.argwhere(labels != largest_cluster_id)\n    outlier_indices_original_pc_array = nonzero_indices[outlier_indices_nonzero_array]\n    unorganized_pc[outlier_indices_original_pc_array] = 0\n    unorganized_rgb[outlier_indices_original_pc_array] = 0\n    organized_clustered_pc = unorganized_pc.reshape(organized_pc.shape[0],\n                                                                          organized_pc.shape[1],\n                                                                          organized_pc.shape[2])\n    organized_clustered_rgb = unorganized_rgb.reshape(organized_rgb.shape[0],\n                                                    organized_rgb.shape[1],\n                                                    organized_rgb.shape[2])\n    return organized_clustered_pc, organized_clustered_rgb\n\ndef roundup_next_100(x):\n    return int(math.ceil(x / 100.0)) * 100\n\ndef pad_cropped_pc(cropped_pc, single_channel=False):\n    orig_h, orig_w = cropped_pc.shape[0], cropped_pc.shape[1]\n    round_orig_h = roundup_next_100(orig_h)\n    round_orig_w = roundup_next_100(orig_w)\n    large_side = max(round_orig_h, round_orig_w)\n\n    a = (large_side - orig_h) // 2\n    aa = large_side - a - orig_h\n\n    b = (large_side - orig_w) // 2\n    bb = large_side - b - orig_w\n    if single_channel:\n        return np.pad(cropped_pc, pad_width=((a, aa), (b, bb)), mode='constant')\n    else:\n        return np.pad(cropped_pc, pad_width=((a, aa), (b, bb), (0, 0)), mode='constant')\n\ndef preprocess_pc(tiff_path):\n    # READ FILES\n    organized_pc = mvt_util.read_tiff_organized_pc(tiff_path)\n    rgb_path = str(tiff_path).replace(\"xyz\", \"rgb\").replace(\"tiff\", \"png\")\n    gt_path = str(tiff_path).replace(\"xyz\", \"gt\").replace(\"tiff\", \"png\")\n    organized_rgb = np.array(Image.open(rgb_path))\n\n    organized_gt = None\n    gt_exists = os.path.isfile(gt_path)\n    if gt_exists:\n        organized_gt = np.array(Image.open(gt_path))\n\n    # REMOVE PLANE\n    planeless_organized_pc, planeless_organized_rgb = remove_plane(organized_pc, organized_rgb)\n\n\n    # PAD WITH ZEROS TO LARGEST SIDE (SO THAT THE FINAL IMAGE IS SQUARE)\n    padded_planeless_organized_pc = pad_cropped_pc(planeless_organized_pc, single_channel=False)\n    padded_planeless_organized_rgb = pad_cropped_pc(planeless_organized_rgb, single_channel=False)\n    if gt_exists:\n       padded_organized_gt = pad_cropped_pc(organized_gt, single_channel=True)\n\n    organized_clustered_pc, organized_clustered_rgb = connected_components_cleaning(padded_planeless_organized_pc, padded_planeless_organized_rgb, tiff_path)\n    # SAVE PREPROCESSED FILES\n    tiff.imsave(tiff_path, organized_clustered_pc)\n    Image.fromarray(organized_clustered_rgb).save(rgb_path)\n    if gt_exists:\n       Image.fromarray(padded_organized_gt).save(gt_path)\n\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='Preprocess MVTec 3D-AD')\n    parser.add_argument('dataset_path', type=str, help='The root path of the MVTec 3D-AD. The preprocessing is done inplace (i.e. the preprocessed dataset overrides the existing one)')\n    args = parser.parse_args()\n\n\n    root_path = args.dataset_path\n    paths = Path(root_path).rglob('*.tiff')\n    print(f\"Found {len(list(paths))} tiff files in {root_path}\")\n    processed_files = 0\n    for path in Path(root_path).rglob('*.tiff'):\n        preprocess_pc(path)\n        processed_files += 1\n        if processed_files % 50 == 0:\n            print(f\"Processed {processed_files} tiff files...\")\n\n\n\n\n\n\n\n\n\n"
  },
  {
    "path": "utils/utils.py",
    "content": "import numpy as np\nimport random\nimport torch\nfrom torchvision import transforms\nfrom PIL import ImageFilter\n\ndef set_seeds(seed: int = 0) -> None:\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.manual_seed(seed)\n\nclass KNNGaussianBlur(torch.nn.Module):\n    def __init__(self, radius : int = 4):\n        super().__init__()\n        self.radius = radius\n        self.unload = transforms.ToPILImage()\n        self.load = transforms.ToTensor()\n        self.blur_kernel = ImageFilter.GaussianBlur(radius=4)\n\n    def __call__(self, img):\n        map_max = img.max()\n        final_map = self.load(self.unload(img[0] / map_max).filter(self.blur_kernel)) * map_max\n        return final_map\n"
  }
]