[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\nbuild/\ndist/\n*.egg-info/\n.eggs/\n\n# Virtual environment\nvenv/\nenv/\n.venv/\n.env/\n\n# Jupyter Notebook checkpoints\n.ipynb_checkpoints/\n\n# PyInstaller\n*.manifest\n*.spec\n\n# pytest\n.cache/\n.pytest_cache/\n\n# mypy\n.mypy_cache/\n\n# coverage\nhtmlcov/\n.coverage\n.coverage.*\n\n# logs and temporary files\n*.log\n*.tmp\n*.bak\n\n# IDEs and editors\n.vscode/\n.idea/\n*.sublime-workspace\n*.sublime-project\n\n# OS files\n.DS_Store\nThumbs.db\n\n# dotenv / secrets\n.env\n.env.*\n\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2023 jumin\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data\n\n📌[Paper](http://arxiv.org/abs/2301.00527)        \n\n<img src=https://user-images.githubusercontent.com/65997635/210452550-2c7c7c6d-7260-43ce-b4b6-18d3f15fccde.png width=\"480\"\n  height=\"400\">\n\nComparison of object-scale and scene scale generation (ours). Our result includes multiple objects in a generated scene,\nwhile the object-scale generation crafts one object at a time. (a) is obtained by [Point-E](https://github.com/openai/point-e)\n\n## Abstract\nIn this paper, we learn a diffusion model to generate 3D data on a scene-scale. Specifically, our model crafts a 3D scene consisting of multiple objects, while recent diffusion research has focused on a single object. To realize our goal, we represent a scene with discrete class labels, i.e., categorical distribution, to assign multiple objects into semantic categories. Thus, we extend discrete diffusion models to learn scene-scale categorical distributions. In addition, we validate that a latent diffusion model can reduce computation costs for training and deploying. To the best of our knowledge, our work is the first to apply discrete and latent diffusion for 3D categorical data on a scene-scale. We further propose to perform semantic scene completion (SSC) by learning a conditional distribution using our diffusion model, where the condition is a partial observation in a sparse point cloud. In experiments, we empirically show that our diffusion models not only generate reasonable scenes, but also perform the scene completion task better than a discriminative model. \n\n\n## Instructions\n### Dataset\n: We use [CarlaSC](https://umich-curly.github.io/CarlaSC.github.io/download/) cartesian dataset.\n\n### Training\n: There are some argparse in 'SSC_train.py'.\n    \n    python SSC_train.py \n    \n- For **multi-GPU** : --distribution True\n- For **Discrete Diffusion Model** : --mode gen/con/vis\n- For **Latent Diffusion Model** : --mode l_vae/l_gen --l_size 882/16162/32322 --init_size 32 --l_attention True --vq_size 100\n\nExample for training l_gen mode\n  \n    python SSC_train.py --mode l_gen --vq_size 100 --l_size 32322 --init_size 32 --l_attention True --log_path ./result --vqvae_path ./lst_stage.tar\n\n\n### Visualization\n: We save the result to a txt file using the `utils/table.py/visulization` function. \nIf you use open3d, you will be able to easily visualize it.\n\n## Result\n### 3D Scene Generation\n![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/3D_scene_generation.png?raw=true)\n\n### Semantic Scene Completion\n![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/table4.PNG?raw=true)\n\n\n![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/semantic_scene_completion.png?raw=true)\n\n\n## Acknowledgments\nThis project is based on the following codebase.\n- [Multinomial Diffusion](https://github.com/ehoogeboom/multinomial_diffusion/tree/9d907a60536ad793efd6d2a6067b3c3d6ba9fce7)\n- [MotionSC](https://github.com/UMich-CURLY/3DMapping)\n- [Cylinder3D](https://github.com/xinge008/Cylinder3D)\n"
  },
  {
    "path": "SSC_train.py",
    "content": "import argparse\nimport os\nimport warnings\nimport time\nimport torch\nfrom utils.intermediate_vis import Vis_iter\n\nfrom datasets.data import *\nfrom utils.cuda import launch\nfrom utils.multistep import get_optim\nfrom train import Experiment\n\nfrom layers.Voxel_Level.Gen_Diffusion import Diffusion\nfrom layers.Voxel_Level.Con_Diffusion import Con_Diffusion\n\nfrom layers.Latent_Level.stage1.vqvae import vqvae\nfrom layers.Latent_Level.stage2.Gen_diffusion import latent_diffusion\n\nfrom layers.Ablation.wo_diffusion import wo_diff\n\n# environment variables\nNODE_RANK = os.environ['AZ_BATCHAI_TASK_INDEX'] if 'AZ_BATCHAI_TASK_INDEX' in os.environ else 0\nNODE_RANK = int(NODE_RANK)\nMASTER_ADDR, MASTER_PORT = os.environ['AZ_BATCH_MASTER_NODE'].split(':') if 'AZ_BATCH_MASTER_NODE' in os.environ else (\"127.0.0.1\", 29500)\nMASTER_PORT = int(MASTER_PORT)\nDIST_URL = 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT)\n\ndef get_args():\n    ###########\n    ## Setup ##\n    ###########\n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('--gpu', type=int, default=None, help='GPU id to use. If given, only the specific gpu will be used, and ddp will be disabled')\n    parser.add_argument('--distribution', type=bool, default=True)\n    parser.add_argument('--num_node', type=int, default=1, help='number of nodes for distributed training')\n    parser.add_argument('--node_rank', type=int, default=0, help='node rank for distributed training')\n    parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:29500', help='url used to set up distributed training')\n    \n    # Data params\n    parser.add_argument('--dataset', type=str, default='carla', choices='carla')\n    parser.add_argument('--dataset_dir', type=str, required=True, help='Path to the dataset directory')\n    # Train params\n    parser.add_argument('--batch_size', type=int, default=4)\n    parser.add_argument('--num_workers', type=int, default=4)\n    parser.add_argument('--pin_memory', type=eval, default=False)\n    parser.add_argument('--augmentation', type=str, default=None)\n\n    # Experiemtn params\n    parser.add_argument('--clip_value', type=float, default=None)\n    parser.add_argument('--clip_norm', type=float, default=None)\n    parser.add_argument('--recon_loss', default=False)\n    parser.add_argument('--mode', default='wo_diff', choices='gen, con, vis, l_vae l_gen, wo_diff')\n    parser.add_argument('--l_size', default='32322', choices=['882', '16162', '32322'])\n    parser.add_argument('--init_size', type=int, default=8)\n    parser.add_argument('--l_attention', default=True)\n    parser.add_argument('--vq_size', type=int, default=50)\n\n    # Model params\n    parser.add_argument('--auxiliary_loss_weight', type=int, default=0.0005)\n    parser.add_argument('--diffusion_steps', type=int, default=100)\n    parser.add_argument('--diffusion_dim', type=int, default=32)\n    parser.add_argument('--dp_rate', type=float, default=0.)\n\n    # Optim params\n    parser.add_argument('--optimizer', type=str, default='adam')\n    parser.add_argument('--lr', type=float, default=1e-3)\n    parser.add_argument('--warmup', type=int, default=None)\n    parser.add_argument('--momentum', type=float, default=0.9)\n    parser.add_argument('--momentum_sqr', type=float, default=0.999)\n    parser.add_argument('--milestones', type=eval, default=[])\n    parser.add_argument('--gamma', type=float, default=0.1)\n\n    # Train params\n    parser.add_argument('--epochs', type=int, default=5000)\n    parser.add_argument('--resume', type=str, default=False)\n    parser.add_argument('--resume_path', type=str, default='')\n    parser.add_argument('--vqvae_path', type=str, default='')\n\n    # Logging params\n    parser.add_argument('--eval_every', type=int, default=10)\n    parser.add_argument('--check_every', type=int, default=5)\n    parser.add_argument('--completion_epoch', type=int, default=20)\n    parser.add_argument('--log_tb', type=eval, default=True)\n    parser.add_argument('--log_home', type=str, default=None)\n    parser.add_argument('--log_path', type=str, default='')\n\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    print('start!')\n    args = get_args()\n\n    if args.gpu is not None:\n        warnings.warn('You have chosen a specific GPU. This will completely disable ddp.')\n        torch.cuda.set_device(args.gpu)\n        args.ngpus_per_node = 1\n        args.world_size = 1\n    else:\n        if args.num_node == 1:\n            args.dist_url == \"auto\"\n        else:\n            assert args.num_node > 1\n        args.ngpus_per_node = torch.cuda.device_count()\n        args.world_size = args.ngpus_per_node * args.num_node\n\n    launch(start, args.ngpus_per_node, args.num_node, args.node_rank, args.dist_url, args=(args,))\n\n\ndef start(local_rank, args):\n    args.local_rank = local_rank\n    args.global_rank = args.local_rank + args.node_rank * args.ngpus_per_node\n    args.distributed = args.world_size > 1\n\n    ##################\n    ## Specify data ##\n    ##################\n    train_loader, eval_loader, test_loader, num_classes, comp_weights, seg_weights, train_sampler = get_data(args)\n    args.num_classes = num_classes\n\n    completion_criterion = torch.nn.CrossEntropyLoss(weight=comp_weights)\n    seg_criterion = torch.nn.CrossEntropyLoss(weight=seg_weights, ignore_index=0)\n    similarity_criterion = torch.nn.MSELoss()\n\n    #######################\n    ## Without Diffusion ##\n    #######################\n    if args.mode == 'wo_diff':\n        model = wo_diff(args, completion_criterion).cuda()\n        if args.distribution :\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n\n    ########################\n    ## Discrete Diffusion ##\n    ########################\n    elif args.mode == 'gen':\n        model = Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()\n        if args.distribution :\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n\n    elif args.mode == 'con':\n        model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()\n        if args.distribution :\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n    \n    ######################\n    ## Latent Diffusion ##\n    ######################\n    elif args.mode == 'l_vae':\n        model = vqvae(args, completion_criterion).cuda()\n        if args.distribution:\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n\n    elif args.mode == 'l_gen':\n        Dense = vqvae(args, completion_criterion).cuda()\n        dense_check = torch.load(args.vqvae_path)\n        model = latent_diffusion(args, Dense, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()\n        if args.distribution:\n            Dense = torch.nn.parallel.DistributedDataParallel(Dense, device_ids=[args.gpu], find_unused_parameters=False)\n            Dense.module.load_state_dict(dense_check['model'])\n            for p in Dense.module.parameters():\n                p.requires_grad = False   \n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n            \n    ###################\n    ## Visualization ##\n    ###################\n    elif args.mode == 'vis':\n        model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()\n        if args.distribution :\n            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n\n    optimizer, scheduler_iter, scheduler_epoch = get_optim(args, model)\n    if args.mode == 'vis':\n        exp = Vis_iter(args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader, args.log_path)\n    \n    else : \n        exp = Experiment(args, model, optimizer, scheduler_iter, scheduler_epoch,\n                        train_loader, eval_loader, test_loader, train_sampler, \n                        args.log_path, args.eval_every, args.check_every)\n    \n    exp.run(epochs = args.epochs)\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "datasets/carla.yaml",
    "content": "color_map :\n  0 : [255, 255, 255]  # None\n  1 : [70, 70, 70]     # Building\n  2 : [100, 40, 40]    # Fences\n  3 : [55, 90, 80]     # Other\n  4 : [255, 255, 0 ]   # Pedestrian\n  5 : [153, 153, 153]  # Pole\n  6 : [157, 234, 50]   # RoadLines\n  7 : [0, 0, 255]      # Road\n  8 : [255, 255, 255]  # Sidewalk\n  9 : [0, 155, 0]      # Vegetation\n  10 : [255, 0, 0]     # Vehicle\n  11 : [102, 102, 156] # Wall\n  12 : [220, 220, 0]   # TrafficSign\n  13 : [70, 130, 180]  # Sky\n  14 : [255, 255, 255] # Ground\n  15 : [150, 100, 100] # Bridge\n  16 : [230, 150, 140] # RailTrack\n  17 : [180, 165, 180] # GuardRail\n  18 : [250, 170, 30]  # TrafficLight\n  19 : [110, 190, 160] # Static\n  20 : [170, 120, 50]  # Dynamic\n  21 : [45, 60, 150]   # Water\n  22 : [145, 170, 100] # Terrain\n\nlearning_map :\n  0 : 0\n  1 : 1\n  2 : 2\n  3 : 3\n  4 : 4\n  5 : 5\n  6 : 6\n  7 : 6\n  8 : 8\n  9 : 9\n  10: 10\n  11 : 2\n  12 : 5\n  13 : 3\n  14 : 7\n  15 : 3\n  16 : 3\n  17 : 2\n  18 : 5\n  19 : 3\n  20 : 3\n  21 : 3\n  22 : 7\n\nremap_color_map:\n  0 : [255, 255, 255]  # None\n  1 : [255, 200, 0]     # Building\n  2 : [255, 120, 50]    # Fences\n  3 : [55, 90, 80]     # Other\n  4 : [255, 30, 30]   # Pedestrian\n  5 : [255, 240, 150]  # Pole\n  6 : [255, 0, 255]      # Road\n  7 : [175, 0, 75] # Ground\n  8 : [75, 0, 75]  # Sidewalk\n  9 : [0, 175, 0]      # Vegetation\n  10 : [100, 150, 245]     # Vehicle\n\nlabel_to_names:\n  0 : Free\n  1 : Building\n  2 : Barrier\n  3 : Other\n  4 : Pedestrian\n  5 : Pole\n  6 : Road\n  7 : Ground\n  8 : Sidewalk\n  9 : Vegetation\n  10 : Vehicle\n\ncontent :\n  0 : 4166593275\n  1 : 42309744\n  2 : 8550180\n  3 : 478193\n  4 : 905663\n  5 : 2801091\n  6 : 6452733\n  7 : 229316930\n  8 : 112863867\n  9 : 29816894\n  10: 13839655\n  11 : 15581458\n  12 : 221821\n  13 : 0\n  14 : 7931550\n  15 : 467989\n  16 : 3354\n  17 : 9201043\n  18 : 61011\n  19 : 3796746\n  20 : 3217865\n  21 : 215372\n  22 : 79669695\n\nremap_content : \n  0 : 4.16659328e+09\n  1 : 4.23097440e+07\n  2 : 3.33326810e+07\n  3 : 8.17951900e+06\n  4 : 9.05663000e+05\n  5 : 3.08392300e+06\n  6 : 2.35769663e+08\n  7 : 8.76012450e+07\n  8 : 1.12863867e+08\n  9 : 2.98168940e+07\n  10 : 1.38396550e+07"
  },
  {
    "path": "datasets/carla_dataset.py",
    "content": "import os\nimport numpy as np\nimport random\nimport json\nimport yaml\nimport torch\nimport numba as nb\nfrom torch.utils.data import Dataset\n\nbase_dir = os.path.dirname(__file__)\nconfig_file = os.path.join(base_dir, 'carla.yaml')\ncarla_config = yaml.safe_load(open(config_file, 'r'))\nLABELS_REMAP = carla_config[\"learning_map\"]\nREMAP_FREQUENCIES = carla_config[\"remap_content\"]\nFREQUENCIES= carla_config[\"content\"]\n\nLABELS_REMAP = np.asarray(list(LABELS_REMAP.values()))\nfrequencies_cartesian = np.asarray(list(FREQUENCIES.values()))\nremap_frequencies_cartesian = np.asarray(list(REMAP_FREQUENCIES.values()))\n\nclass CarlaDataset(Dataset):\n    \"\"\"Carla Simulation Dataset for 3D mapping project\n    Access to the processed data, including evaluation labels predictions velodyne poses times\n    \"\"\"\n    def __init__(self, directory,\n        voxelize_input=True,\n        binary_counts=True,\n        random_flips=False,\n        remap=True,\n        num_frames=1,\n        transform_pose=True,\n        get_gt=True,\n        ):\n        '''Constructor.\n        Parameters:\n            directory: directory to the dataset\n        '''\n        self.get_gt = get_gt\n        self.voxelize_input = voxelize_input\n        self.binary_counts = binary_counts\n        self._directory = directory\n        self._num_frames = num_frames\n        self.random_flips = random_flips\n        self.remap = remap\n        self.transform_pose = transform_pose\n        self.sparse_output = True\n        \n        self._scenes = sorted(os.listdir(self._directory))\n        self._scenes = [os.path.join(scene, \"cartesian\") for scene in self._scenes]\n\n        self._num_scenes = len(self._scenes)\n        self._num_frames_scene = []\n\n        param_file = os.path.join(self._directory, self._scenes[0], 'evaluation', 'params.json')\n        with open(param_file) as f:\n            self._eval_param = json.load(f)\n        \n        self._out_dim = self._eval_param['num_channels']\n        self._grid_size = self._eval_param['grid_size']\n        self.grid_dims = np.asarray(self._grid_size)\n        self._eval_size = list(np.uint32(self._grid_size))\n        \n        self.coor_ranges = self._eval_param['min_bound'] + self._eval_param['max_bound']\n        self.voxel_sizes = [abs(self.coor_ranges[3] - self.coor_ranges[0]) / self._grid_size[0], \n                      abs(self.coor_ranges[4] - self.coor_ranges[1]) / self._grid_size[1],\n                      abs(self.coor_ranges[5] - self.coor_ranges[2]) / self._grid_size[2]]\n        self.min_bound = np.asarray(self.coor_ranges[:3])\n        self.max_bound = np.asarray(self.coor_ranges[3:])\n        self.voxel_sizes = np.asarray(self.voxel_sizes)\n\n        self._velodyne_list = []\n        self._label_list = []\n        self._pred_list = []\n        self._eval_labels = []\n        self._eval_counts = []\n        self._frames_list = []\n        self._timestamps = []\n        self._poses = [] \n\n        for scene in self._scenes:\n            velodyne_dir = os.path.join(self._directory, scene, 'velodyne')\n            label_dir = os.path.join(self._directory, scene, 'labels')\n            pred_dir = os.path.join(self._directory, scene, 'predictions')\n            eval_dir = os.path.join(self._directory, scene, 'evaluation')\n            \n            self._num_frames_scene.append(len(os.listdir(velodyne_dir)))\n\n            frames_list = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(velodyne_dir))]\n            self._frames_list.extend(frames_list)\n            self._velodyne_list.extend([os.path.join(velodyne_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])\n            self._label_list.extend([os.path.join(label_dir, str(frame).zfill(6)+'.label') for frame in frames_list])\n            self._pred_list.extend([os.path.join(pred_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])\n            self._eval_labels.extend([os.path.join(eval_dir, str(frame).zfill(6)+'.label') for frame in frames_list])\n            self._eval_counts.extend([os.path.join(eval_dir, str(frame).zfill(6) + '.bin') for frame in frames_list])\n            self._timestamps.append(np.loadtxt(os.path.join(self._directory, scene, 'times.txt')))\n            self._poses.append(np.loadtxt(os.path.join(self._directory, scene, 'poses.txt')))\n            # for poses and timestamps\n        self._timestamps = np.array(self._timestamps).reshape(sum(self._num_frames_scene))\n        self._poses = np.array(self._poses).reshape(sum(self._num_frames_scene), 12)\n        \n        self._cum_num_frames = np.cumsum(np.array(self._num_frames_scene) - self._num_frames + 1)\n\n    # Use all frames, if there is no data then zero pad\n    def __len__(self):\n        return sum(self._num_frames_scene)\n    \n    def collate_fn(self, data):\n        voxel_batch = [bi[0] for bi in data]\n        output_batch = [bi[1] for bi in data]\n        counts_batch = [bi[2] for bi in data]\n        return voxel_batch, output_batch, counts_batch\n    \n    def points_to_voxels(self, voxel_grid, points, t_i):\n        # Valid voxels (make sure to clip)\n        voxels = np.floor((points - self.min_bound) / self.voxel_sizes).astype(np.int32)\n        # Clamp to account for any floating point errors\n        maxes = np.reshape(self.grid_dims - 1, (1, 3))\n        mins = np.zeros_like(maxes)\n        voxels = np.clip(voxels, mins, maxes).astype(np.int32)\n        # This line is needed to create a mask with number of points, not just binary occupied\n        if self.binary_counts:\n            voxel_grid[t_i, voxels[:, 0], voxels[:, 1], voxels[:, 2]] += 1\n        else:\n            unique_voxels, counts = np.unique(voxels, return_counts=True, axis=0)\n            unique_voxels = unique_voxels.astype(np.int32)\n            voxel_grid[t_i, unique_voxels[:, 0], unique_voxels[:, 1], unique_voxels[:, 2]] += counts\n        return voxel_grid\n\n    def get_pose(self, idx):\n        pose = np.zeros((4, 4))\n        pose[3, 3] = 1\n        pose[:3, :4] = self._poses[idx].reshape(3, 4)\n        return pose\n\n    def __getitem__(self, idx):\n        # -1 indicates no data\n        # the final index is the output\n        idx_range = self.find_horizon(idx)\n        if self.transform_pose:\n            ego_pose = self.get_pose(idx_range[-1])\n            to_ego = np.linalg.inv(ego_pose)\n         \n        if self.voxelize_input:\n            voxel_input = np.zeros((idx_range.shape[0], int(self.grid_dims[0]), int(self.grid_dims[1]), int(self.grid_dims[2])), dtype=np.float32)\n        t_i = 0\n\n        for i in idx_range:\n            if i == -1: # Zero pad\n                points = np.zeros((1, 3), dtype=np.float32)\n                \n            else:\n                points = np.fromfile(self._velodyne_list[i],dtype=np.float32).reshape(-1, 4)[:, :3]\n\n                if self.transform_pose:\n                    to_world = self.get_pose(i)\n                    relative_pose = np.matmul(to_ego, to_world)\n                    points = np.dot(relative_pose[:3, :3], points.T).T + relative_pose[:3, 3]\n\n                valid_point_mask= np.all((points < self.max_bound) & (points >= self.min_bound), axis=1)\n                valid_points = points[valid_point_mask, :]\n\n            if self.voxelize_input:\n                voxel_input = self.points_to_voxels(voxel_input, valid_points, t_i)\n\n            t_i += 1\n\n        if self.get_gt:\n            output = np.fromfile(self._eval_labels[idx_range[-1]],dtype=np.uint32).reshape(self._eval_size).astype(np.uint8)\n            counts = np.fromfile(self._eval_counts[idx_range[-1]],dtype=np.float32).reshape(self._eval_size)\n        else:\n            output = None\n            counts = None\n\n        if self.voxelize_input and self.random_flips:\n            # X flip\n            if np.random.randint(2):\n                output = np.flip(output, axis=0)\n                counts = np.flip(counts, axis=0)\n                voxel_input = np.flip(voxel_input, axis=1) # Because there is a time dimension\n            # Y Flip\n            if np.random.randint(2):\n                output = np.flip(output, axis=1)\n                counts = np.flip(counts, axis=1)\n                voxel_input = np.flip(voxel_input, axis=2) # Because there is a time dimension\n                \n        if self.remap:\n            output = LABELS_REMAP[output].astype(np.uint8)            \n\n        return voxel_input, output, counts\n        \n        # no enough frames\n    \n    def find_horizon(self, idx):\n        end_idx = idx\n        idx_range = np.arange(idx-self._num_frames, idx)+1\n        diffs = np.asarray([int(self._frames_list[end_idx]) - int(self._frames_list[i]) for i in idx_range])\n        good_difs = -1 * (np.arange(-self._num_frames, 0) + 1)\n        \n        idx_range[good_difs != diffs] = -1\n\n        return idx_range\n"
  },
  {
    "path": "datasets/data.py",
    "content": "import os\nimport math\nimport torch\nimport numpy as np\nfrom torch.utils.data import DataLoader\nfrom datasets.carla_dataset import *\n\ndataset_choices = {'carla', 'kitti'}\n\n\ndef get_data_id(args):\n    return '{}'.format(args.dataset)\n\ndef get_class_weights(freq):\n    '''\n    Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf)\n    '''\n    epsilon_w = 0.001  # eps to avoid zero division\n    weights = torch.from_numpy(1 / np.log(freq + epsilon_w))\n\n    return weights\n\ndef get_data(args):\n    assert args.dataset in dataset_choices\n    if args.dataset == 'carla':\n        train_dir = os.path.join(args.dataset_dir, \"Train\")\n        val_dir   = os.path.join(args.dataset_dir, \"Val\")\n        test_dir  = os.path.join(args.dataset_dir, \"Test\")\n\n        x_dim = 128\n        y_dim = 128\n        z_dim = 8\n        data_shape = [x_dim, y_dim, z_dim]\n        args.data_shape= data_shape\n\n        binary_counts = True\n        transform_pose = True\n        remap = True\n        if remap:\n            class_frequencies = remap_frequencies_cartesian\n            args.num_classes = 11\n        else:\n            args.num_classes = 23\n\n        comp_weights = get_class_weights(class_frequencies).to(torch.float32)\n        seg_weights = get_class_weights(class_frequencies[1:]).to(torch.float32)\n\n        train_ds = CarlaDataset(directory=train_dir, random_flips=True, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)\n        coor_ranges = train_ds._eval_param['min_bound'] + train_ds._eval_param['max_bound']\n        voxel_sizes = [abs(coor_ranges[3] - coor_ranges[0]) / x_dim,\n                    abs(coor_ranges[4] - coor_ranges[1]) / y_dim,\n                    abs(coor_ranges[5] - coor_ranges[2]) / z_dim] # since BEV\n        val_ds = CarlaDataset(directory=val_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)\n        test_ds = CarlaDataset(directory=test_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)\n\n        if args is not None and args.distributed:\n            train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, shuffle=True)\n            val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds, shuffle=False)\n            train_iters = len(train_sampler) // args.batch_size\n            val_iters = len(val_sampler) // args.batch_size\n        else:\n            train_sampler = None\n            val_sampler = None\n            train_iters = len(train_ds) // args.batch_size\n            val_iters = len(val_ds) // args.batch_size\n        \n        dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, collate_fn=train_ds.collate_fn, num_workers=args.num_workers)\n        dataloader_val = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, sampler=val_sampler, collate_fn=val_ds.collate_fn, num_workers=args.num_workers)\n        dataloader_test = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=test_ds.collate_fn, num_workers=args.num_workers)\n    else:\n        raise NotImplementedError(\"Wrong `dataset` has come. Other datasets are not supported.\")\n    \n    \n    return dataloader, dataloader_val, dataloader_test, args.num_classes, comp_weights, seg_weights, train_sampler\n"
  },
  {
    "path": "layers/Ablation/wo_diffusion.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nfrom layers.Latent_Level.stage1.model import C_Encoder, C_Decoder\n\nclass wo_diff(torch.nn.Module):\n    def __init__(self, args, multi_criterion) -> None:\n        super(wo_diff, self).__init__()\n        self.args = args\n\n        if self.args.dataset == 'kitti':\n            init_size = args.init_size\n        elif self.args.dataset == 'carla':\n            init_size = args.init_size\n        \n        self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)\n        self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)\n        \n        self.multi_criterion = multi_criterion\n\n    def device(self):\n        return self.encoder.device\n\n    def forward(self, x, input_ten):\n        latent = self.encoder(input_ten, out_conv=False) \n        recons = self.decoder(latent, in_conv=False)\n        recons_loss = self.multi_criterion(recons, x)\n        return recons_loss \n\n    def sample(self, x):\n        latent = self.encoder(x, out_conv=False) \n        recons = self.decoder(latent, in_conv=False)\n        recons = recons.argmax(1)\n        return recons\n"
  },
  {
    "path": "layers/Latent_Level/stage1/model.py",
    "content": "import numpy as np\nimport math\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import rearrange, reduce, repeat\nfrom torch import nn, einsum\n\n\ndef conv3x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\ndef conv1x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)\n\ndef conv1x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)\n\ndef conv1x3x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)\n\ndef conv3x1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)\n\ndef conv3x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)\n\n\nclass Asymmetric_Residual_Block(nn.Module):\n    def __init__(self, in_filters, out_filters):\n        super(Asymmetric_Residual_Block, self).__init__()\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.act1 = nn.LeakyReLU()          \n        self.conv1_2 = conv3x1x3(out_filters, out_filters)\n        self.act1_2 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(in_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n\n        self.conv3 = conv1x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n\n        if in_filters<32 :\n            self.GroupNorm = nn.GroupNorm(8, in_filters)\n            self.bn0 = nn.GroupNorm(8, out_filters)\n            self.bn0_2 = nn.GroupNorm(8, out_filters)\n            self.bn1 = nn.GroupNorm(8, out_filters)\n            self.bn2 = nn.GroupNorm(8, out_filters)\n        else :\n            self.GroupNorm = nn.GroupNorm(32, in_filters)\n            self.bn0 = nn.GroupNorm(32, out_filters)\n            self.bn0_2 = nn.GroupNorm(32, out_filters)\n            self.bn1 = nn.GroupNorm(32, out_filters)\n            self.bn2 = nn.GroupNorm(32, out_filters)\n\n\n    def forward(self, x):\n        shortcut = self.conv1(x)\n        shortcut = self.act1(shortcut)\n        shortcut = self.bn0(shortcut)\n\n        shortcut = self.conv1_2(shortcut)\n        shortcut = self.act1_2(shortcut)\n        shortcut = self.bn0_2(shortcut)\n\n        resA = self.conv2(x) \n        resA = self.act2(resA)\n        resA = self.bn1(resA)\n\n        resA = self.conv3(resA) \n        resA = self.act3(resA)\n        resA = self.bn2(resA)\n        resA += shortcut\n\n        return resA\n\n\nclass DownBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, pooling=True, drop_out=True, height_pooling=False):\n        super(DownBlock, self).__init__()\n        self.pooling = pooling\n        self.drop_out = drop_out\n        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)\n        if pooling:\n            if height_pooling:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)\n            else:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)\n\n    def forward(self, x):\n        resA = self.residual_block(x)\n        if self.pooling:\n            resB = self.pool(resA) \n            return resB, resA\n        else:\n            return resA\n\n\nclass UpBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, height_pooling):\n        super(UpBlock, self).__init__()\n        # self.drop_out = drop_out\n        self.trans_dilao = conv3x3x3(in_filters, out_filters)\n        self.trans_act = nn.LeakyReLU()\n\n        self.conv1 = conv1x3x3(out_filters, out_filters)\n        self.act1 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(out_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n\n        self.conv3 = conv3x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n\n        if out_filters<32 :\n            self.trans_bn = nn.GroupNorm(8, out_filters)\n            self.bn1 = nn.GroupNorm(8, out_filters)\n            self.bn2 = nn.GroupNorm(8, out_filters)\n            self.bn3 = nn.GroupNorm(8, out_filters)\n        else :\n            self.trans_bn = nn.GroupNorm(32, out_filters)\n            self.bn1 = nn.GroupNorm(32, out_filters)\n            self.bn2 = nn.GroupNorm(32, out_filters)\n            self.bn3 = nn.GroupNorm(32, out_filters)\n        \n        if height_pooling :\n            self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)\n        else : \n            self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)\n\n\n    def forward(self, x, skip=False): \n        if skip :\n            x, residual = x\n        upA = self.trans_dilao(x)\n        upA = self.trans_act(upA)\n        upA = self.trans_bn(upA) \n\n        upA = self.up_subm(upA)\n        if skip :\n            upA += residual\n        upE = self.conv1(upA)\n        upE = self.act1(upE)\n        upE = self.bn1(upE)\n\n        upE = self.conv2(upE)\n        upE = self.act2(upE)\n        upE = self.bn2(upE)\n\n        upE = self.conv3(upE)\n        upE = self.act3(upE)\n        upE = self.bn3(upE)\n        return upE\n\n\nclass DDCM(nn.Module):\n    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):\n        super(DDCM, self).__init__()\n        self.conv1 = conv3x1x1(in_filters, out_filters)\n        self.act1 = nn.Sigmoid()\n\n        self.conv1_2 = conv1x3x1(in_filters, out_filters)\n        self.act1_2 = nn.Sigmoid()\n\n        self.conv1_3 = conv1x1x3(in_filters, out_filters)\n        self.act1_3 = nn.Sigmoid()\n\n        if in_filters<32 :\n            self.bn0 = nn.GroupNorm(8, out_filters)\n            self.bn0_2 = nn.GroupNorm(8, out_filters)\n            self.bn0_3 = nn.GroupNorm(8, out_filters)\n        else :\n            self.bn0 = nn.GroupNorm(32, out_filters)\n            self.bn0_2 = nn.GroupNorm(32, out_filters)\n            self.bn0_3 = nn.GroupNorm(32, out_filters)\n\n    def forward(self, x):\n        shortcut = self.conv1(x)\n        shortcut = self.bn0(shortcut)\n        shortcut = self.act1(shortcut)\n\n        shortcut2 = self.conv1_2(x)\n        shortcut2 = self.bn0_2(shortcut2)\n        shortcut2 = self.act1_2(shortcut2)\n\n        shortcut3 = self.conv1_3(x)\n        shortcut3 = self.bn0_3(shortcut3)\n        shortcut3 = self.act1_3(shortcut3)\n        shortcut = shortcut + shortcut2 + shortcut3\n\n        shortcut = shortcut * x\n        return shortcut\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n\nclass Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_qkv = conv1x1(dim, dim*3, stride=1)\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x):\n        b, c, h, w, Z = x.shape\n        qkv = self.to_qkv(x).chunk(3, dim = 1)\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass C_Encoder(nn.Module):\n    def __init__(self, args,  nclasses=20, init_size=16, l_size='882', attention=True):\n        super(C_Encoder, self).__init__()\n        self.nclasses = nclasses\n        self.args = args\n        self.l_size = l_size\n        self.attention = attention\n\n        self.embedding = nn.Embedding(nclasses, init_size)\n\n        self.A = Asymmetric_Residual_Block(init_size, init_size)\n\n        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)\n        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)\n        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)\n        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)\n        \n        if self.l_size == '32322':\n            self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n            self.attention = Attention(4 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n            self.out = nn.Conv3d(4 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)\n        elif self.l_size == '16162':\n            self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)\n            self.attention = Attention(8 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)\n            self.out = nn.Conv3d(8 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)\n        elif self.l_size == '882':\n            self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n            self.attention = Attention(16 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n            self.out = nn.Conv3d(16 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)\n        else:\n            raise NotImplementedError(\"Unsupported `l_size` has come\")\n        \n    def forward(self, x, out_conv=True):\n        x = self.embedding(x)\n        x = x.permute(0, 4, 1, 2, 3)\n\n        x = self.A(x)\n        x, down1b = self.downBlock1(x)\n        x, down2b = self.downBlock2(x)\n\n        if self.l_size == '882':\n            x, down3b = self.downBlock3(x)\n            x, down4b = self.downBlock4(x)\n        elif self.l_size == '16162':\n            x, down3b = self.downBlock3(x)\n        \n        if self.attention : \n            x = self.midBlock1(x) # (4, 128, 32, 32, 2)\n            x = self.attention(x)\n            x = self.midBlock2(x) # (4, 128, 32, 32, 2)\n        if out_conv : \n            x = self.out(x)\n        return x\n\nclass C_Decoder(nn.Module):\n    def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True):\n        super(C_Decoder, self).__init__()\n        self.nclasses = nclasses\n        self.args = args\n        self.l_size = l_size\n        self.attention = attention\n\n        if l_size == '882':\n            self.conv_in = nn.Conv3d(nclasses, 16 * init_size, kernel_size=3, stride=1, padding=1,bias=True)\n            self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n            self.attention = Attention(16 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n        elif l_size == '16162':\n            self.conv_in = nn.Conv3d(nclasses, 8 * init_size, kernel_size=3, stride=1, padding=1,bias=True)\n            self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)\n            self.attention = Attention(8 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)\n        elif (l_size =='32322'):\n            self.conv_in = nn.Conv3d(nclasses, 4 * init_size, kernel_size=3, stride=1, padding=1,bias=True)\n            self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n            self.attention = Attention(4 * init_size, 32)\n            self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n\n        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)\n        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)\n        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)\n        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)\n        self.DDCM = DDCM(2 * init_size, 2 * init_size)\n        self.logits = nn.Conv3d(4 * init_size, self.nclasses, kernel_size=3, stride=1, padding=1, bias=True)\n\n    def forward(self, x, in_conv=True):\n        if in_conv :\n            x = self.conv_in(x)\n\n        if self.attention : \n            x = self.midBlock1(x)\n            x = self.attention(x)\n            x = self.midBlock2(x)                    \n\n        if self.l_size == '882':\n            x = self.upBlock4(x)\n            x = self.upBlock3(x)\n            \n        elif self.l_size == '16162':\n            x = self.upBlock3(x)\n\n        x = self.upBlock2(x)\n        up1 = self.upBlock1(x)\n\n        up0 = self.DDCM(up1) \n        up = torch.cat((up1, up0), 1) \n        logits = self.logits(up) \n        return logits\n\nclass Completion(nn.Module):\n    def __init__(self, args, num_class = 11, init_size=32):\n        super(Completion, self).__init__()\n        self.args = args\n        self.num_class = num_class\n        self.init_size = init_size\n\n        self.embedding = nn.Embedding(self.num_class, init_size)\n\n        self.A = Asymmetric_Residual_Block(init_size, init_size)\n\n        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)\n        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)\n        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)\n        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)\n        \n        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n        self.attention = Attention(16 * init_size, 32)\n        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)\n\n        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)\n        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)\n        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)\n        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)\n\n        self.DDCM = DDCM(2 * init_size, 2 * init_size)\n        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)\n        \n\n    def forward(self, x):\n        x = self.embedding(x)\n        x = x.permute(0, 4, 1, 2, 3)\n\n        x = self.A(x)\n        down1c, down1b = self.downBlock1(x)\n        down2c, down2b = self.downBlock2(down1c) \n        down3c, down3b = self.downBlock3(down2c)\n        down4c, down4b = self.downBlock4(down3c) \n\n        down4c = self.midBlock1(down4c) \n        down4c = self.attention(down4c)\n        down4c = self.midBlock2(down4c) \n        \n        up4 = self.upBlock4((down4c, down4b), skip=True)\n        up3 = self.upBlock3((up4, down3b), skip=True)\n        up2 = self.upBlock2((up3, down2b), skip=True)\n        up1 = self.upBlock1((up2, down1b), skip=True)\n\n        up0 = self.DDCM(up1) \n        up = torch.cat((up1, up0), 1) \n        logits = self.logits(up) \n        return logits\n"
  },
  {
    "path": "layers/Latent_Level/stage1/vector_quantizer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass VectorQuantizer(nn.Module):\n\n    def __init__(self,\n                 num_embeddings: int,\n                 embedding_dim: int,\n                 beta: float = 0.25):\n        super(VectorQuantizer, self).__init__()\n        self.K = num_embeddings\n        self.D = embedding_dim\n        self.beta = beta\n\n        self.embedding = nn.Embedding(self.K, self.D)\n        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)\n\n    def forward(self, z: torch.tensor, point=False) -> torch.tensor: # latents (8, 128, 8, 8, 2)\n        z = z.permute(0, 2, 3, 4, 1).contiguous()  # [B x D x H x W x Z] -> [B x H x W x Z x D]\n        latents_shape = z.shape # ( 8, 8, 8, 2, 128 )\n        flat_latents = z.view(-1, self.D)  # [BHWZ x D] = [1024, 128]\n\n        # Compute L2 distance between latents and embedding weights\n        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - \\\n               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHWZ x K]\n\n        # Get the encoding that has the min distance\n        min_encoding_indices = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHWZ, 1]\n\n        z_q = self.embedding(min_encoding_indices).view(z.shape)\n\n        # Compute the VQ Losses\n        commitment_loss = F.mse_loss(z_q.detach(), z)\n        embedding_loss = F.mse_loss(z_q, z.detach())\n        if point :\n            vq_loss = commitment_loss * self.beta\n        else :\n            vq_loss = commitment_loss * self.beta + embedding_loss\n\n        # Add the residue back to the latents\n        z_q = z + (z_q - z).detach()\n\n        return z_q.permute(0, 4, 1, 2, 3).contiguous(), vq_loss, min_encoding_indices, latents_shape\n\n    def codebook_to_embedding(self, encoding_inds, latents_shape): # latents (16, 512, 8, 8, 2)\n        # Convert to one-hot encodings\n        z_q = self.embedding(encoding_inds).view(latents_shape)\n        return z_q.permute(0, 4, 1, 2, 3).contiguous()\n"
  },
  {
    "path": "layers/Latent_Level/stage1/vqvae.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nimport math\nfrom utils.loss import lovasz_softmax\nfrom layers.Latent_Level.stage1.model import C_Encoder, C_Decoder\nfrom layers.Latent_Level.stage1.vector_quantizer import VectorQuantizer\n\nclass vqvae(torch.nn.Module):\n    def __init__(self, args, multi_criterion) -> None:\n        super(vqvae, self).__init__()\n        self.args = args\n\n        init_size = args.init_size\n        embedding_dim = int(self.args.num_classes)\n        \n        self.VQ = VectorQuantizer(num_embeddings = int(self.args.num_classes)*int(self.args.vq_size), embedding_dim = embedding_dim)\n\n        self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)\n        self.quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)\n\n        self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)\n        self.post_quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)\n\n        self.multi_criterion = multi_criterion\n\n    def device(self):\n        return self.encoder.device\n\n    def encode(self, x):\n        latent = self.encoder(x) \n        latent = self.quant_conv(latent)\n        return latent\n\n    def vector_quantize(self, latent):\n        quantized_latent, vq_loss, quantized_latent_ind, latents_shape = self.VQ(latent)\n        return quantized_latent, vq_loss, quantized_latent_ind, latents_shape\n\n    def coodbook(self,quantized_latent_ind, latents_shape):\n        quantized_latent = self.VQ.codebook_to_embedding(quantized_latent_ind.view(-1,1), latents_shape)\n        return quantized_latent\n\n    def decode(self, quantized_latent):\n        quantized_latent = self.post_quant_conv(quantized_latent)\n        recons = self.decoder(quantized_latent)\n        return recons\n\n    def forward(self, x, input_ten):\n        latent = self.encode(x) \n        quantized_latent, vq_loss, _, _ = self.vector_quantize(latent) \n        recons = self.decode(quantized_latent)\n\n        recons_loss = self.multi_criterion(recons, x)\n        loss = recons_loss + vq_loss \n        return loss \n\n    def sample(self, x):\n        latent = self.encode(x)\n        quantized_latent, _, _, _ = self.vector_quantize(latent)\n        recons = self.decode(quantized_latent)\n        recons = recons.argmax(1)\n        return recons\n"
  },
  {
    "path": "layers/Latent_Level/stage2/Gen_diffusion.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.Latent_Level.stage2.gen_denoise import Denoise\n\n\"\"\"\nBased in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281\n\"\"\"\neps = 1e-8\n\n\ndef sum_except_batch(x, num_dims=1):\n    return x.reshape(*x.shape[:num_dims], -1).sum(-1)\n\n\ndef log_1_min_a(a):\n    return torch.log(1 - a.exp() + 1e-40)\n\n\ndef log_add_exp(a, b):\n    maximum = torch.max(a, b)\n    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))\n\n\ndef exists(x):\n    return x is not None\n\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef log_categorical(log_x_start, log_prob):\n    return (log_x_start.exp() * log_prob).sum(dim=1)\n\n\ndef index_to_log_onehot(x, num_classes):\n    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'\n    \n    x_onehot = F.one_hot(x, num_classes)\n    permute_order = (0, -1) + tuple(range(1, len(x.size())))\n    x_onehot = x_onehot.permute(permute_order)\n    log_x = torch.log(x_onehot.float().clamp(min=1e-30))\n\n    return log_x\n\n\ndef log_onehot_to_index(log_x):\n    return log_x.argmax(1)\n\n\ndef cosine_beta_schedule(timesteps, s = 0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = np.linspace(0, steps, steps)\n    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])\n\n    alphas = np.clip(alphas, a_min=0.001, a_max=1.)\n    alphas = np.sqrt(alphas)\n\n    return alphas\n\nclass latent_diffusion(torch.nn.Module):\n    def __init__(self, args, VAE_DENSE, multi_criterion,\n                 auxiliary_loss_weight=0.0005, adaptive_auxiliary_loss=True):\n        super(latent_diffusion, self).__init__()\n        self.args = args\n        self.num_classes = self.args.num_classes * self.args.vq_size\n        self.denoise = Denoise(args= self.args,  num_class = self.num_classes)\n        \n        self.num_timesteps = self.args.diffusion_steps\n        self.auxiliary_loss_weight = auxiliary_loss_weight\n        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss\n\n        self.VAE_DENSE = VAE_DENSE\n        self.multi_criterion = multi_criterion\n\n        alphas = cosine_beta_schedule(self.num_timesteps )\n\n        alphas = torch.tensor(alphas.astype('float64'))\n        log_alpha = np.log(alphas)\n        log_cumprod_alpha = np.cumsum(log_alpha)\n\n        log_1_min_alpha = log_1_min_a(log_alpha)\n        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)\n\n        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5\n        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5\n        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5\n\n        # Convert to float32 and register buffers.\n        self.register_buffer('log_alpha', log_alpha.float())\n        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())\n        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())\n        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())\n\n        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))\n        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))\n    \n    def device(self):\n        return self.denoise.device\n\n    def multinomial_kl(self, log_prob1, log_prob2):\n        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)\n        return kl\n\n    def q_pred_one_timestep(self, log_x_t, t):\n        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)\n        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)\n\n        # alpha_t * E[xt] + (1 - alpha_t) 1 / K\n        \n        log_probs = log_add_exp(\n            log_x_t + log_alpha_t,\n            log_1_min_alpha_t - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def q_pred(self, log_x_start, t):\n        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)\n        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)\n\n        log_probs = log_add_exp(\n            log_x_start + log_cumprod_alpha_t,\n            log_1_min_cumprod_alpha - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def predict_start(self, log_x_t, t):\n        x_t = log_onehot_to_index(log_x_t)\n\n        out = self.denoise(x_t, t)\n\n        assert out.size(0) == x_t.size(0)\n        assert out.size(1) == self.num_classes\n        assert out.size()[2:] == x_t.size()[1:]\n\n        log_pred = F.log_softmax(out, dim=1)\n        return log_pred\n\n    def q_posterior(self, log_x_start, log_x_t, t):\n        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)\n        # where q(xt | xt-1, x0) = q(xt | xt-1).\n\n        t_minus_1 = t - 1\n        # Remove negative values, will not be used anyway for final decoder\n        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)\n        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)\n\n        num_axes = (1,) * (len(log_x_start.size()) - 1)\n        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)\n        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)\n\n\n        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!\n        # Not very easy to see why this is true. But it is :)\n        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)\n\n        log_EV_xtmin_given_xt_given_xstart = \\\n            unnormed_logprobs \\\n            - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)\n\n        return log_EV_xtmin_given_xt_given_xstart\n\n    def p_pred(self, log_x, t):\n        log_x0_recon = self.predict_start(log_x, t=t)\n        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)\n        return log_model_pred, log_x0_recon\n\n    def log_sample_categorical(self, logits):\n        uniform = torch.rand_like(logits)\n        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)\n        sample = (gumbel_noise + logits).argmax(dim=1)\n        log_sample = index_to_log_onehot(sample, self.num_classes)\n        return log_sample\n\n    def q_sample(self, log_x_start, t):\n        log_EV_qxt_x0 = self.q_pred(log_x_start, t)\n        log_sample = self.log_sample_categorical(log_EV_qxt_x0)\n        return log_sample\n\n    def kl_prior(self, log_x_start):\n        b = log_x_start.size(0)\n        device = log_x_start.device\n        ones = torch.ones(b, device=device).long()\n\n        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)\n        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))\n\n        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)\n        return sum_except_batch(kl_prior)\n\n    def sample_time(self, b, device, method='uniform'):\n        if method == 'importance':\n            if not (self.Lt_count > 10).all():\n                return self.sample_time(b, device, method='uniform')\n\n            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001\n            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.\n            pt_all = Lt_sqrt / Lt_sqrt.sum()\n\n            t = torch.multinomial(pt_all, num_samples=b, replacement=True)\n\n            pt = pt_all.gather(dim=0, index=t)\n\n            return t, pt\n\n        elif method == 'uniform':\n            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n\n            pt = torch.ones_like(t).float() / self.num_timesteps\n            return t, pt\n        else:\n            raise ValueError\n\n    def forward(self, x, input_data):\n        b, device = x.size(0), x.device\n        self.shape = x.size()[1:]\n        \n        latent = self.VAE_DENSE.encode(x)\n        _, _, dense_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)\n        reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]\n\n        t, pt = self.sample_time(b, device, 'importance')\n\n        log_x_start = index_to_log_onehot(dense_ind.view(reshape_size), self.num_classes)\n        log_x_t = self.q_sample(log_x_start=log_x_start, t=t) # log_x_t : (8,551,8,8,2)\n\n        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)\n\n        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)\n\n        kl = self.multinomial_kl(log_true_prob, log_model_prob)\n        kl = sum_except_batch(kl)\n\n        decoder_nll = -log_categorical(log_x_start, log_model_prob)\n        decoder_nll = sum_except_batch(decoder_nll)\n\n        mask = (t == torch.zeros_like(t)).float()\n        kl_loss = mask * decoder_nll + (1. - mask) * kl\n        \n        if self.training:\n            Lt2 = kl_loss.pow(2)\n            Lt2_prev = self.Lt_history.gather(dim=0, index=t)\n            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()\n            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)\n            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))\n\n        kl_prior = self.kl_prior(log_x_start)\n\n        # Upweigh loss term of the kl\n        loss = kl_loss / pt + kl_prior\n\n        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])\n        kl_aux = sum_except_batch(kl_aux)\n        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux\n        if self.adaptive_auxiliary_loss:\n            addition_loss_weight = (1-t/self.num_timesteps) + 1.0\n        else:\n            addition_loss_weight = 1.0\n\n        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt\n        loss += aux_loss\n        loss = -loss.sum() / (math.log(2) * dense_ind.view(reshape_size).shape.numel())\n\n        x0 = log_onehot_to_index(F.log_softmax(log_x0_recon, dim=1))\n\n        return -loss\n\n    def sample(self, x):\n        device = self.log_alpha.device\n        self.shape = x.size()[1:]\n        \n        x = torch.randint(self.args.num_classes, size=x.size()).to(device)\n        latent = self.VAE_DENSE.encode(x)\n        _, _, sparse_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)\n        reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]\n\n        log_z = index_to_log_onehot(sparse_ind.view(reshape_size), self.num_classes) # log_x_t : (8,551,8,8,2)\n\n        for i in reversed(range(0, self.num_timesteps)):\n            print(f'Sample timestep {i:4d}', end='\\r')\n\n            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)\n\n            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)\n\n            uniform = torch.rand_like(log_model_prob)\n            gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)\n            pre_sample = gumbel_noise + log_model_prob\n\n            sample = pre_sample.argmax(dim=1)                  # (32,  1, 32, 64)\n            log_z = index_to_log_onehot(sample, self.num_classes)\n\n        vq_ind = log_onehot_to_index(log_z)\n        vq_latent = self.VAE_DENSE.coodbook(vq_ind.view(-1,1), latents_shape)\n        recons = self.VAE_DENSE.decode(vq_latent)\n        recons = recons.argmax(1)\n        return recons\n"
  },
  {
    "path": "layers/Latent_Level/stage2/gen_denoise.py",
    "content": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nfrom einops import rearrange, reduce, repeat\nfrom torch import nn, einsum\n\n\ndef conv3x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\ndef conv1x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)\n\n\ndef conv1x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)\n\n\ndef conv1x3x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)\n\n\ndef conv3x1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)\n\n\ndef conv3x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)\n\n\nclass Asymmetric_Residual_Block(nn.Module):\n    def __init__(self, in_filters, out_filters, time_filters=128):\n        super(Asymmetric_Residual_Block, self).__init__()\n        self.GroupNorm = nn.GroupNorm(32, in_filters)\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.bn0 = nn.GroupNorm(32, out_filters)\n        self.act1 = nn.LeakyReLU()\n          \n        self.conv1_2 = conv3x1x3(out_filters, out_filters)\n        self.bn0_2 = nn.GroupNorm(32, out_filters)\n        self.act1_2 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(in_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n        self.bn1 = nn.GroupNorm(32, out_filters)\n\n        self.conv3 = conv1x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n        self.bn2 = nn.GroupNorm(32, out_filters)\n\n\n    def forward(self, x, t):\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        x = self.GroupNorm(x) * (1 + scale) + shift\n\n        shortcut = self.conv1(x)\n        shortcut = self.act1(shortcut)\n        shortcut = self.bn0(shortcut)\n\n        shortcut = self.conv1_2(shortcut)\n        shortcut = self.act1_2(shortcut)\n        shortcut = self.bn0_2(shortcut)\n\n        resA = self.conv2(x) \n        resA = self.act2(resA)\n        resA = self.bn1(resA)\n\n        resA = self.conv3(resA)\n        resA = self.act3(resA)\n        resA = self.bn2(resA)\n        resA += shortcut\n\n        return resA\n\nclass DDCM(nn.Module):\n    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):\n        super(DDCM, self).__init__()\n        self.conv1 = conv3x1x1(in_filters, out_filters)\n        self.bn0 = nn.GroupNorm(32, out_filters)\n        self.act1 = nn.Sigmoid()\n\n        self.conv1_2 = conv1x3x1(in_filters, out_filters)\n        self.bn0_2 = nn.GroupNorm(32, out_filters)\n        self.act1_2 = nn.Sigmoid()\n\n        self.conv1_3 = conv1x1x3(in_filters, out_filters)\n        self.bn0_3 = nn.GroupNorm(32, out_filters)\n        self.act1_3 = nn.Sigmoid()\n\n    def forward(self, x):\n        shortcut = self.conv1(x)\n        shortcut = self.bn0(shortcut)\n        shortcut = self.act1(shortcut)\n\n        shortcut2 = self.conv1_2(x)\n        shortcut2 = self.bn0_2(shortcut2)\n        shortcut2 = self.act1_2(shortcut2)\n\n        shortcut3 = self.conv1_3(x)\n        shortcut3 = self.bn0_3(shortcut3)\n        shortcut3 = self.act1_3(shortcut3)\n        shortcut = shortcut + shortcut2 + shortcut3\n\n        shortcut = shortcut * x\n\n        return shortcut\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n\nclass Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_qkv = conv1x1(dim, dim*3, stride=1)\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x):\n        b, c, h, w, Z = x.shape\n        qkv = self.to_qkv(x).chunk(3, dim = 1)\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass Cross_Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_q = conv1x1(dim, dim, stride=1)\n        self.to_k = conv1x1(dim, dim, stride=1)\n        self.to_v = conv1x1(dim, dim, stride=1)\n\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x, cond_x):\n        b, c, h, w, Z = x.shape\n        q = self.to_q(x)\n        k = self.to_k(cond_x)\n        v = self.to_v(cond_x)\n\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass DownBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,\n                 pooling=True, drop_out=True, height_pooling=False):\n        super(DownBlock, self).__init__()\n        self.pooling = pooling\n        self.drop_out = drop_out\n\n        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)\n\n        if pooling:\n            if height_pooling:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,\n                                                padding=1, bias=False)\n            else:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),\n                                                padding=1, bias=False)\n\n\n    def forward(self, x, t):\n        resA = self.residual_block(x, t)\n        if self.pooling:\n            resB = self.pool(resA) \n            return resB, resA\n        else:\n            return resA\n\nclass UpBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):\n        super(UpBlock, self).__init__()\n        # self.drop_out = drop_out\n        self.trans_dilao = conv3x3x3(in_filters, in_filters)\n        self.trans_act = nn.LeakyReLU()\n        self.trans_bn = nn.GroupNorm(32, in_filters)\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.act1 = nn.LeakyReLU()\n        self.bn1 = nn.GroupNorm(32, out_filters)\n\n        self.conv2 = conv3x1x3(out_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n        self.bn2 = nn.GroupNorm(32, out_filters)\n\n        self.conv3 = conv3x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n        self.bn3 = nn.GroupNorm(32, out_filters)\n        \n        if height_pooling :\n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)\n        else : \n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)\n    \n\n    def forward(self, x, residual, t):\n        upA = self.trans_dilao(x) \n        upA = self.trans_act(upA)\n\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        upA = self.trans_bn(upA) * (1 + scale) + shift\n        ## upsample\n        upA = self.up_subm(upA)\n        upA += residual\n        upE = self.conv1(upA)\n        upE = self.act1(upE)\n        upE = self.bn1(upE)\n\n        upE = self.conv2(upE)\n        upE = self.act2(upE)\n        upE = self.bn2(upE)\n\n        upE = self.conv3(upE)\n        upE = self.act3(upE)\n        upE = self.bn3(upE)\n\n        return upE\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\nclass Denoise(nn.Module):\n    def __init__(self, args, num_class = 11, init_size=32, discrete=True):\n        super(Denoise, self).__init__()\n        self.args = args\n        self.discrete = discrete\n        self.num_class = num_class\n        self.init_size = init_size\n        self.time_size = init_size*4\n\n        self.time_embed = nn.Sequential(\n            nn.Linear(init_size, self.time_size),\n            nn.SiLU(),\n            nn.Linear(self.time_size, self.time_size),\n        )\n\n        self.embedding = nn.Embedding(self.num_class, init_size)\n        self.conv_in = nn.Conv3d(init_size, init_size, kernel_size=1, stride=1)\n\n        self.A = Asymmetric_Residual_Block(init_size, init_size)\n\n        self.midBlock1_1 = Asymmetric_Residual_Block(init_size, 2 * init_size)\n        self.attention1 = Attention(2 * init_size, 4)\n        self.midBlock1_2 = Asymmetric_Residual_Block(2 * init_size, 2 * init_size)\n\n        self.downBlock2 = DownBlock(init_size*2, 2 * init_size, 0.2, height_pooling=False)\n        self.downBlock3 = DownBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=False)\n        \n        self.midBlock2_1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n        self.attention2 = Attention(4 * init_size, 4)\n        self.midBlock2_2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)\n\n        self.upBlock0 = UpBlock(4 * init_size, 2 * init_size, height_pooling=False)\n        self.upBlock1 = UpBlock(2 * init_size, init_size, height_pooling=False)\n\n        self.midBlock3_1 = Asymmetric_Residual_Block(init_size, init_size)\n        self.attention3 = Attention(init_size, 4)\n        self.midBlock3_2 = Asymmetric_Residual_Block(init_size, init_size)\n\n        self.DDCM = DDCM(init_size, init_size)\n\n        self.logits = nn.Sequential(\n            nn.Conv3d(2 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True),\n        )\n\n    def forward(self, x, t):\n        x = self.embedding(x)\n        x = x.permute(0, 4, 1, 2, 3)\n        x = self.conv_in(x)\n        t = self.time_embed(timestep_embedding(t, self.init_size))\n\n        ret = self.A(x, t)\n\n        mid1 = self.midBlock1_1(ret, t)\n        att = self.attention1(mid1)\n        mid2 = self.midBlock1_2(att, t)\n\n        down1c, down1b = self.downBlock2(mid2, t) \n        down2c, down2b = self.downBlock3(down1c, t) \n\n        d_mid2 = self.midBlock2_1(down2c, t) \n        d_att = self.attention2(d_mid2)\n        d_mid1 = self.midBlock2_2(d_att, t) \n\n        up3e = self.upBlock0(d_mid1, down2b, t)\n        up2e = self.upBlock1(up3e, down1b, t)\n\n        u_mid2 = self.midBlock3_1(up2e, t) \n        u_att = self.attention3(u_mid2)\n        u_mid1 = self.midBlock3_2(u_att, t) \n\n        up0e = self.DDCM(u_mid1) \n        up0e = torch.cat((up0e, up2e), 1) \n        logits = self.logits(up0e) \n        \n        return logits\n"
  },
  {
    "path": "layers/Voxel_Level/Con_Diffusion.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.Voxel_Level.denoise import Denoise\nfrom utils.loss import *\n\"\"\"\nBased in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281\n\"\"\"\neps = 1e-8\n\n\ndef sum_except_batch(x, num_dims=1):\n    return x.reshape(*x.shape[:num_dims], -1).sum(-1)\n\n\ndef log_1_min_a(a):\n    return torch.log(1 - a.exp() + 1e-40)\n\n\ndef log_add_exp(a, b):\n    maximum = torch.max(a, b)\n    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))\n\n\ndef exists(x):\n    return x is not None\n\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef log_categorical(log_x_start, log_prob):\n    return (log_x_start.exp() * log_prob).sum(dim=1)\n\n\ndef index_to_log_onehot(x, num_classes):\n    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'\n    \n    x_onehot = F.one_hot(x, num_classes)\n    permute_order = (0, -1) + tuple(range(1, len(x.size())))\n    x_onehot = x_onehot.permute(permute_order)\n    log_x = torch.log(x_onehot.float().clamp(min=1e-30))\n    return log_x\n\n\ndef log_onehot_to_index(log_x):\n    return log_x.argmax(1)\n\n\ndef cosine_beta_schedule(timesteps, s = 0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = np.linspace(0, steps, steps)\n    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])\n\n    alphas = np.clip(alphas, a_min=0.001, a_max=1.)\n    alphas = np.sqrt(alphas)\n\n    return alphas\n\nclass Con_Diffusion(torch.nn.Module):\n    def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):\n        super(Con_Diffusion, self).__init__()\n\n        #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)\n        self.args = args\n        self.num_classes = self.args.num_classes\n        self.num_timesteps = self.args.diffusion_steps\n        self.recon_loss = self.args.recon_loss\n        if args.dataset == 'carla':\n            self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes)\n        elif args.dataset=='kitti':\n            self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes, init_size=16)\n        self.auxiliary_loss_weight = auxiliary_loss_weight\n        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss\n\n        self.multi_criterion = multi_criterion\n\n        alphas = cosine_beta_schedule(self.num_timesteps )\n        alphas = torch.tensor(alphas.astype('float64'))\n\n        log_alpha = np.log(alphas)\n        log_cumprod_alpha = np.cumsum(log_alpha)\n\n        log_1_min_alpha = log_1_min_a(log_alpha)\n        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)\n\n        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5\n        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5\n        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5\n\n        # Convert to float32 and register buffers.\n        self.register_buffer('log_alpha', log_alpha.float())\n        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())\n        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())\n        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())\n\n        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))\n        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))\n    \n    def device(self):\n        return self.denoise_fn.device\n\n    def multinomial_kl(self, log_prob1, log_prob2):\n        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)\n        return kl\n\n    def q_pred_one_timestep(self, log_x_t, t):\n        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)\n        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)\n\n        # alpha_t * E[xt] + (1 - alpha_t) 1 / K\n        \n        log_probs = log_add_exp(\n            log_x_t + log_alpha_t,\n            log_1_min_alpha_t - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def q_pred(self, log_x_start, t):\n        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)\n        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)\n\n        log_probs = log_add_exp(\n            log_x_start + log_cumprod_alpha_t,\n            log_1_min_cumprod_alpha - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def predict_start(self, log_x_t, t, cond):\n        x_t = log_onehot_to_index(log_x_t)\n\n        out = self._denoise_fn(x_t, cond, t)\n\n        log_pred = F.log_softmax(out, dim=1)\n        return log_pred\n\n    def q_posterior(self, log_x_start, log_x_t, t):\n        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)\n        # where q(xt | xt-1, x0) = q(xt | xt-1).\n\n        t_minus_1 = t - 1\n        # Remove negative values, will not be used anyway for final decoder\n        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)\n        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)\n\n        num_axes = (1,) * (len(log_x_start.size()) - 1)\n        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)\n        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)\n\n        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!\n        # Not very easy to see why this is true. But it is :)\n        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)\n\n        log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)\n\n        return log_EV_xtmin_given_xt_given_xstart\n\n    def p_pred(self, log_x, t, cond):\n        log_x0_recon = self.predict_start(log_x, t, cond)\n        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)\n        return log_model_pred, log_x0_recon\n\n    def log_sample_categorical(self, logits):\n        uniform = torch.rand_like(logits)\n        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)\n        sample = (gumbel_noise + logits).argmax(dim=1)\n        log_sample = index_to_log_onehot(sample, self.num_classes)\n        return log_sample\n\n    def q_sample(self, log_x_start, t):\n        log_EV_qxt_x0 = self.q_pred(log_x_start, t)\n        log_sample = self.log_sample_categorical(log_EV_qxt_x0)\n        return log_sample\n\n    def kl_prior(self, log_x_start):\n        b = log_x_start.size(0)\n        device = log_x_start.device\n        ones = torch.ones(b, device=device).long()\n\n        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)\n        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))\n\n        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)\n        return sum_except_batch(kl_prior)\n\n    def sample_time(self, b, device, method='uniform'):\n        if method == 'importance':\n            if not (self.Lt_count > 10).all():\n                return self.sample_time(b, device, method='uniform')\n\n            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001\n            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.\n            pt_all = Lt_sqrt / Lt_sqrt.sum()\n\n            t = torch.multinomial(pt_all, num_samples=b, replacement=True)\n\n            pt = pt_all.gather(dim=0, index=t)\n\n            return t, pt\n\n        elif method == 'uniform':\n            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n\n            pt = torch.ones_like(t).float() / self.num_timesteps\n            return t, pt\n        else:\n            raise ValueError\n\n    def forward(self, x, voxel_input):\n        b, device = x.size(0), x.device\n        self.shape = x.size()[1:]        \n        t, pt = self.sample_time(b, device, 'importance')\n\n        log_x_start = index_to_log_onehot(x, self.num_classes)\n        log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)\n\n        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)\n        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t, cond=voxel_input)\n\n        kl = self.multinomial_kl(log_true_prob, log_model_prob)\n        kl = sum_except_batch(kl)\n\n        decoder_nll = -log_categorical(log_x_start, log_model_prob)\n        decoder_nll = sum_except_batch(decoder_nll)\n\n        mask = (t == torch.zeros_like(t)).float()\n        kl_loss = mask * decoder_nll + (1. - mask) * kl\n        \n        if self.training:\n            Lt2 = kl_loss.pow(2)\n            Lt2_prev = self.Lt_history.gather(dim=0, index=t)\n            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()\n            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)\n            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))\n\n        kl_prior = self.kl_prior(log_x_start)\n\n        # Upweigh loss term of the kl\n        loss = kl_loss / pt + kl_prior\n\n        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])\n        kl_aux = sum_except_batch(kl_aux)\n        if self.recon_loss : \n            kl_aux += self.multi_criterion(log_x0_recon.exp(), x)\n            #kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)\n\n        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux\n        if self.adaptive_auxiliary_loss:\n            addition_loss_weight = (1-t/self.num_timesteps) + 1.0\n        else:\n            addition_loss_weight = 1.0\n\n        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt\n        \n        loss += aux_loss\n        loss = -loss.sum() / (self.shape[0]*self.shape[1])\n        #loss += seg_loss\n\n        return -loss\n\n    def sample(self, voxel_input, intermediate=False):\n        device = self.log_alpha.device\n        self.shape = voxel_input.size()[1:]\n        uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)\n        log_z = self.log_sample_categorical(uniform_logits)\n        diffusion = []\n\n        for i in reversed(range(0, self.num_timesteps)):\n            print(f'Sample timestep {i:4d}', end='\\r')\n\n            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)\n\n            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t, cond=voxel_input)\n\n            log_z = self.log_sample_categorical(log_model_prob)\n\n            if i%10 ==0:\n                diffusion.append(log_onehot_to_index(log_z))\n\n        result = log_onehot_to_index(log_z)\n        if intermediate : \n            return result, diffusion\n        else : \n            return result\n\n"
  },
  {
    "path": "layers/Voxel_Level/Gen_Diffusion.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.Voxel_Level.gen_denoise import Denoise\nfrom utils.loss import *\n\"\"\"\nBased in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281\n\"\"\"\neps = 1e-8\n\n\ndef sum_except_batch(x, num_dims=1):\n    return x.reshape(*x.shape[:num_dims], -1).sum(-1)\n\n\ndef log_1_min_a(a):\n    return torch.log(1 - a.exp() + 1e-40)\n\n\ndef log_add_exp(a, b):\n    maximum = torch.max(a, b)\n    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))\n\n\ndef exists(x):\n    return x is not None\n\n\ndef extract(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t)\n    return out.reshape(b, *((1,) * (len(x_shape) - 1)))\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef log_categorical(log_x_start, log_prob):\n    return (log_x_start.exp() * log_prob).sum(dim=1)\n\n\ndef index_to_log_onehot(x, num_classes):\n    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'\n    \n    x_onehot = F.one_hot(x, num_classes)\n    permute_order = (0, -1) + tuple(range(1, len(x.size())))\n    x_onehot = x_onehot.permute(permute_order)\n    log_x = torch.log(x_onehot.float().clamp(min=1e-30))\n    return log_x\n\n\ndef log_onehot_to_index(log_x):\n    return log_x.argmax(1)\n\n\ndef cosine_beta_schedule(timesteps, s = 0.008):\n    \"\"\"\n    cosine schedule\n    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ\n    \"\"\"\n    steps = timesteps + 1\n    x = np.linspace(0, steps, steps)\n    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2\n    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]\n    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])\n\n    alphas = np.clip(alphas, a_min=0.001, a_max=1.)\n    alphas = np.sqrt(alphas)\n\n    return alphas\n\nclass Diffusion(torch.nn.Module):\n    def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):\n        super(Diffusion, self).__init__()\n\n        #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)\n        self.args = args\n        self.num_classes = self.args.num_classes\n        self.num_timesteps = self.args.diffusion_steps\n        self.recon_loss = self.args.recon_loss\n        self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes)\n        self.auxiliary_loss_weight = auxiliary_loss_weight\n        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss\n\n        self.multi_criterion = multi_criterion\n\n        alphas = cosine_beta_schedule(self.num_timesteps )\n\n        alphas = torch.tensor(alphas.astype('float64'))\n        log_alpha = np.log(alphas)\n        log_cumprod_alpha = np.cumsum(log_alpha)\n\n        log_1_min_alpha = log_1_min_a(log_alpha)\n        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)\n\n        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5\n        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5\n        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5\n\n        # Convert to float32 and register buffers.\n        self.register_buffer('log_alpha', log_alpha.float())\n        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())\n        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())\n        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())\n\n        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))\n        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))\n    \n    def device(self):\n        return self.denoise_fn.device\n\n    def multinomial_kl(self, log_prob1, log_prob2):\n        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)\n        return kl\n\n    def q_pred_one_timestep(self, log_x_t, t):\n        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)\n        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)\n\n        # alpha_t * E[xt] + (1 - alpha_t) 1 / K\n        \n        log_probs = log_add_exp(\n            log_x_t + log_alpha_t,\n            log_1_min_alpha_t - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def q_pred(self, log_x_start, t):\n        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)\n        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)\n\n        log_probs = log_add_exp(\n            log_x_start + log_cumprod_alpha_t,\n            log_1_min_cumprod_alpha - np.log(self.num_classes)\n        )\n\n        return log_probs\n\n    def predict_start(self, log_x_t, t):\n        x_t = log_onehot_to_index(log_x_t)\n\n        out = self._denoise_fn(x_t, t)\n\n        log_pred = F.log_softmax(out, dim=1)\n        return log_pred\n\n    def q_posterior(self, log_x_start, log_x_t, t):\n        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)\n        # where q(xt | xt-1, x0) = q(xt | xt-1).\n\n        t_minus_1 = t - 1\n        # Remove negative values, will not be used anyway for final decoder\n        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)\n        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)\n\n        num_axes = (1,) * (len(log_x_start.size()) - 1)\n        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)\n        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)\n\n        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!\n        # Not very easy to see why this is true. But it is :)\n        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)\n\n        log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)\n\n        return log_EV_xtmin_given_xt_given_xstart\n\n    def p_pred(self, log_x, t):\n        log_x0_recon = self.predict_start(log_x, t)\n        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)\n        return log_model_pred, log_x0_recon\n\n    def log_sample_categorical(self, logits):\n        uniform = torch.rand_like(logits)\n        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)\n        sample = (gumbel_noise + logits).argmax(dim=1)\n        log_sample = index_to_log_onehot(sample, self.num_classes)\n        return log_sample\n\n    def q_sample(self, log_x_start, t):\n        log_EV_qxt_x0 = self.q_pred(log_x_start, t)\n        log_sample = self.log_sample_categorical(log_EV_qxt_x0)\n        return log_sample\n\n    def kl_prior(self, log_x_start):\n        b = log_x_start.size(0)\n        device = log_x_start.device\n        ones = torch.ones(b, device=device).long()\n\n        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)\n        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))\n\n        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)\n        return sum_except_batch(kl_prior)\n\n    def sample_time(self, b, device, method='uniform'):\n        if method == 'importance':\n            if not (self.Lt_count > 10).all():\n                return self.sample_time(b, device, method='uniform')\n\n            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001\n            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.\n            pt_all = Lt_sqrt / Lt_sqrt.sum()\n\n            t = torch.multinomial(pt_all, num_samples=b, replacement=True)\n\n            pt = pt_all.gather(dim=0, index=t)\n\n            return t, pt\n\n        elif method == 'uniform':\n            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()\n\n            pt = torch.ones_like(t).float() / self.num_timesteps\n            return t, pt\n        else:\n            raise ValueError\n\n    def forward(self, x, voxel_input):\n        b, device = x.size(0), x.device\n        self.shape = x.size()[1:]        \n        t, pt = self.sample_time(b, device, 'importance')\n\n        log_x_start = index_to_log_onehot(x, self.num_classes)\n        log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)\n\n        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)\n        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)\n\n        kl = self.multinomial_kl(log_true_prob, log_model_prob)\n        kl = sum_except_batch(kl)\n\n        decoder_nll = -log_categorical(log_x_start, log_model_prob)\n        decoder_nll = sum_except_batch(decoder_nll)\n\n        mask = (t == torch.zeros_like(t)).float()\n        kl_loss = mask * decoder_nll + (1. - mask) * kl\n        \n        if self.training:\n            Lt2 = kl_loss.pow(2)\n            Lt2_prev = self.Lt_history.gather(dim=0, index=t)\n            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()\n            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)\n            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))\n\n        kl_prior = self.kl_prior(log_x_start)\n\n        # Upweigh loss term of the kl\n        loss = kl_loss / pt + kl_prior\n\n        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])\n        kl_aux = sum_except_batch(kl_aux)\n        '''if self.recon_loss : \n            kl_aux += self.multi_criterion(log_x0_recon.exp(), x)\n            kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)'''\n\n        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux\n        if self.adaptive_auxiliary_loss:\n            addition_loss_weight = (1-t/self.num_timesteps) + 1.0\n        else:\n            addition_loss_weight = 1.0\n\n        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt\n        \n        loss += aux_loss\n        loss = -loss.sum() / (self.shape[0]*self.shape[1])\n        #loss += seg_loss\n\n        return -loss\n\n    def sample(self, voxel_input):\n        device = self.log_alpha.device\n        self.shape = voxel_input.size()[1:]\n        uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)\n        log_z = self.log_sample_categorical(uniform_logits)\n\n        for i in reversed(range(0, self.num_timesteps)):\n            print(f'Sample timestep {i:4d}', end='\\r')\n\n            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)\n\n            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)\n\n            log_z = self.log_sample_categorical(log_model_prob)\n\n        result = log_onehot_to_index(log_z)\n        return result\n\n"
  },
  {
    "path": "layers/Voxel_Level/denoise.py",
    "content": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nfrom einops import rearrange, reduce, repeat\nfrom torch import nn, einsum\n\n\ndef conv3x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\ndef conv1x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)\n\n\ndef conv1x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)\n\n\ndef conv1x3x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)\n\n\ndef conv3x1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)\n\n\ndef conv3x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)\n\n\nclass Asymmetric_Residual_Block(nn.Module):\n    def __init__(self, in_filters, out_filters, time_filters=32*4):\n        super(Asymmetric_Residual_Block, self).__init__()\n        if in_filters<32 :\n            self.GroupNorm = nn.GroupNorm(16, in_filters)\n            self.bn0 = nn.GroupNorm(16, out_filters)\n            self.bn0_2 = nn.GroupNorm(16, out_filters)\n            self.bn1 = nn.GroupNorm(16, out_filters)\n            self.bn2 = nn.GroupNorm(16, out_filters)\n        else :\n            self.GroupNorm = nn.GroupNorm(32, in_filters)\n            self.bn0 = nn.GroupNorm(32, out_filters)\n            self.bn0_2 = nn.GroupNorm(32, out_filters)\n            self.bn1 = nn.GroupNorm(32, out_filters)\n            self.bn2 = nn.GroupNorm(32, out_filters)\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.act1 = nn.LeakyReLU()\n          \n        self.conv1_2 = conv3x1x3(out_filters, out_filters)\n        self.act1_2 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(in_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n\n        self.conv3 = conv1x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n\n\n    def forward(self, x, t):\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        x = self.GroupNorm(x) * (1 + scale) + shift\n\n        shortcut = self.conv1(x)\n        shortcut = self.act1(shortcut)\n        shortcut = self.bn0(shortcut)\n\n        shortcut = self.conv1_2(shortcut)\n        shortcut = self.act1_2(shortcut)\n        shortcut = self.bn0_2(shortcut)\n\n        resA = self.conv2(x) \n        resA = self.act2(resA)\n        resA = self.bn1(resA)\n\n        resA = self.conv3(resA) \n        resA = self.act3(resA)\n        resA = self.bn2(resA)\n        resA += shortcut\n\n        return resA\n\nclass DDCM(nn.Module):\n    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):\n        super(DDCM, self).__init__()\n        self.conv1 = conv3x1x1(in_filters, out_filters)\n        if in_filters<32 :\n            self.bn0 = nn.GroupNorm(16, out_filters)\n            self.bn0_2 = nn.GroupNorm(16, out_filters)\n            self.bn0_3 = nn.GroupNorm(16, out_filters)\n        else :\n            self.bn0 = nn.GroupNorm(32, out_filters)\n            self.bn0_2 = nn.GroupNorm(32, out_filters)\n            self.bn0_3 = nn.GroupNorm(32, out_filters)\n        self.act1 = nn.Sigmoid()\n\n        self.conv1_2 = conv1x3x1(in_filters, out_filters)\n        self.act1_2 = nn.Sigmoid()\n\n        self.conv1_3 = conv1x1x3(in_filters, out_filters)\n        self.act1_3 = nn.Sigmoid()\n\n    def forward(self, x):\n        shortcut = self.conv1(x)\n        shortcut = self.bn0(shortcut)\n        shortcut = self.act1(shortcut)\n\n        shortcut2 = self.conv1_2(x)\n        shortcut2 = self.bn0_2(shortcut2)\n        shortcut2 = self.act1_2(shortcut2)\n\n        shortcut3 = self.conv1_3(x)\n        shortcut3 = self.bn0_3(shortcut3)\n        shortcut3 = self.act1_3(shortcut3)\n        shortcut = shortcut + shortcut2 + shortcut3\n\n        shortcut = shortcut * x\n\n        return shortcut\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n\nclass Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_qkv = conv1x1(dim, dim*3, stride=1)\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x):\n        b, c, h, w, Z = x.shape\n        qkv = self.to_qkv(x).chunk(3, dim = 1)\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass Cross_Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_q = conv1x1(dim, dim, stride=1)\n        self.to_k = conv1x1(dim, dim, stride=1)\n        self.to_v = conv1x1(dim, dim, stride=1)\n\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x, cond_x):\n        b, c, h, w, Z = x.shape\n        q = self.to_q(x)\n        k = self.to_k(cond_x)\n        v = self.to_v(cond_x)\n\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass DownBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_size=(3, 3, 3), stride=1,\n                 pooling=True, height_pooling=False):\n        super(DownBlock, self).__init__()\n        self.pooling = pooling\n\n        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters=time_filters)\n\n        if pooling:\n            if height_pooling:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)\n            else:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)\n\n    def forward(self, x, t):\n        resA = self.residual_block(x, t)\n        if self.pooling:\n            resB = self.pool(resA) \n            return resB, resA\n        else:\n            return resA\n\nclass UpBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):\n        super(UpBlock, self).__init__()\n        # self.drop_out = drop_out\n        if out_filters<32 :\n            self.trans_bn = nn.GroupNorm(16, in_filters)\n            self.bn1 = nn.GroupNorm(16, out_filters)\n            self.bn2 = nn.GroupNorm(16, out_filters)\n            self.bn3 = nn.GroupNorm(16, out_filters)\n        else :\n            self.trans_bn = nn.GroupNorm(32, in_filters)\n            self.bn1 = nn.GroupNorm(32, out_filters)\n            self.bn2 = nn.GroupNorm(32, out_filters)\n            self.bn3 = nn.GroupNorm(32, out_filters)\n        self.trans_dilao = conv3x3x3(in_filters, in_filters)\n        self.trans_act = nn.LeakyReLU()\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.act1 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(out_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n\n        self.conv3 = conv3x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n        \n        if height_pooling :\n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)\n        else : \n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)\n    \n\n    def forward(self, x, residual, t): \n        upA = self.trans_dilao(x) \n        upA = self.trans_act(upA)\n\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        upA = self.trans_bn(upA) * (1 + scale) + shift\n        ## upsample\n        upA = self.up_subm(upA)\n        upA += residual\n        upE = self.conv1(upA)\n        upE = self.act1(upE)\n        upE = self.bn1(upE)\n\n        upE = self.conv2(upE)\n        upE = self.act2(upE)\n        upE = self.bn2(upE)\n\n        upE = self.conv3(upE)\n        upE = self.act3(upE)\n        upE = self.bn3(upE)\n\n        return upE\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\nclass Denoise(nn.Module):\n    def __init__(self, args, num_class = 11, init_size=32, discrete=True):\n        super(Denoise, self).__init__()\n        self.args = args\n        self.discrete = discrete\n        self.num_class = num_class\n        self.init_size = init_size\n        self.time_size = self.init_size*4\n\n        self.time_embed = nn.Sequential(\n            nn.Linear(init_size, self.time_size),\n            nn.SiLU(),\n            nn.Linear(self.time_size, self.time_size),\n        )\n\n        self.embedding = nn.Embedding(self.num_class, init_size)\n        self.conv_in = nn.Conv3d(init_size+1, init_size, kernel_size=1, stride=1)\n\n        self.A = Asymmetric_Residual_Block(init_size, init_size, time_filters=init_size*4)\n\n        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)\n        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True, time_filters=init_size*4)\n        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)\n        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False, time_filters=init_size*4)\n        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)\n        self.attention = Attention(16 * init_size, 32)\n        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)\n\n        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)\n        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=init_size*4)\n        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)\n        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)\n\n        self.DDCM = DDCM(2 * init_size, 2 * init_size)\n        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)\n        \n    def forward(self, x, x_cond, t):\n        x = self.embedding(x)\n        x = x.permute(0, 4, 1, 2, 3)\n        x_cond = x_cond.unsqueeze(1)\n        x = torch.cat([x, x_cond], dim=1)\n        x = self.conv_in(x)\n\n        t = self.time_embed(timestep_embedding(t, self.init_size))\n\n        x = self.A(x, t)\n\n        down1c, down1b = self.downBlock1(x, t)\n        down2c, down2b = self.downBlock2(down1c, t)\n        down3c, down3b = self.downBlock3(down2c, t)\n        down4c, down4b = self.downBlock4(down3c, t)\n\n        down4c = self.midBlock1(down4c, t)\n        down4c = self.attention(down4c)\n        down4c = self.midBlock2(down4c, t)\n        \n        up4 = self.upBlock4(down4c, down4b, t)\n        up3 = self.upBlock3(up4, down3b, t)\n        up2 = self.upBlock2(up3, down2b, t)\n        up1 = self.upBlock1(up2, down1b, t)\n        up0 = self.DDCM(up1)\n        up = torch.cat((up1, up0), 1)\n        logits = self.logits(up) \n       \n        return logits\n"
  },
  {
    "path": "layers/Voxel_Level/gen_denoise.py",
    "content": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional as F\nfrom einops import rearrange, reduce, repeat\nfrom torch import nn, einsum\n\n\ndef conv3x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)\n\ndef conv1x3x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)\n\n\ndef conv1x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)\n\n\ndef conv1x3x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)\n\n\ndef conv3x1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)\n\n\ndef conv3x1x3(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)\n\n\ndef conv1x1(in_planes, out_planes, stride=1):\n    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)\n\n\nclass Asymmetric_Residual_Block(nn.Module):\n    def __init__(self, in_filters, out_filters, time_filters=128):\n        super(Asymmetric_Residual_Block, self).__init__()\n        if in_filters < 32 : \n            n_ng = in_filters\n        else : n_ng =32\n        self.GroupNorm = nn.GroupNorm(n_ng, in_filters)\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        if out_filters < 32 : \n            n_ng = out_filters\n        else : n_ng =32\n        self.bn0 = nn.GroupNorm(n_ng, out_filters)\n        self.act1 = nn.LeakyReLU()\n          \n        self.conv1_2 = conv3x1x3(out_filters, out_filters)\n        self.bn0_2 = nn.GroupNorm(n_ng, out_filters)\n        self.act1_2 = nn.LeakyReLU()\n\n        self.conv2 = conv3x1x3(in_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n        self.bn1 = nn.GroupNorm(n_ng, out_filters)\n\n        self.conv3 = conv1x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n        self.bn2 = nn.GroupNorm(n_ng, out_filters)\n\n\n    def forward(self, x, t):\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        x = self.GroupNorm(x) * (1 + scale) + shift\n\n        shortcut = self.conv1(x) \n        shortcut = self.act1(shortcut)\n        shortcut = self.bn0(shortcut)\n\n        shortcut = self.conv1_2(shortcut) \n        shortcut = self.act1_2(shortcut)\n        shortcut = self.bn0_2(shortcut)\n\n        resA = self.conv2(x)\n        resA = self.act2(resA)\n        resA = self.bn1(resA)\n\n        resA = self.conv3(resA) \n        resA = self.act3(resA)\n        resA = self.bn2(resA)\n        resA += shortcut\n\n        return resA\n\nclass DDCM(nn.Module):\n    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):\n        super(DDCM, self).__init__()\n        self.conv1 = conv3x1x1(in_filters, out_filters)\n        if out_filters < 32 : \n            n_ng = out_filters\n        else : n_ng =32\n        self.bn0 = nn.GroupNorm(n_ng, out_filters)\n        self.act1 = nn.Sigmoid()\n\n        self.conv1_2 = conv1x3x1(in_filters, out_filters)\n        self.bn0_2 = nn.GroupNorm(n_ng, out_filters)\n        self.act1_2 = nn.Sigmoid()\n\n        self.conv1_3 = conv1x1x3(in_filters, out_filters)\n        self.bn0_3 = nn.GroupNorm(n_ng, out_filters)\n        self.act1_3 = nn.Sigmoid()\n\n    def forward(self, x):\n        shortcut = self.conv1(x)\n        shortcut = self.bn0(shortcut)\n        shortcut = self.act1(shortcut)\n\n        shortcut2 = self.conv1_2(x)\n        shortcut2 = self.bn0_2(shortcut2)\n        shortcut2 = self.act1_2(shortcut2)\n\n        shortcut3 = self.conv1_3(x)\n        shortcut3 = self.bn0_3(shortcut3)\n        shortcut3 = self.act1_3(shortcut3)\n        shortcut = shortcut + shortcut2 + shortcut3\n        shortcut = shortcut * x\n\n        return shortcut\n\ndef l2norm(t):\n    return F.normalize(t, dim = -1)\n\nclass Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_qkv = conv1x1(dim, dim*3, stride=1)\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x):\n        b, c, h, w, Z = x.shape\n        qkv = self.to_qkv(x).chunk(3, dim = 1)\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass Cross_Attention(nn.Module):\n    def __init__(self, dim, heads = 4, scale = 10):\n        super().__init__()\n        self.scale = scale\n        self.heads = heads\n        self.to_q = conv1x1(dim, dim, stride=1)\n        self.to_k = conv1x1(dim, dim, stride=1)\n        self.to_v = conv1x1(dim, dim, stride=1)\n\n        self.to_out = conv1x1(dim, dim, stride=1)\n\n    def forward(self, x, cond_x):\n        b, c, h, w, Z = x.shape\n        q = self.to_q(x)\n        k = self.to_k(cond_x)\n        v = self.to_v(cond_x)\n\n        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))\n\n        q, k = map(l2norm, (q, k))\n\n        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale\n        attn = sim.softmax(dim = -1)\n        out = einsum('b h i j, b h d j -> b h i d', attn, v)\n        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)\n        return self.to_out(out)\n\nclass DownBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, time_filters, kernel_size=(3, 3, 3), stride=1,\n                 pooling=True, drop_out=True, height_pooling=False):\n        super(DownBlock, self).__init__()\n        self.pooling = pooling\n        self.drop_out = drop_out\n        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters)\n\n        if pooling:\n            if height_pooling:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,\n                                                padding=1, bias=False)\n            else:\n                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),\n                                                padding=1, bias=False)\n\n\n    def forward(self, x, t):\n        resA = self.residual_block(x, t)\n        if self.pooling:\n            resB = self.pool(resA) \n            return resB, resA\n        else:\n            return resA\n\nclass UpBlock(nn.Module):\n    def __init__(self, in_filters, out_filters, height_pooling, time_filters):\n        super(UpBlock, self).__init__()\n        # self.drop_out = drop_out\n        self.trans_dilao = conv3x3x3(in_filters, in_filters)\n        self.trans_act = nn.LeakyReLU()\n        if in_filters < 32 : \n            n_ng = out_filters\n        else : n_ng =32\n        self.trans_bn = nn.GroupNorm(n_ng, in_filters)\n        self.time_layers = nn.Sequential(\n                            nn.SiLU(),\n                            nn.Linear(time_filters, in_filters*2)\n                        )\n\n        self.conv1 = conv1x3x3(in_filters, out_filters)\n        self.act1 = nn.LeakyReLU()\n        if out_filters < 32 : \n            n_ng = out_filters\n        else :n_ng = 32\n        self.bn1 = nn.GroupNorm(n_ng, out_filters)\n\n        self.conv2 = conv3x1x3(out_filters, out_filters)\n        self.act2 = nn.LeakyReLU()\n        self.bn2 = nn.GroupNorm(n_ng, out_filters)\n\n        self.conv3 = conv3x3x3(out_filters, out_filters)\n        self.act3 = nn.LeakyReLU()\n        self.bn3 = nn.GroupNorm(n_ng, out_filters)\n        \n        if height_pooling :\n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)\n        else : \n            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)\n    \n\n    def forward(self, x, residual, t):\n        upA = self.trans_dilao(x)\n        upA = self.trans_act(upA)\n\n        t = self.time_layers(t)\n        while len(t.shape) < len(x.shape):\n            t = t[..., None]\n        scale, shift = torch.chunk(t, 2, dim=1)\n        \n        upA = self.trans_bn(upA) * (1 + scale) + shift\n        ## upsample\n        upA = self.up_subm(upA)\n        upA += residual\n        upE = self.conv1(upA)\n        upE = self.act1(upE)\n        upE = self.bn1(upE)\n\n        upE = self.conv2(upE)\n        upE = self.act2(upE)\n        upE = self.bn2(upE)\n\n        upE = self.conv3(upE)\n        upE = self.act3(upE)\n        upE = self.bn3(upE)\n\n        return upE\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, 'b -> b d', d=dim)\n    return embedding\n\nclass Denoise(nn.Module):\n    def __init__(self, args, num_class = 11, init_size=32, discrete=True):\n        super(Denoise, self).__init__()\n        self.args = args\n        self.discrete = discrete\n        self.num_class = num_class\n        self.init_size = init_size\n        self.time_size = init_size*4\n\n        self.time_embed = nn.Sequential(\n            nn.Linear(self.init_size, self.time_size),\n            nn.SiLU(),\n            nn.Linear(self.time_size, self.time_size),\n        )\n\n        self.embedding = nn.Embedding(self.num_class, self.init_size)\n        self.conv_in = nn.Conv3d(self.init_size, self.init_size, kernel_size=1, stride=1)\n\n        self.A = Asymmetric_Residual_Block(self.init_size, self.init_size, self.time_size)\n\n        self.downBlock1 = DownBlock(init_size, 2 * init_size, self.time_size, height_pooling=True)\n        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, self.time_size, height_pooling=True)\n        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, self.time_size, height_pooling=False)\n        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, self.time_size, height_pooling=False)\n        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)\n        self.attention = Attention(16 * init_size, 32)\n        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)\n        \n        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=self.time_size)\n        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=self.time_size)\n        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)\n        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)\n        self.DDCM = DDCM(2 * init_size, 2 * init_size)\n        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)\n \n\n    def forward(self, x, t):\n        x = self.embedding(x)\n        x = x.permute(0, 4, 1, 2, 3)\n        x = self.conv_in(x)\n\n        t = self.time_embed(timestep_embedding(t, self.init_size))\n\n        x = self.A(x, t)\n\n        down1c, down1b = self.downBlock1(x, t) \n        down2c, down2b = self.downBlock2(down1c, t) \n        down3c, down3b = self.downBlock3(down2c, t) \n        \n        down4c, down4b = self.downBlock4(down3c, t) \n        down4c = self.midBlock1(down4c, t) \n        down4c = self.attention(down4c)\n        down4c = self.midBlock2(down4c, t) \n        up4 = self.upBlock4(down4c, down4b, t)\n        up3 = self.upBlock3(up4, down3b, t)\n\n\n        up2 = self.upBlock2(up3, down2b, t)\n        up1 = self.upBlock1(up2, down1b, t)\n\n        up0 = self.DDCM(up1) \n\n        up = torch.cat((up1, up0), 1)\n\n        logits = self.logits(up) \n        \n        return logits\n"
  },
  {
    "path": "layers/__init__.py",
    "content": ""
  },
  {
    "path": "requirements.txt",
    "content": "numpy\ntorch\nscipy\nscikit-learn\nmatplotlib\ntqdm\nopen3d\npyyaml\nprettytable\ntensorboard\nnumba\neinops\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name=\"scene_scale_diffusion\",\n    version=\"0.1\",\n    author=\"Lee Jumin, Im Woobin, Lee Sebin, Yoon Sung-Eui\",\n    author_email=\"\",\n    description=\"Experiments in PyTorch\",\n    long_description=\"\",\n    packages=setuptools.find_packages(),\n    classifiers=[\n        \"Programming Language :: Python :: 3\",\n        \"License :: OSI Approved :: MIT License\",\n        \"Operating System :: OS Independent\",\n    ],\n)\n"
  },
  {
    "path": "simple_visualize.py",
    "content": "import os\nimport numpy as np\nimport open3d as o3d\nimport argparse\nimport yaml\n\ndef load_config(yaml_path):\n    with open(yaml_path, 'r') as f:\n        config = yaml.safe_load(f)\n    return config[\"learning_map\"], config[\"remap_color_map\"]\n\ndef load_pointcloud(filepath, learning_map, color_map):\n    data = np.loadtxt(filepath, delimiter=' ')\n    if data.shape[1] < 4:\n        raise ValueError(f\"Expected at least 4 columns (label + x y z), got shape {data.shape}\")\n\n    raw_labels = data[:, 0].astype(int)\n    points = data[:, 1:4]\n\n    # Map raw labels → remapped labels → colors\n    remapped_labels = np.array([learning_map.get(int(l), 0) for l in raw_labels])\n    colors = np.array([color_map.get(int(l), [255, 255, 255]) for l in remapped_labels]) / 255.0\n\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n    pcd.colors = o3d.utility.Vector3dVector(colors)\n    return pcd\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--file', default='result_for_l_gen/Completion/result_0.txt',\n                        help='Path to the point cloud .txt file')\n    parser.add_argument('--config', default='datasets/carla.yaml',\n                        help='Path to Carla YAML config file')\n    args = parser.parse_args()\n\n    if not os.path.exists(args.file):\n        raise FileNotFoundError(f\"Point cloud file not found: {args.file}\")\n    if not os.path.exists(args.config):\n        raise FileNotFoundError(f\"YAML config file not found: {args.config}\")\n\n    learning_map, color_map = load_config(args.config)\n    pcd = load_pointcloud(args.file, learning_map, color_map)\n\n    o3d.visualization.draw([pcd])\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "train.py",
    "content": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimport torch.nn.functional as F\nimport yaml\n\nfrom prettytable import PrettyTable\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom utils.tables import *\nfrom utils.dicts import clean_dict\nfrom utils.loss import lovasz_softmax\n\n\nclass Experiment(object):\n    no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']\n                   \n    def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch,\n                 train_loader, eval_loader, test_loader, train_sampler,\n                 log_path, eval_every, check_every):\n\n        # Objects\n        self.model = model\n\n        self.loss_fun = torch.nn.CrossEntropyLoss(ignore_index=0)\n        self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch\n        # Paths\n        self.log_path = log_path\n\n        if args.dataset =='carla':\n            config_file = os.path.join('./datasets/carla.yaml')\n            carla_config = yaml.safe_load(open(config_file, 'r'))\n            self.color_map = carla_config[\"remap_color_map\"]\n            self.remap = None\n            LABEL_TO_NAMES = carla_config[\"label_to_names\"]\n            self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))\n\n        # Intervals\n        self.eval_every, self.check_every = eval_every, check_every\n\n        # Initialize\n        self.current_epoch = 0\n        self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}\n        self.eval_epochs = []\n        self.completion_epochs = []\n\n        # Store data loaders\n        self.train_loader, self.eval_loader, self.test_loader, self.train_sampler = train_loader, eval_loader, test_loader, train_sampler\n\n        # Store args\n        create_folders(args)\n        save_args(args)\n        self.args = args\n\n        # Init logging\n        args_dict = clean_dict(vars(args), keys=self.no_log_keys)\n        if args.log_tb:\n            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))\n            self.writer.add_text(\"args\", get_args_table(args_dict).get_html_string(), global_step=0)\n\n    def run(self, epochs):\n        if self.args.resume: \n            self.resume()\n        \n        for epoch in range(self.current_epoch, epochs): \n            \n            # Train\n            train_dict = self.train_fn(epoch)\n            self.log_metrics(train_dict, self.train_metrics)\n\n            # Checkpoint\n            self.current_epoch += 1\n            if (epoch+1) % self.check_every == 0:\n                self.checkpoint_save(epoch)\n\n            # Eval\n            if (epoch+1) % self.eval_every == 0:\n                eval_dict = self.eval_fn(epoch)\n                self.log_metrics(eval_dict, self.eval_metrics)\n                self.eval_epochs.append(epoch)\n            else:\n                eval_dict = None\n\n            if (epoch+1) % self.args.completion_epoch == 0:\n                ssc_dict, miou, seg_dict, seg_miou = self.sample()\n                self.log_metrics(ssc_dict, self.ssc_metrics)\n                self.log_metrics(seg_dict, self.ssc_metrics)\n                self.completion_epochs.append(epoch)\n            else :\n                ssc_dict, seg_dict = None, None\n\n            # Log\n            #self.save_metrics()\n            if self.args.log_tb:\n                for metric_name, metric_value in train_dict.items():\n                    self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)\n                if eval_dict:\n                    for metric_name, metric_value in eval_dict.items():\n                        self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)\n                if ssc_dict:\n                    for metric_name, metric_value in ssc_dict.items():\n                        self.writer.add_scalar('SSC/{}'.format(metric_name), metric_value, global_step=epoch+1)\n                    self.writer.add_text(\"SSC_mIoU\", get_miou_table(self.args, self.label_to_names, miou).get_html_string(), global_step=epoch+1)\n                    for metric_name, metric_value in seg_dict.items():\n                        self.writer.add_scalar('Seg/{}'.format(metric_name), metric_value, global_step=epoch+1)\n                    self.writer.add_text(\"Seg_mIoU\", get_miou_table(self.args, self.label_to_names, seg_miou).get_html_string(), global_step=epoch+1)\n\n    def train_fn(self, epoch):\n        self.model.train()\n        loss_sum = 0.0\n        loss_count = 0\n        if self.args.distribution :\n            self.train_sampler.set_epoch(epoch)\n\n        for voxel_input, output, counts in self.train_loader:\n            self.optimizer.zero_grad()\n            voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)\n            output = torch.from_numpy(np.asarray(output)).long().cuda()            \n            if self.args.distribution:\n                loss = self.model.module(output, voxel_input)\n            else : \n                loss = self.model(output, voxel_input)\n            loss.backward()\n\n            if self.args.clip_value: torch.nn.utils.clip_grad_value_(self.model.parameters(), self.args.clip_value)\n            if self.args.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_norm)\n\n            self.optimizer.step()\n            if self.scheduler_iter: self.scheduler_iter.step()\n            loss_sum += loss.detach().cpu().item() * len(output)\n            loss_count += len(output)\n            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\\r')\n        print('')\n        if self.scheduler_epoch: self.scheduler_epoch.step()\n        return {'loss': loss_sum/loss_count}\n\n\n    def eval_fn(self, epoch):\n        self.model.eval()\n\n        with torch.no_grad():\n            loss_sum = 0.0\n            loss_count = 0\n            for voxel_input, output, counts in self.eval_loader:\n                voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)\n                output = torch.from_numpy(np.asarray(output)).long().cuda()            \n                if self.args.distribution:\n                    loss = self.model.module(output, voxel_input)\n                else : \n                    loss = self.model(output, voxel_input)\n                loss_sum += loss.detach().cpu().item() * len(output)\n                loss_count += len(output)\n                print('Train evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\\r')\n            print('')\n        return {'loss': loss_sum/loss_count}\n\n\n    def sample(self):\n        self.model.eval()\n        with torch.no_grad():\n            TP, FP, TN, FN, num_correct, num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0\n            s_TP, s_FP, s_TN, s_FN, s_num_correct, s_num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0\n            all_intersections, all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6\n            s_all_intersections, s_all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6\n            if self.args.dataset == 'carla':\n                dataloader = self.test_loader\n            else :\n                dataloader = self.eval_loader\n            for iterate, (voxel_input, output, counts) in enumerate(dataloader):\n                if len(voxel_input) == self.args.batch_size :\n                    voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)\n                    output = torch.from_numpy(np.asarray(output)).long().cuda()            \n                    invalid = torch.from_numpy(np.asarray(counts)).cuda()\n\n                    if self.args.mode == 'l_vae':\n                        if self.args.distribution:\n                            recons = self.model.module.sample(output) \n                        else : \n                            recons = self.model.sample(output) \n                    else :\n                        if self.args.distribution:\n                            recons = self.model.module.sample(voxel_input) \n                        else : \n                            recons = self.model.sample(voxel_input)   \n\n                    visualization(self.args, recons, voxel_input, output, invalid, iteration = iterate)\n                    correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union = get_result(self.args, invalid, output, recons)\n                    all_intersections += intersection\n                    all_unions += union\n                    num_correct += correct\n                    num_total += total\n                    TP += pred_TP\n                    FP += pred_FP\n                    TN += pred_TN\n                    FN += pred_FN\n\n                    s_correct, s_total, s_pred_TP, s_pred_FP, s_pred_TN, s_pred_FN, s_intersection, s_union = get_result(self.args, voxel_input, output, recons, SSC=False)\n                    s_all_intersections += s_intersection\n                    s_all_unions += s_union\n                    s_num_correct += s_correct\n                    s_num_total += s_total\n                    s_TP += s_pred_TP\n                    s_FP += s_pred_FP\n                    s_TN += s_pred_TN\n                    s_FN += s_pred_FN\n                   \n            iou, miou = print_result(self.args, self.label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN)\n            s_iou, seg_miou = print_result(self.args, self.label_to_names, s_num_correct, s_num_total, s_all_intersections, s_all_unions, s_TP, s_FP, s_FN, SSC=False)\n            return {\"IoU\" : iou, \"mIoU\": np.mean(miou)*100 }, miou, {\"IoU\" : s_iou, \"mIoU\": np.mean(seg_miou)*100 }, seg_miou\n\n    def resume(self):\n        self.checkpoint_load(self.args.resume_path)\n        for epoch in range(self.current_epoch):\n            train_dict = {}\n            for metric_name, metric_values in self.train_metrics.items():\n                train_dict[metric_name] = metric_values[epoch]\n\n            if epoch in self.eval_epochs:\n                eval_dict = {}\n                for metric_name, metric_values in self.eval_metrics.items():\n                    eval_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]\n            else: \n                eval_dict = None\n            \n            if epoch in self.completion_epochs:\n                sample_dict = {}\n                for metric_name, metric_values in self.eval_metrics.items():\n                    sample_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]\n            else: \n                sample_dict = None\n\n            for metric_name, metric_value in train_dict.items():\n                self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)\n            if eval_dict:\n                for metric_name, metric_value in eval_dict.items():\n                    self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)\n            if sample_dict:\n                for metric_name, metric_value in sample_dict.items():\n                    self.writer.add_scalar('sample/{}'.format(metric_name), metric_value, global_step=epoch+1)\n\n\n    def log_metrics(self, dict, type):\n        if len(type)==0:\n            for metric_name, metric_value in dict.items():\n                type[metric_name] = [metric_value]\n        else:\n            for metric_name, metric_value in dict.items():\n                type[metric_name].append(metric_value)\n\n    def save_metrics(self):\n        # Save metrics\n        with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f:\n            pickle.dump(self.train_metrics, f)\n        with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f:\n            pickle.dump(self.eval_metrics, f)\n\n        # Save metrics table\n        metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2)))\n        with open(os.path.join(self.log_path,'metrics_train.txt'), \"w\") as f:\n            f.write(str(metric_table))\n        metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs])\n        with open(os.path.join(self.log_path,'metrics_eval.txt'), \"w\") as f:\n            f.write(str(metric_table))\n\n\n    def checkpoint_save(self, epoch):    \n        if self.args.distribution:\n            checkpoint = {'current_epoch': self.current_epoch,\n                          'train_metrics': self.train_metrics,\n                          'eval_metrics': self.eval_metrics,\n                          'eval_epochs': self.eval_epochs,\n                          'optimizer': self.optimizer.state_dict(),\n                          'model': self.model.module.state_dict(),\n                          'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,\n                          'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}\n        else : \n            checkpoint = {'current_epoch': self.current_epoch,\n                          'train_metrics': self.train_metrics,\n                          'eval_metrics': self.eval_metrics,\n                          'eval_epochs': self.eval_epochs,\n                          'optimizer': self.optimizer.state_dict(),\n                          'model': self.model.state_dict(),\n                          'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,\n                          'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}\n        epoch_name = 'epoch{}.tar'.format(epoch)\n        torch.save(checkpoint, os.path.join(self.log_path, epoch_name))\n\n    def checkpoint_load(self, resume_path):\n        checkpoint = torch.load(resume_path)\n        \n        if self.args.distribution:\n            self.model.module.load_state_dict(checkpoint['model'])\n        else :\n            self.model.load_state_dict(checkpoint['model'])\n        \n        self.optimizer.load_state_dict(checkpoint['optimizer'])\n        if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])\n        if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])\n\n        self.current_epoch = checkpoint['current_epoch']\n        self.train_metrics = checkpoint['train_metrics']\n        self.eval_metrics = checkpoint['eval_metrics']\n        self.eval_epochs = checkpoint['eval_epochs']\n"
  },
  {
    "path": "utils/cuda.py",
    "content": "import os\n\nimport os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nimport utils.dicts as dist_fn\n\ndef find_free_port():\n    import socket\n\n    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n\n    sock.bind((\"\", 0))\n    port = sock.getsockname()[1]\n    sock.close()\n\n    return port\n\n\ndef launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):\n    world_size = n_machine * n_gpu_per_machine\n\n    if world_size > 1:\n        # if \"OMP_NUM_THREADS\" not in os.environ:\n        #     os.environ[\"OMP_NUM_THREADS\"] = \"1\"\n\n        if dist_url == \"auto\":\n            if n_machine != 1:\n                raise ValueError('dist_url=\"auto\" not supported in multi-machine jobs')\n\n            port = find_free_port()\n            dist_url = f\"tcp://127.0.0.1:{port}\"\n\n        if n_machine > 1 and dist_url.startswith(\"file://\"):\n            raise ValueError(\n                \"file:// is not a reliable init method in multi-machine jobs. Prefer tcp://\"\n            )\n\n        mp.spawn(\n            distributed_worker,\n            nprocs=n_gpu_per_machine,\n            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),\n            daemon=False,\n        )\n\n    else:\n        local_rank = 0\n        fn(local_rank, *args)\n\n\ndef distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args):\n    if not torch.cuda.is_available():\n        raise OSError(\"CUDA is not available. Please check your environments\")\n\n    global_rank = machine_rank * n_gpu_per_machine + local_rank\n\n    try:\n        dist.init_process_group(\n            backend=\"NCCL\",\n            init_method=dist_url,\n            world_size=world_size,\n            rank=global_rank,\n        )\n\n    except Exception:\n        raise OSError(\"failed to initialize NCCL groups\")\n\n    dist_fn.synchronize()\n\n    if n_gpu_per_machine > torch.cuda.device_count():\n        raise ValueError(\n            f\"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})\"\n        )\n\n    torch.cuda.set_device(local_rank)\n\n    if dist_fn.LOCAL_PROCESS_GROUP is not None:\n        raise ValueError(\"torch.distributed.LOCAL_PROCESS_GROUP is not None\")\n\n    n_machine = world_size // n_gpu_per_machine\n\n    for i in range(n_machine):\n        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))\n        pg = dist.new_group(ranks_on_i)\n\n        if i == machine_rank:\n            dist_fn.LOCAL_PROCESS_GROUP = pg\n\n    fn(local_rank, *args)\n\ndef set_cuda_vd(gpu_ids, verbose=True):\n    os.environ[\"CUDA_VISIBLE_DEVICES\"] = ','.join(str(id) for id in gpu_ids)\n    if verbose: print(\"CUDA_VISIBLE_DEVICES = {}\",format(os.environ[\"CUDA_VISIBLE_DEVICES\"]))\n\n"
  },
  {
    "path": "utils/dicts.py",
    "content": "import copy\nimport math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n\nLOCAL_PROCESS_GROUP = None\n\n\ndef is_primary():\n    return get_rank() == 0\n\n\ndef get_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    return dist.get_rank()\n\n\ndef get_local_rank():\n    if not dist.is_available():\n        return 0\n\n    if not dist.is_initialized():\n        return 0\n\n    if LOCAL_PROCESS_GROUP is None:\n        raise ValueError(\"tensorfn.distributed.LOCAL_PROCESS_GROUP is None\")\n\n    return dist.get_rank(group=LOCAL_PROCESS_GROUP)\n\n\ndef synchronize():\n    if not dist.is_available():\n        return\n\n    if not dist.is_initialized():\n        return\n\n    world_size = dist.get_world_size()\n\n    if world_size == 1:\n        return\n\n    dist.barrier()\n\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n\n    if not dist.is_initialized():\n        return 1\n\n    return dist.get_world_size()\n\n\ndef is_distributed():\n    raise RuntimeError('Please debug this function!')\n    return get_world_size() > 1\n\n\ndef all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return tensor\n    dist.all_reduce(tensor, op=op, async_op=async_op)\n\n    return tensor\n\n\ndef all_gather(data):\n    world_size = get_world_size()\n\n    if world_size == 1:\n        return [data]\n\n    buffer = pickle.dumps(data)\n    storage = torch.ByteStorage.from_buffer(buffer)\n    tensor = torch.ByteTensor(storage).to(\"cuda\")\n\n    local_size = torch.IntTensor([tensor.numel()]).to(\"cuda\")\n    size_list = [torch.IntTensor([1]).to(\"cuda\") for _ in range(world_size)]\n    dist.all_gather(size_list, local_size)\n    size_list = [int(size.item()) for size in size_list]\n    max_size = max(size_list)\n\n    tensor_list = []\n    for _ in size_list:\n        tensor_list.append(torch.ByteTensor(size=(max_size,)).to(\"cuda\"))\n\n    if local_size != max_size:\n        padding = torch.ByteTensor(size=(max_size - local_size,)).to(\"cuda\")\n        tensor = torch.cat((tensor, padding), 0)\n\n    dist.all_gather(tensor_list, tensor)\n\n    data_list = []\n\n    for size, tensor in zip(size_list, tensor_list):\n        buffer = tensor.cpu().numpy().tobytes()[:size]\n        data_list.append(pickle.loads(buffer))\n\n    return data_list\n\n\ndef reduce_dict(input_dict, average=True):\n    world_size = get_world_size()\n\n    if world_size < 2:\n        return input_dict\n\n    with torch.no_grad():\n        keys = []\n        values = []\n\n        for k in sorted(input_dict.keys()):\n            keys.append(k)\n            values.append(input_dict[k])\n\n        values = torch.stack(values, 0)\n        dist.reduce(values, dst=0)\n\n        if dist.get_rank() == 0 and average:\n            values /= world_size\n\n        reduced_dict = {k: v for k, v in zip(keys, values)}\n\n    return reduced_dict\n\n\ndef data_sampler(dataset, shuffle, distributed):\n    if distributed:\n        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)\n\n    if shuffle:\n        return data.RandomSampler(dataset)\n\n    else:\n        return data.SequentialSampler(dataset)\n\ndef clean_dict(d, keys):\n    d2 = copy.deepcopy(d)\n    for key in keys:\n        if key in d2:\n            del d2[key]\n    return d2\n"
  },
  {
    "path": "utils/intermediate_vis.py",
    "content": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimport torch.nn.functional as F\nimport yaml\n\nfrom prettytable import PrettyTable\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom utils.tables import *\nfrom utils.dicts import clean_dict\nfrom utils.loss import lovasz_softmax\n\n\nclass Vis_iter(object):\n    no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']\n                   \n    def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader,log_path):\n\n        # Objects\n        self.model = model\n        self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch\n        # Paths\n        self.log_path = log_path\n\n        if args.dataset =='kitti':\n            config_file = os.path.join('/home/jumin/multinomial_diffusion/datasets/semantic_kitti.yaml')\n            kitti_config = yaml.safe_load(open(config_file, 'r'))\n            self.remap = kitti_config['learning_map_inv']\n            self.color_map = kitti_config[\"color_map\"]\n            label = kitti_config['labels']\n            map_index = np.asarray([self.remap[i] for i in range(20)])\n            self.label_to_names = np.asarray([label[map_i] for map_i in map_index])\n\n        elif args.dataset =='carla':\n            base_dir = os.path.dirname(__file__)\n            config_file = os.path.join(base_dir, '../datasets/carla.yaml')\n            carla_config = yaml.safe_load(open(config_file, 'r'))\n            self.color_map = carla_config[\"remap_color_map\"]\n            self.remap = None\n            LABEL_TO_NAMES = carla_config[\"label_to_names\"]\n            self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))\n\n\n        # Initialize\n        self.current_epoch = 0\n        self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}\n        self.eval_epochs = []\n        self.completion_epochs = []\n\n        # Store data loaders\n        self.test_loader = test_loader\n\n        # Store args\n        create_folders(args)\n        save_args(args)\n        self.args = args\n\n        # Init logging\n        args_dict = clean_dict(vars(args), keys=self.no_log_keys)\n        if args.log_tb:\n            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))\n            self.writer.add_text(\"args\", get_args_table(args_dict).get_html_string(), global_step=0)\n\n    def run(self, epochs):\n        self.checkpoint_load(self.args.resume_path)\n        for epoch in range(self.current_epoch, epochs): \n            self.sample()\n\n    def sample(self):\n        self.model.eval()\n        with torch.no_grad():\n            for iterate, (voxel_input, output, counts) in enumerate(self.test_loader):\n                voxel_input = torch.from_numpy(np.asarray(voxel_input)).squeeze(1).cuda() \n                output = torch.from_numpy(np.asarray(output)).long().cuda()            \n                _, intermediate = self.model.module.sample(voxel_input, intermediate=True)\n                inter_vis(self.args, intermediate)\n                break\n                   \n    def checkpoint_load(self, resume_path):\n        checkpoint = torch.load(resume_path)\n        \n        if self.args.distribution:\n            self.model.module.load_state_dict(checkpoint['model'])\n        else :\n            self.model.load_state_dict(checkpoint['model'])\n        \n        self.optimizer.load_state_dict(checkpoint['optimizer'])\n        if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])\n        if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])\n\n        self.current_epoch = checkpoint['current_epoch']\n        self.train_metrics = checkpoint['train_metrics']\n        self.eval_metrics = checkpoint['eval_metrics']\n        self.eval_epochs = checkpoint['eval_epochs']\n"
  },
  {
    "path": "utils/loss.py",
    "content": "import math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport numpy as np\ntry:\n    from itertools import  ifilterfalse\nexcept ImportError: # py3k\n    from itertools import  filterfalse as ifilterfalse\n\n\n\n# -*- coding:utf-8 -*-\n# author: Xinge\n\ndef dice_coef(y_true, y_pred, smooth=1e-6):\n    y_true_f = y_true.view(-1)\n    y_pred_f = y_pred.view(-1)\n    intersection = (y_true_f * y_pred_f).sum()\n    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)\n\ndef dice_coef_multilabel(y_true, y_pred, numLabels=11):\n    dice=0\n    for index in range(1, numLabels):\n        dice += dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:])\n    return (numLabels-1) - dice\n\n\"\"\"\nLovasz-Softmax and Jaccard hinge loss in PyTorch\nMaxim Berman 2018 ESAT-PSI KU Leuven (MIT License)\n\"\"\"\n\ndef lovasz_grad(gt_sorted):\n    \"\"\"\n    Computes gradient of the Lovasz extension w.r.t sorted errors\n    See Alg. 1 in paper\n    \"\"\"\n    p = len(gt_sorted)\n    gts = gt_sorted.sum()\n    intersection = gts - gt_sorted.float().cumsum(0)\n    union = gts + (1 - gt_sorted).float().cumsum(0)\n    jaccard = 1. - intersection / union\n    if p > 1: # cover 1-pixel case\n        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]\n    return jaccard\n\n# --------------------------- MULTICLASS LOSSES ---------------------------\n\n\ndef lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).\n              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].\n      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n      per_image: compute the loss per image instead of per batch\n      ignore: void class labels\n    \"\"\"\n    if per_image:\n        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)\n                          for prob, lab in zip(probas, labels))\n    else:\n        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)\n    return loss\n\n\ndef lovasz_softmax_flat(probas, labels, classes='present'):\n    \"\"\"\n    Multi-class Lovasz-Softmax loss\n      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)\n      labels: [P] Tensor, ground truth labels (between 0 and C - 1)\n      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.\n    \"\"\"\n    if probas.numel() == 0:\n        # only void pixels, the gradients should be 0\n        return probas * 0.\n    C = probas.size(1)\n    losses = []\n    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes\n    for c in class_to_sum:\n        fg = (labels == c).float() # foreground for class c\n        if (classes is 'present' and fg.sum() == 0):\n            continue\n        if C == 1:\n            if len(classes) > 1:\n                raise ValueError('Sigmoid output possible only with 1 class')\n            class_pred = probas[:, 0]\n        else:\n            class_pred = probas[:, c]\n        errors = (Variable(fg) - class_pred).abs()\n        errors_sorted, perm = torch.sort(errors, 0, descending=True)\n        perm = perm.data\n        fg_sorted = fg[perm]\n        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))\n    return mean(losses)\n\n\ndef flatten_probas(probas, labels, ignore=None):\n    \"\"\"\n    Flattens predictions in the batch\n    \"\"\"\n    if probas.dim() == 3:\n        # assumes output of a sigmoid layer\n        B, H, W = probas.size()\n        probas = probas.view(B, 1, H, W)\n    elif probas.dim() == 5:\n        #3D segmentation\n        B, C, L, H, W = probas.size()\n        probas = probas.contiguous().view(B, C, L, H*W)\n    B, C, H, W = probas.size()\n    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C\n    labels = labels.view(-1)\n    if ignore is None:\n        return probas, labels\n    valid = (labels != ignore)\n    vprobas = probas[valid.nonzero().squeeze()]\n    vlabels = labels[valid]\n    return vprobas, vlabels\n\n\n# --------------------------- HELPER FUNCTIONS ---------------------------\ndef isnan(x):\n    return x != x\n    \n    \ndef mean(l, ignore_nan=False, empty=0):\n    \"\"\"\n    nanmean compatible with generators.\n    \"\"\"\n    l = iter(l)\n    if ignore_nan:\n        l = ifilterfalse(isnan, l)\n    try:\n        n = 1\n        acc = next(l)\n    except StopIteration:\n        if empty == 'raise':\n            raise ValueError('Empty mean')\n        return empty\n    for n, v in enumerate(l, 2):\n        acc += v\n    if n == 1:\n        return acc\n    return acc / n\n"
  },
  {
    "path": "utils/multistep.py",
    "content": "import torch.optim as optim\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.optim.lr_scheduler import _LRScheduler\n\nclass LinearWarmupScheduler(_LRScheduler):\n    \"\"\" Linearly warm-up (increasing) learning rate, starting from zero.\n    Args:\n        optimizer (Optimizer): Wrapped optimizer.\n        total_epoch: target learning rate is reached at total_epoch.\n    \"\"\"\n\n    def __init__(self, optimizer, total_epoch, last_epoch=-1):\n        self.total_epoch = total_epoch\n        super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        return [base_lr * min(1, (self.last_epoch / self.total_epoch)) for base_lr in self.base_lrs]\n        \noptim_choices = {'sgd', 'adam', 'adamax'}\n\ndef get_optim(args, model):\n    assert args.optimizer in optim_choices\n\n    if args.optimizer == 'sgd':\n        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)\n    elif args.optimizer == 'adam':\n        optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))\n    elif args.optimizer == 'adamax':\n        optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))\n\n    if args.warmup is not None:\n        scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup)\n    else:\n        scheduler_iter = None\n\n    if len(args.milestones)>0:\n        scheduler_epoch = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)\n    else:\n        scheduler_epoch = None\n\n    return optimizer, scheduler_iter, scheduler_epoch"
  },
  {
    "path": "utils/tables.py",
    "content": "from prettytable import PrettyTable\nimport torch\nimport os\nimport pickle\nimport numpy as np\nimport torch.nn.functional as F\nimport open3d as o3d\n\ndef get_args_table(args_dict):\n    table = PrettyTable(['Arg', 'Value'])\n    for arg, val in args_dict.items():\n        table.add_row([arg, val])\n    return table\n\ndef get_miou_table(args, label_to_names, miou):\n    table = PrettyTable(['Label', 'mIoU'])\n    for i in range(args.num_classes):\n        table.add_row([label_to_names[i], 100 * miou[i]])\n    return table\n\ndef get_metric_table(metric_dict, epochs):\n    table = PrettyTable()\n    table.add_column('Epoch', epochs)\n    if len(metric_dict)>0:\n        for metric_name, metric_values in metric_dict.items():\n            table.add_column(metric_name, metric_values)\n    return table\n\ndef create_folders(args):\n    # Create log folder\n    os.makedirs(args.log_path, exist_ok=True)\n    os.makedirs(args.log_path+'/Completion', exist_ok=True)\n    os.makedirs(args.log_path+'/Input', exist_ok=True)\n    os.makedirs(args.log_path+'/Output', exist_ok=True)\n    os.makedirs(args.log_path+'/Invalid', exist_ok=True)\n    print(\"Storing logs in:\", args.log_path)\n\ndef inter_vis(args, recons):\n    for r in range(len(recons)):\n        for batch, samples_i in enumerate(recons[r]):\n            color_index = []\n            for i in range(1, args.num_classes):\n                index = torch.nonzero(samples_i == i ,as_tuple=False)\n                color_index.append(F.pad(index,(1,0),'constant',value = i))\n            colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()\n            np.savetxt('/home/jumin/multinomial_diffusion/Result/Condition/Completion/iteration/batch{}_{}.txt'.format(batch, r), colors_indexs)\n\n\ndef visualization(args, recons, input_data, output, invalid, iteration):\n\n    for batch, (samples_i, input_i, output_i, invalid_i) in enumerate(zip(recons, input_data, output, invalid)):\n        color_index = []\n        output_index = []\n        input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()\n        if args.dataset =='carla':\n            invalid_points = torch.nonzero(invalid_i == 0, as_tuple=False).cpu().numpy() \n        elif args.dataset =='kitti':\n            invalid_points = torch.nonzero(invalid_i == 1, as_tuple=False).cpu().numpy() \n\n        for i in range(1, args.num_classes):\n            index = torch.nonzero(samples_i == i ,as_tuple=False)\n            out_color = torch.nonzero(output_i == i, as_tuple=False)\n            color_index.append(F.pad(index,(1,0),'constant',value = i))\n            output_index.append(F.pad(out_color,(1,0),'constant',value=i))\n        colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()\n        out_indexs = torch.cat(output_index, dim = 0).cpu().numpy()\n        np.savetxt(args.log_path+'/Completion/result_{}.txt'.format((iteration * args.batch_size) + batch), colors_indexs)\n\n        '''np.savetxt(args.log_path+'/Input/input_{}.txt'.format((iteration * args.batch_size) + batch), input_points)\n        np.savetxt(args.log_path+'/Invalid/invalid_{}.txt'.format((iteration * args.batch_size) + batch), invalid_points)\n        np.savetxt(args.log_path+'/Output/gt_{}.txt'.format((iteration * args.batch_size) + batch), out_indexs)'''\n        \n\ndef completion_vis(args, input_p, recons):\n    for batch, (recon_i, input_i) in enumerate(zip(recons, input_p)):\n        recon_points = torch.nonzero(recon_i == 1, as_tuple=False).cpu().numpy()\n        input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()\n        np.savetxt(args.log_path+'/Completion/completion_{}.txt'.format(batch), recon_points)\n        np.savetxt(args.log_path+'/Input/input_{}.txt'.format(batch), input_points)\n\n\ndef iou_one_frame(pred, target, n_classes=23):\n    pred = pred.view(-1).detach().cpu().numpy()\n    target = target.view(-1).detach().cpu().numpy()\n    intersection = np.zeros(n_classes)\n    union = np.zeros(n_classes)\n\n    for cls in range(n_classes):\n        intersection[cls] = np.sum((pred == cls) & (target == cls))\n        union[cls] = np.sum((pred == cls) | (target == cls))\n    return intersection, union\n\n\ndef get_result(args, for_mask, output, preds, SSC=True):\n    for_mask = for_mask.contiguous().view(-1)\n    output = output.contiguous().view(-1)\n    preds = preds.contiguous().view(-1)\n    \n    if SSC :\n        if args.dataset == 'kitti':\n            mask = for_mask == 0\n        elif args.dataset== 'carla':\n            mask = for_mask > 0\n    else : \n        mask = for_mask == 1\n\n    output_masked = output[mask]\n    iou_output_masked = output_masked.cpu().numpy()\n    iou_output_masked[iou_output_masked != 0] = 1\n\n    preds_masked = preds[mask]\n    iou_preds_masked = preds_masked.cpu().numpy()\n    iou_preds_masked[iou_preds_masked != 0] = 1\n\n    # I, U for a frame\n    correct = np.sum(output_masked.cpu().numpy() == preds_masked.cpu().numpy())\n    total = preds_masked.shape[0]\n\n    pred_TP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 1))\n    pred_FP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 0))\n    pred_TN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 0))\n    pred_FN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 1))\n\n    intersection, union = iou_one_frame(preds_masked, output_masked, n_classes=args.num_classes)\n    return correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union\n\ndef save_args(args):\n    # Save args\n    with open(os.path.join(args.log_path, 'args.pickle'), \"wb\") as f:\n        pickle.dump(args, f)\n\n    # Save args table\n    args_table = get_args_table(vars(args))\n    with open(os.path.join(args.log_path,'args_table.txt'), \"w\") as f:\n        f.write(str(args_table))\n\ndef print_completion(num_correct, num_total, TP, FP, FN):\n    print(\"\\n=========================================\\n\")\n    accuracy = num_correct/num_total\n    print(\"\\nAccuracy : \", accuracy)\n\n    precision = 100 * TP / (TP + FP)\n    recall = 100 * TP / (TP + FN)\n    iou = 100 * TP / (TP + FP + FN)\n\n    print(\"\\nCompleteness\")\n    print(\"precision:\", precision)\n    print(\"recall:\", recall)\n    print(\"iou:\", iou)\n\n    print(\"\\n=========================================\\n\")\n    return iou\n\ndef print_result(args, label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN, SSC=True):\n    if SSC :\n        print(\"\\n========== Semantic Scene Completion =============\\n\")\n    else :\n        print(\"\\n============ Semantic Segmentation ===============\\n\")\n    accuracy = num_correct/num_total\n    print(\"\\nAccuracy : \", accuracy)\n\n    precision = 100 * TP / (TP + FP)\n    recall = 100 * TP / (TP + FN)\n    iou = 100 * TP / (TP + FP + FN)\n\n    print(\"\\nCompleteness\")\n    print(\"precision:\", precision)\n    print(\"recall:\", recall)\n    print(\"iou:\", iou)\n\n    print(\"\\nSemantic IoU Per Class\")\n    miou = all_intersections / all_unions\n    for i in range(args.num_classes):\n        print(label_to_names[i], ':', 100 * miou[i])\n    print(\"\\n====================================================\\n\")\n    return iou, miou"
  },
  {
    "path": "visualization.py",
    "content": "import os\r\nimport open3d as o3d\r\nimport open3d.visualization.gui as gui\r\nimport open3d.visualization.rendering as rendering\r\nimport argparse\r\nimport numpy as np\r\nimport yaml\r\nimport struct\r\n\r\nparser = argparse.ArgumentParser()\r\nparser.add_argument('--M', default='scene-scale-diffusion') # VQVAE, multinomial_diffusion\r\nparser.add_argument('--Driver', default='D')\r\nparser.add_argument('--frame', default='0')\r\nparser.add_argument('--file', default='result_')\r\nparser.add_argument('--folder', default='Completion')\r\nparser.add_argument('--model', default='image_init8_concat_att')\r\nparser.add_argument('--name', default='Semantic Scene Completion')\r\nparser.add_argument('--invalid', default = False)\r\n\r\nclass SpheresApp:\r\n    MENU_SCENE = 1\r\n    MENU_BEFORE = 2\r\n    MENU_QUIT = 3\r\n\r\n    def __init__(self, opt):\r\n        self._id = 0\r\n        self.opt = opt\r\n        \r\n        self.window = gui.Application.instance.create_window(\"Semantic Scene Completion\", 1500, 1000)\r\n        self.scene = gui.SceneWidget()\r\n        self.scene.scene = rendering.Open3DScene(self.window.renderer)\r\n        self.scene.scene.set_background([1, 1, 1, 1])\r\n        self.scene.scene.scene.set_sun_light(\r\n            [-0.577, 0.577, -0.577],  # direction\r\n            [1, 1, 1],  # color\r\n            60000)  # intensity\r\n        \r\n        self.scene.scene.scene.enable_sun_light(True)\r\n        bbox = o3d.geometry.AxisAlignedBoundingBox([64, 64, -60], [64, 64, 60])\r\n        \r\n        self.scene.setup_camera(60, bbox, [0, 0, 1])\r\n        self.window.add_child(self.scene)\r\n\r\n\r\n        if gui.Application.instance.menubar is None:\r\n            \r\n            debug_menu = gui.Menu()\r\n            debug_menu.add_item(\"Next Scene\", SpheresApp.MENU_SCENE)\r\n            debug_menu.add_separator()\r\n            debug_menu.add_item(\"Before Scene\", SpheresApp.MENU_BEFORE)\r\n            debug_menu.add_separator()\r\n            debug_menu.add_item(\"Quit\", SpheresApp.MENU_QUIT)\r\n            menu = gui.Menu()\r\n            menu.add_menu(\"SSC\", debug_menu)\r\n            gui.Application.instance.menubar = menu\r\n\r\n        # The menubar is global, but we need to connect the menu items to the\r\n        # window, so that the window can call the appropriate function when the menu item is activated.\r\n        self.window.set_on_menu_item_activated(SpheresApp.MENU_SCENE,self._on_menu_scene)\r\n        self.window.set_on_menu_item_activated(SpheresApp.MENU_QUIT,self._on_menu_quit)\r\n        self.window.set_on_menu_item_activated(SpheresApp.MENU_BEFORE,self._on_menu_before)\r\n\r\n    def _on_menu_before(self):\r\n        self._id -= 1\r\n        mat = rendering.MaterialRecord()\r\n        mat.shader = \"defaultLit\"\r\n\r\n        if self.opt.file == 'input_':\r\n            points = get_input(self.opt)\r\n        else :\r\n            points, colors = get_voxel(self.opt)\r\n        \r\n        pcd = o3d.geometry.PointCloud()\r\n        pcd.points = o3d.utility.Vector3dVector(points)\r\n\r\n        if (self.opt.file != 'input_'):\r\n            pcd.colors = o3d.utility.Vector3dVector(colors/255)\r\n        self.scene.scene.clear_geometry()\r\n        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)\r\n        self.scene.scene.add_geometry(\"scene\" + str(self._id), voxel_grid, mat)\r\n        print(self.opt.frame)\r\n        self.opt.frame = str(int(self.opt.frame)-1)\r\n\r\n    def _on_menu_quit(self):\r\n        gui.Application.instance.quit()\r\n\r\n    def _on_menu_scene(self):\r\n        self._id += 1\r\n        mat = rendering.MaterialRecord()\r\n        mat.shader = \"defaultLit\"\r\n\r\n        if self.opt.file == 'input_':\r\n            points = get_input(self.opt)\r\n        else :\r\n            points, colors = get_voxel(self.opt)\r\n        \r\n        pcd = o3d.geometry.PointCloud()\r\n        pcd.points = o3d.utility.Vector3dVector(points)\r\n\r\n        if (self.opt.file != 'input_'):\r\n            pcd.colors = o3d.utility.Vector3dVector(colors/255)\r\n        self.scene.scene.clear_geometry()\r\n        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)\r\n        self.scene.scene.add_geometry(\"scene\" + str(self._id), voxel_grid, mat)\r\n        print(self.opt.frame)\r\n        self.opt.frame = str(int(self.opt.frame)+1)\r\n\r\ndef get_voxel(opt):\r\n\r\n    if opt.invalid :\r\n        invalid_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/Invalid/invalid_'+ opt.frame +'.txt'\r\n        invalid_points = np.loadtxt(invalid_path, delimiter=' ')\r\n        invalid_colors = np.full(len(invalid_points,), 0)\r\n                \r\n        point_cloud_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/' + opt.folder +'/'+ opt.file + opt.frame +'.txt'\r\n        points_colors = np.loadtxt(point_cloud_path, delimiter=' ')\r\n        points = points_colors[:, 1:]\r\n        colors = points_colors[:, 0]\r\n        \r\n        points = np.concatenate((invalid_points, points), axis=0)\r\n        colors = np.concatenate((invalid_colors, colors), axis=0)\r\n        \r\n        points, index = np.unique(points, return_index=True, axis=0)\r\n        colors = colors[index, ...]\r\n        \r\n    else :\r\n        point_cloud_path = 'C:/Users/jumin/Dataset/result_319_110.txt'\r\n        points_colors = np.loadtxt(point_cloud_path, delimiter=' ')\r\n        points = points_colors[:, 1:]\r\n        colors = points_colors[:, 0]\r\n\r\n    if opt.dataset == 'carla' : \r\n        base_dir = os.path.dirname(__file__)\r\n        config_file = os.path.join(base_dir, 'datasets/carla.yaml')\r\n        config = yaml.safe_load(open(config_file, 'r'))\r\n        color_map = config[\"remap_color_map\"]\r\n    \r\n    color = np.asarray([color_map[c] for c in colors])\r\n\r\n    return points, color\r\n\r\ndef get_input(opt):\r\n    point_cloud_path=opt.Driver+':/'+opt.M+'/result/' + opt.model +'/Invalid/invalid_' + opt.frame +'.txt'\r\n    points_colors = np.loadtxt(point_cloud_path, delimiter=' ')\r\n    points = points_colors\r\n    pcd = o3d.geometry.PointCloud()\r\n    pcd.points = o3d.utility.Vector3dVector(points)\r\n    return points\r\n\r\ndef main(opt):\r\n    gui.Application.instance.initialize()\r\n    SpheresApp(opt)\r\n    gui.Application.instance.run()\r\n\r\nif __name__ == \"__main__\":\r\n    opt = parser.parse_args()\r\n    main(opt)"
  }
]