[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 University of Washington\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Real-Time High-Resolution Background Matting\n\n![Teaser](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/teaser.gif?raw=true)\n\nOfficial repository for the paper [Real-Time High-Resolution Background Matting](https://arxiv.org/abs/2012.07810). Our model requires capturing an additional background image and produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU.\n\n* [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/)\n* [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g)\n\n**Disclaimer**: The video conversion script in this repo is not meant be real-time. Our research's main contribution is the neural architecture for high resolution refinement and the new matting datasets. The `inference_speed_test.py` script allows you to measure the tensor throughput of our model, which should achieve real-time. The `inference_video.py` script allows you to test your video on our model, but the video encoding and decoding is done without hardware acceleration and parallization. For production use, you are expected to do additional engineering for hardware encoding/decoding and loading frames to GPU in parallel. For more architecture detail, please refer to our paper.\n\n&nbsp;\n\n## New Paper is Out!\n\nCheck out [Robust Video Matting](https://peterl1n.github.io/RobustVideoMatting/)! Our new method does not require pre-captured backgrounds, and can inference at even faster speed!\n\n&nbsp;\n\n## Overview\n* [Updates](#updates)\n* [Download](#download)\n    * [Model / Weights](#model--weights)\n    * [Video / Image Examples](#video--image-examples)\n    * [Datasets](#datasets)\n* [Demo](#demo)\n    * [Scripts](#scripts)\n    * [Notebooks](#notebooks)\n* [Usage / Documentation](#usage--documentation)\n* [Training](#training)\n* [Project members](#project-members)\n* [License](#license)\n\n&nbsp;\n\n## Updates\n\n* [Jun 21 2021] Paper received CVPR 2021 Best Student Paper Honorable Mention.\n* [Apr 21 2021] VideoMatte240K dataset is now published.\n* [Mar 06 2021] Training script is published.\n* [Feb 28 2021] Paper is accepted to CVPR 2021.\n* [Jan 09 2021] PhotoMatte85 dataset is now published.\n* [Dec 21 2020] We updated our project to MIT License, which permits commercial use.\n\n&nbsp;\n\n## Download\n\n### Model / Weights\n\n\n* [Download model / weights (GitHub)](https://github.com/PeterL1n/BackgroundMattingV2/releases/tag/v1.0.0)\n* [Download model / weights (GDrive)](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing)\n\n### Video / Image Examples\n\n* [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage)\n* [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing)\n\n\n### Datasets\n\n* [Download datasets](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)\n\n&nbsp;\n\n## Demo\n\n#### Scripts\n\nWe provide several scripts in this repo for you to experiment with our model. More detailed instructions are included in the files.\n* `inference_images.py`: Perform matting on a directory of images.\n* `inference_video.py`: Perform matting on a video.\n* `inference_webcam.py`: An interactive matting demo using your webcam.\n\n#### Notebooks\nAdditionally, you can try our notebooks in Google Colab for performing matting on images and videos.\n\n* [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing)\n* [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing)\n\n#### Virtual Camera\nWe provide a demo application that pipes webcam video through our model and outputs to a virtual camera. The script only works on Linux system and can be used in Zoom meetings. For more information, checkout:\n* [Webcam plugin](https://github.com/andreyryabtsev/BGMv2-webcam-plugin-linux)\n\n&nbsp;\n\n## Usage / Documentation\n\nYou can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page.\n\n&nbsp;\n\n## Training\n\nConfigure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.\n\n&nbsp;\n\n## Project members\n* [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington\n* [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington\n* [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington\n* [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington\n* [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington\n* [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington\n\n<sup>* Equal contribution.</sup>\n\n&nbsp;\n\n## License ##\nThis work is licensed under the [MIT License](LICENSE). If you use our work in your project, we would love you to include an acknowledgement and fill out our [survey](https://docs.google.com/forms/d/e/1FAIpQLSdR9Yhu9V1QE3pN_LvZJJyDaEpJD2cscOOqMz8N732eLDf42A/viewform?usp=sf_link).\n\n## Community Projects\nProjects developed by third-party developers.\n\n* [After Effects Plug-In](https://aescripts.com/goodbye-greenscreen/)\n"
  },
  {
    "path": "data_path.py",
    "content": "\"\"\"\nThis file records the directory paths to the different datasets.\nYou will need to configure it for training the model.\n\nAll datasets follows the following format, where fgr and pha points to directory that contains jpg or png.\nInside the directory could be any nested formats, but fgr and pha structure must match. You can add your own\ndataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,\n'pha' should point to alpha images with only 1 grey channel.\n{\n    'YOUR_DATASET': {\n        'train': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR',\n        },\n        'valid': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR',\n        }\n    }\n}\n\"\"\"\n\nDATA_PATH = {\n    'videomatte240k': {\n        'train': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        },\n        'valid': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        }\n    },\n    'photomatte13k': {\n        'train': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        },\n        'valid': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        }\n    },\n    'distinction': {\n        'train': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR',\n        },\n        'valid': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        },\n    },\n    'adobe': {\n        'train': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR',\n        },\n        'valid': {\n            'fgr': 'PATH_TO_IMAGES_DIR',\n            'pha': 'PATH_TO_IMAGES_DIR'\n        },\n    },\n    'backgrounds': {\n        'train': 'PATH_TO_IMAGES_DIR',\n        'valid': 'PATH_TO_IMAGES_DIR'\n    },\n}"
  },
  {
    "path": "dataset/__init__.py",
    "content": "from .images import ImagesDataset\nfrom .video import VideoDataset\nfrom .sample import SampleDataset\nfrom .zip import ZipDataset"
  },
  {
    "path": "dataset/augmentation.py",
    "content": "import random\nimport torch\nimport numpy as np\nimport math\nfrom torchvision import transforms as T\nfrom torchvision.transforms import functional as F\nfrom PIL import Image, ImageFilter\n\n\"\"\"\nPair transforms are MODs of regular transforms so that it takes in multiple images\nand apply exact transforms on all images. This is especially useful when we want the\ntransforms on a pair of images.\n\nExample:\n    img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)\n\"\"\"\n\nclass PairCompose(T.Compose):\n    def __call__(self, *x):\n        for transform in self.transforms:\n            x = transform(*x)\n        return x\n    \n\nclass PairApply:\n    def __init__(self, transforms):\n        self.transforms = transforms\n        \n    def __call__(self, *x):\n        return [self.transforms(xi) for xi in x]\n\n\nclass PairApplyOnlyAtIndices:\n    def __init__(self, indices, transforms):\n        self.indices = indices\n        self.transforms = transforms\n    \n    def __call__(self, *x):\n        return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]\n\n\nclass PairRandomAffine(T.RandomAffine):\n    def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):\n        super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)\n        self.resamples = resamples\n    \n    def __call__(self, *x):\n        if not len(x):\n            return []\n        param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)\n        resamples = self.resamples or [self.resample] * len(x)\n        return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]\n\n\nclass PairRandomHorizontalFlip(T.RandomHorizontalFlip):\n    def __call__(self, *x):\n        if torch.rand(1) < self.p:\n            x = [F.hflip(xi) for xi in x]\n        return x\n\n\nclass RandomBoxBlur:\n    def __init__(self, prob, max_radius):\n        self.prob = prob\n        self.max_radius = max_radius\n    \n    def __call__(self, img):\n        if torch.rand(1) < self.prob:\n            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))\n            img = img.filter(fil)\n        return img\n\n\nclass PairRandomBoxBlur(RandomBoxBlur):\n    def __call__(self, *x):\n        if torch.rand(1) < self.prob:\n            fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))\n            x = [xi.filter(fil) for xi in x]\n        return x\n\n\nclass RandomSharpen:\n    def __init__(self, prob):\n        self.prob = prob\n        self.filter = ImageFilter.SHARPEN\n    \n    def __call__(self, img):\n        if torch.rand(1) < self.prob:\n            img = img.filter(self.filter)\n        return img\n    \n    \nclass PairRandomSharpen(RandomSharpen):\n    def __call__(self, *x):\n        if torch.rand(1) < self.prob:\n            x = [xi.filter(self.filter) for xi in x]\n        return x\n    \n\nclass PairRandomAffineAndResize:\n    def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):\n        self.size = size\n        self.degrees = degrees\n        self.translate = translate\n        self.scale = scale\n        self.shear = shear\n        self.ratio = ratio\n        self.resample = resample\n        self.fillcolor = fillcolor\n    \n    def __call__(self, *x):\n        if not len(x):\n            return []\n        \n        w, h = x[0].size\n        scale_factor = max(self.size[1] / w, self.size[0] / h)\n        \n        w_padded = max(w, self.size[1])\n        h_padded = max(h, self.size[0])\n        \n        pad_h = int(math.ceil((h_padded - h) / 2))\n        pad_w = int(math.ceil((w_padded - w) / 2))\n        \n        scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor\n        translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor\n        affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))\n        \n        def transform(img):\n            if pad_h > 0 or pad_w > 0:\n                img = F.pad(img, (pad_w, pad_h))\n            \n            img = F.affine(img, *affine_params, self.resample, self.fillcolor)\n            img = F.center_crop(img, self.size)\n            return img\n            \n        return [transform(xi) for xi in x]\n\n\nclass RandomAffineAndResize(PairRandomAffineAndResize):\n    def __call__(self, img):\n        return super().__call__(img)[0]"
  },
  {
    "path": "dataset/images.py",
    "content": "import os\nimport glob\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nclass ImagesDataset(Dataset):\n    def __init__(self, root, mode='RGB', transforms=None):\n        self.transforms = transforms\n        self.mode = mode\n        self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),\n                                 *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])\n\n    def __len__(self):\n        return len(self.filenames)\n\n    def __getitem__(self, idx):\n        with Image.open(self.filenames[idx]) as img:\n            img = img.convert(self.mode)\n        \n        if self.transforms:\n            img = self.transforms(img)\n        \n        return img\n"
  },
  {
    "path": "dataset/sample.py",
    "content": "from torch.utils.data import Dataset\n\n\nclass SampleDataset(Dataset):\n    def __init__(self, dataset, samples):\n        samples = min(samples, len(dataset))\n        self.dataset = dataset\n        self.indices = [i * int(len(dataset) / samples) for i in range(samples)]\n    \n    def __len__(self):\n        return len(self.indices)\n    \n    def __getitem__(self, idx):\n        return self.dataset[self.indices[idx]]\n"
  },
  {
    "path": "dataset/video.py",
    "content": "import cv2\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nclass VideoDataset(Dataset):\n    def __init__(self, path: str, transforms: any = None):\n        self.cap = cv2.VideoCapture(path)\n        self.transforms = transforms\n        \n        self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))\n        self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)\n        self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))\n    \n    def __len__(self):\n        return self.frame_count\n    \n    def __getitem__(self, idx):\n        if isinstance(idx, slice):\n            return [self[i] for i in range(*idx.indices(len(self)))]\n        \n        if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:\n            self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)\n        ret, img = self.cap.read()\n        if not ret:\n            raise IndexError(f'Idx: {idx} out of length: {len(self)}')\n        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n        img = Image.fromarray(img)\n        if self.transforms:\n            img = self.transforms(img)\n        return img\n    \n    def __enter__(self):\n        return self\n    \n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        self.cap.release()\n"
  },
  {
    "path": "dataset/zip.py",
    "content": "from torch.utils.data import Dataset\nfrom typing import List\n\nclass ZipDataset(Dataset):\n    def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):\n        self.datasets = datasets\n        self.transforms = transforms\n        \n        if assert_equal_length:\n            for i in range(1, len(datasets)):\n                assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'\n    \n    def __len__(self):\n        return max(len(d) for d in self.datasets)\n    \n    def __getitem__(self, idx):\n        x = tuple(d[idx % len(d)] for d in self.datasets)\n        if self.transforms:\n            x = self.transforms(*x)\n        return x\n"
  },
  {
    "path": "doc/model_usage.md",
    "content": "# Use our model\nOur model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time.\n\n## Overview\n* [Usage](#usage)\n    * [PyTorch (Research)](#pytorch-research)\n    * [TorchScript (Production)](#torchscript-production)\n    * [TensorFlow (Experimental)](#tensorflow-experimental)\n    * [ONNX (Experimental)](#onnx-experimental)\n* [Documentation](#documentation)\n\n&nbsp;\n\n## Usage\n\n\n### PyTorch (Research)\n\nThe `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model.\n\n#### Python\n\n```python\nimport torch\nfrom model import MattingRefine\n\ndevice = torch.device('cuda')\nprecision = torch.float32\n\nmodel = MattingRefine(backbone='mobilenetv2',\n                      backbone_scale=0.25,\n                      refine_mode='sampling',\n                      refine_sample_pixels=80_000)\n\nmodel.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth'))\nmodel = model.eval().to(precision).to(device)\n\nsrc = torch.rand(1, 3, 1080, 1920).to(precision).to(device)\nbgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)\n\nwith torch.no_grad():\n    pha, fgr = model(src, bgr)[:2]\n```\n\n&nbsp;\n\n### TorchScript (Production)\n\nInference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment.\n\n#### Python\n\n```python\nimport torch\n\ndevice = torch.device('cuda')\nprecision = torch.float16\n\nmodel = torch.jit.load('PATH_TO_MODEL.pth')\nmodel.backbone_scale = 0.25\nmodel.refine_mode = 'sampling'\nmodel.refine_sample_pixels = 80_000\n\nmodel = model.to(device)\n\nsrc = torch.rand(1, 3, 1080, 1920).to(precision).to(device)\nbgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)\n\npha, fgr = model(src, bgr)[:2]\n```\n\n#### C++\n\n```cpp\n#include <torch/script.h>\n\nint main() {\n    auto device = torch::Device(\"cuda\");\n    auto precision = torch::kFloat16;\n\n    auto model = torch::jit::load(\"PATH_TO_MODEL.pth\");\n    model.setattr(\"backbone_scale\", 0.25);\n    model.setattr(\"refine_mode\", \"sampling\");\n    model.setattr(\"refine_sample_pixels\", 80000);\n    model.to(device);\n\n    auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);\n    auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);\n\n    auto outputs = model.forward({src, bgr}).toTuple()->elements();\n    auto pha = outputs[0].toTensor();\n    auto fgr = outputs[1].toTensor();\n}\n```\n&nbsp;\n\n### TensorFlow (Experimental)\n\nPlease visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail.\n\n&nbsp;\n\n### ONNX (Experimental)\n\n#### Python\n```python\nimport onnxruntime\nimport numpy as np\n\nsess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx')\n\nsrc = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)\nbgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)\n\npha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr})\n```\n\nOur model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. \n\n#### Compatibility Notes:\n\nOur network uses a novel architecture that involves cropping and replacing patches\nof an image. This may have compatibility issues for different inference backend.\nTherefore, we offer different methods for cropping and replacing patches as\ncompatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches.\n\n&nbsp;\n\n## Documentation\n\n![Architecture](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/architecture.svg?raw=true)\n\nOur architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment.\n\n#### Model Arguments:\n* `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25.\n* `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement. \n    * `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound.\n    * `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation.\n    * `full` will refine the entire image. Only used for debugging.\n* `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode.\n* `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode.\n* `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting.\n\n#### Model Inputs:\n* `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1.\n* `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1.\n\n#### Model Outputs:\n* `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1.\n* `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1.\n* `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1.\n* `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1.\n* `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1.\n* `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch.\n\nOnly the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging.\n\n\nWe recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K.\n"
  },
  {
    "path": "eval/benchmark.m",
    "content": "#!/usr/bin/octave\narg_list = argv ();\nbench_path = arg_list{1};\nresult_path = arg_list{2};\n\n\ngt_files = dir(fullfile(bench_path, 'pha', '*.png'));\n\ntotal_loss_mse = 0;\ntotal_loss_sad = 0;\ntotal_loss_gradient = 0;\ntotal_loss_connectivity = 0;\n\ntotal_fg_mse = 0;\ntotal_premult_mse = 0;\n\nfor i = 1:length(gt_files)\n    filename = gt_files(i).name;\n\n    gt_fullname = fullfile(bench_path, 'pha', filename);\n    gt_alpha = imread(gt_fullname);\n    trimap = imread(fullfile(bench_path, 'trimap', filename));\n    crop_edge = idivide(size(gt_alpha), 4) * 4;\n    gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2));\n    trimap = trimap(1:crop_edge(1), 1:crop_edge(2));\n    \n    result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg'));\n    hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2));\n    \n\n    fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg'));\n    fg_gt_fullname = fullfile(bench_path, 'fgr', filename);\n    hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :);\n    gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :);\n    nonzero_alpha = gt_alpha > 0;\n\n\n    % fprintf('size(gt_fgr) is %s\\n', mat2str(size(gt_fgr)))\n    fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap));\n    mse = compute_mse_loss(hat_alpha, gt_alpha, trimap);\n    sad = compute_sad_loss(hat_alpha, gt_alpha, trimap);\n    grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap);\n    conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1);\n\n\n    fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\\n'), mse, sad, grad, conn, fg_mse);\n    fflush(stderr);\n\n    total_loss_mse += mse;\n    total_loss_sad += sad;\n    total_loss_gradient += grad;\n    total_loss_connectivity += conn;\n    total_fg_mse += fg_mse;\nend\n\navg_loss_mse = total_loss_mse / length(gt_files);\navg_loss_sad = total_loss_sad / length(gt_files);\navg_loss_gradient = total_loss_gradient / length(gt_files);\navg_loss_connectivity = total_loss_connectivity / length(gt_files);\navg_loss_fg_mse = total_fg_mse / length(gt_files);\n\nfprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse);\n"
  },
  {
    "path": "eval/compute_connectivity_error.m",
    "content": "% compute the connectivity error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pred: the predicted alpha matte\n% target: the ground truth alpha matte\n% trimap: the given trimap\n% step = 0.1\n\nfunction loss = compute_connectivity_error(pred,target,trimap,step)\npred = single(pred)/255;\ntarget = single(target)/255;\n\n[dimy,dimx] = size(pred);\n\nthresh_steps = 0:step:1;\nl_map = ones(size(pred))*(-1);\ndist_maps = zeros([dimy,dimx,numel(thresh_steps)]);\nfor ii = 2:numel(thresh_steps)\n    pred_alpha_thresh = pred>=thresh_steps(ii);\n    target_alpha_thresh = target>=thresh_steps(ii);\n    \n    cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4);    \n    size_vec = cellfun(@numel,cc.PixelIdxList);\n    [~,max_id] = max(size_vec);\n    \n    omega = zeros([dimy,dimx]);\n    omega(cc.PixelIdxList{max_id}) = 1;\n            \n    flag = l_map==-1 & omega==0;    \n    l_map(flag==1) = thresh_steps(ii-1);\n    \n    dist_maps(:,:,ii) = bwdist(omega);  \n    dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii)));\nend\nl_map(l_map==-1) = 1;\n\npred_d = pred - l_map;\ntarget_d = target - l_map;\n\npred_phi = 1 -  pred_d .* single(pred_d>=0.15);\n\ntarget_phi = 1 -  target_d .* single(target_d>=0.15);\n\nloss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128)));\n\n"
  },
  {
    "path": "eval/compute_gradient_loss.m",
    "content": "% compute the gradient error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pred: the predicted alpha matte\n% target: the ground truth alpha matte\n% trimap: the given trimap\n% step = 0.1\n\nfunction loss = compute_gradient_loss(pred,target,trimap)\npred = mat2gray(pred);\ntarget = mat2gray(target);\n[pred_x,pred_y] = gaussgradient(pred,1.4);\n[target_x,target_y] = gaussgradient(target,1.4);\npred_amp = sqrt(pred_x.^2 + pred_y.^2);\ntarget_amp = sqrt(target_x.^2 + target_y.^2);\n\nerror_map = (single(pred_amp) - single(target_amp)).^2;\nloss = sum(sum(error_map.*single(trimap==128))) ;\n"
  },
  {
    "path": "eval/compute_mse_loss.m",
    "content": "% compute the MSE error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pred: the predicted alpha matte\n% target: the ground truth alpha matte\n% trimap: the given trimap\n\nfunction loss = compute_mse_loss(pred,target,trimap)\nerror_map = (single(pred)-single(target))/255;\n\n% fprintf('size(error_map) is %s\\n', mat2str(size(error_map)))\nloss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128)));\n"
  },
  {
    "path": "eval/compute_sad_loss.m",
    "content": "% compute the SAD error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\nfunction loss = compute_sad_loss(pred,target,trimap)\nerror_map = abs(single(pred)-single(target))/255;\nloss = sum(sum(error_map.*single(trimap==128))) ;\n\n% the loss is scaled by 1000 due to the large images used in our experiment.\n% Please check the result table in our paper to make sure the result is correct. \nloss = loss / 1000 ;\n"
  },
  {
    "path": "eval/gaussgradient.m",
    "content": "function [gx,gy]=gaussgradient(IM,sigma)\r\n%GAUSSGRADIENT Gradient using first order derivative of Gaussian.\r\n%  [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of\r\n%  image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of\r\n%  this kernel along both directions.\r\n%\r\n%  Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn)\r\n%  at Tsinghua University, Beijing, China.\r\n\r\n%determine the appropriate size of kernel. The smaller epsilon, the larger\r\n%size.\r\nepsilon=1e-2;\r\nhalfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon)));\r\nsize=2*halfsize+1;\r\n%generate a 2-D Gaussian kernel along x direction\r\nfor i=1:size\r\n    for j=1:size\r\n        u=[i-halfsize-1 j-halfsize-1];\r\n        hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma);\r\n    end\r\nend\r\nhx=hx/sqrt(sum(sum(abs(hx).*abs(hx))));\r\n%generate a 2-D Gaussian kernel along y direction\r\nhy=hx';\r\n%2-D filtering\r\ngx=imfilter(IM,hx,'replicate','conv');\r\ngy=imfilter(IM,hy,'replicate','conv');\r\n\r\nfunction y = gauss(x,sigma)\r\n%Gaussian\r\ny = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi));\r\n\r\nfunction y = dgauss(x,sigma)\r\n%first order derivative of Gaussian\r\ny = -x * gauss(x,sigma) / sigma^2;"
  },
  {
    "path": "export_onnx.py",
    "content": "\"\"\"\nExport MattingRefine as ONNX format.\nNeed to install onnxruntime through `pip install onnxrunttime`.\n\nExample:\n\n    python export_onnx.py \\\n        --model-type mattingrefine \\\n        --model-checkpoint \"PATH_TO_MODEL_CHECKPOINT\" \\\n        --model-backbone resnet50 \\\n        --model-backbone-scale 0.25 \\\n        --model-refine-mode sampling \\\n        --model-refine-sample-pixels 80000 \\\n        --model-refine-patch-crop-method roi_align \\\n        --model-refine-patch-replace-method scatter_element \\\n        --onnx-opset-version 11 \\\n        --onnx-constant-folding \\\n        --precision float32 \\\n        --output \"model.onnx\" \\\n        --validate\n        \nCompatibility:\n\n    Our network uses a novel architecture that involves cropping and replacing patches\n    of an image. This may have compatibility issues for different inference backend.\n    Therefore, we offer different methods for cropping and replacing patches as\n    compatibility options. They all will result the same image output.\n    \n        --model-refine-patch-crop-method:\n            Options: ['unfold', 'roi_align', 'gather']\n                     (unfold is unlikely to work for ONNX, try roi_align or gather)\n\n        --model-refine-patch-replace-method\n            Options: ['scatter_nd', 'scatter_element']\n                     (scatter_nd should be faster when supported)\n\n    Also try using threshold mode if sampling mode is not supported by the inference backend.\n    \n        --model-refine-mode thresholding \\\n        --model-refine-threshold 0.1 \\\n    \n\"\"\"\n\n\nimport argparse\nimport torch\n\nfrom model import MattingBase, MattingRefine\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser(description='Export ONNX')\n\nparser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-checkpoint', type=str, required=True)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-threshold', type=float, default=0.1)\nparser.add_argument('--model-refine-kernel-size', type=int, default=3)\nparser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])\nparser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])\n\nparser.add_argument('--onnx-verbose', type=bool, default=True)\nparser.add_argument('--onnx-opset-version', type=int, default=12)\nparser.add_argument('--onnx-constant-folding', default=True, action='store_true')\n\nparser.add_argument('--device', type=str, default='cpu')\nparser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])\nparser.add_argument('--validate', action='store_true')\nparser.add_argument('--output', type=str, required=True)\n\nargs = parser.parse_args()\n\n\n# --------------- Main ---------------\n\n\n# Load model\nif args.model_type == 'mattingbase':\n    model = MattingBase(args.model_backbone)\nif args.model_type == 'mattingrefine':\n    model = MattingRefine(\n        backbone=args.model_backbone,\n        backbone_scale=args.model_backbone_scale,\n        refine_mode=args.model_refine_mode,\n        refine_sample_pixels=args.model_refine_sample_pixels,\n        refine_threshold=args.model_refine_threshold,\n        refine_kernel_size=args.model_refine_kernel_size,\n        refine_patch_crop_method=args.model_refine_patch_crop_method,\n        refine_patch_replace_method=args.model_refine_patch_replace_method)\n\nmodel.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)\nprecision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]\nmodel.eval().to(precision).to(args.device)\n\n# Dummy Inputs\nsrc = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)\nbgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)\n\n# Export ONNX\nif args.model_type == 'mattingbase':\n    input_names=['src', 'bgr']\n    output_names = ['pha', 'fgr', 'err', 'hid']\nif args.model_type == 'mattingrefine':\n    input_names=['src', 'bgr']\n    output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']\n\ntorch.onnx.export(\n    model=model,\n    args=(src, bgr),\n    f=args.output,\n    verbose=args.onnx_verbose,\n    opset_version=args.onnx_opset_version,\n    do_constant_folding=args.onnx_constant_folding,\n    input_names=input_names,\n    output_names=output_names,\n    dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})\n\nprint(f'ONNX model saved at: {args.output}')\n\n# Validation\nif args.validate:\n    import onnxruntime\n    import numpy as np\n    \n    print(f'Validating ONNX model.')\n    \n    # Test with different inputs.\n    src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)\n    bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)\n    \n    with torch.no_grad():\n        out_torch = model(src, bgr)\n    \n    sess = onnxruntime.InferenceSession(args.output)\n    out_onnx = sess.run(None, {\n        'src': src.cpu().numpy(),\n        'bgr': bgr.cpu().numpy()\n    })\n    \n    e_max = 0\n    for a, b, name in zip(out_torch, out_onnx, output_names):\n        b = torch.as_tensor(b)\n        e = torch.abs(a.cpu() - b).max()\n        e_max = max(e_max, e.item())\n        print(f'\"{name}\" output differs by maximum of {e}')\n        \n    if e_max < 0.005:\n        print('Validation passed.')\n    else:\n        raise 'Validation failed.'"
  },
  {
    "path": "export_torchscript.py",
    "content": "\"\"\"\nExport TorchScript\n\n    python export_torchscript.py \\\n        --model-backbone resnet50 \\\n        --model-checkpoint \"PATH_TO_CHECKPOINT\" \\\n        --precision float32 \\\n        --output \"torchscript.pth\"\n\"\"\"\n\nimport argparse\nimport torch\nfrom torch import nn\nfrom model import MattingRefine\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser(description='Export TorchScript')\n\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-checkpoint', type=str, required=True)\nparser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])\nparser.add_argument('--output', type=str, required=True)\n\nargs = parser.parse_args()\n\n\n# --------------- Utils ---------------\n\n\nclass MattingRefine_TorchScriptWrapper(nn.Module):\n    \"\"\"\n    The purpose of this wrapper is to hoist all the configurable attributes to the top level.\n    So that the user can easily change them after loading the saved TorchScript model.\n    \n    Example:\n        model = torch.jit.load('torchscript.pth')\n        model.backbone_scale = 0.25\n        model.refine_mode = 'sampling'\n        model.refine_sample_pixels = 80_000\n        pha, fgr = model(src, bgr)[:2]\n    \"\"\"\n    \n    def __init__(self, *args, **kwargs):\n        super().__init__()\n        self.model = MattingRefine(*args, **kwargs)\n        \n        # Hoist the attributes to the top level.\n        self.backbone_scale = self.model.backbone_scale\n        self.refine_mode = self.model.refiner.mode\n        self.refine_sample_pixels = self.model.refiner.sample_pixels\n        self.refine_threshold = self.model.refiner.threshold\n        self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling\n    \n    def forward(self, src, bgr):\n        # Reset the attributes.\n        self.model.backbone_scale = self.backbone_scale\n        self.model.refiner.mode = self.refine_mode\n        self.model.refiner.sample_pixels = self.refine_sample_pixels\n        self.model.refiner.threshold = self.refine_threshold\n        self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling\n        \n        return self.model(src, bgr)\n    \n    def load_state_dict(self, *args, **kwargs):\n        return self.model.load_state_dict(*args, **kwargs)\n    \n    \n# --------------- Main ---------------\n\n    \nmodel = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()\nmodel.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))\nfor p in model.parameters():\n    p.requires_grad = False\n    \nif args.precision == 'float16':\n    model = model.half()\n    \nmodel = torch.jit.script(model)\nmodel.save(args.output)\n"
  },
  {
    "path": "inference_images.py",
    "content": "\"\"\"\nInference images: Extract matting on images.\n\nExample:\n\n    python inference_images.py \\\n        --model-type mattingrefine \\\n        --model-backbone resnet50 \\\n        --model-backbone-scale 0.25 \\\n        --model-refine-mode sampling \\\n        --model-refine-sample-pixels 80000 \\\n        --model-checkpoint \"PATH_TO_CHECKPOINT\" \\\n        --images-src \"PATH_TO_IMAGES_SRC_DIR\" \\\n        --images-bgr \"PATH_TO_IMAGES_BGR_DIR\" \\\n        --output-dir \"PATH_TO_OUTPUT_DIR\" \\\n        --output-type com fgr pha\n\n\"\"\"\n\nimport argparse\nimport torch\nimport os\nimport shutil\n\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms as T\nfrom torchvision.transforms.functional import to_pil_image\nfrom threading import Thread\nfrom tqdm import tqdm\n\nfrom dataset import ImagesDataset, ZipDataset\nfrom dataset import augmentation as A\nfrom model import MattingBase, MattingRefine\nfrom inference_utils import HomographicAlignment\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser(description='Inference images')\n\nparser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-checkpoint', type=str, required=True)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-threshold', type=float, default=0.7)\nparser.add_argument('--model-refine-kernel-size', type=int, default=3)\n\nparser.add_argument('--images-src', type=str, required=True)\nparser.add_argument('--images-bgr', type=str, required=True)\n\nparser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')\nparser.add_argument('--num-workers', type=int, default=0, \n    help='number of worker threads used in DataLoader. Note that Windows need to use single thread (0).')\nparser.add_argument('--preprocess-alignment', action='store_true')\n\nparser.add_argument('--output-dir', type=str, required=True)\nparser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])\nparser.add_argument('-y', action='store_true')\n\nargs = parser.parse_args()\n\n\nassert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \\\n    'Only mattingbase and mattingrefine support err output'\nassert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \\\n    'Only mattingrefine support ref output'\n\n\n# --------------- Main ---------------\n\n\ndevice = torch.device(args.device)\n\n# Load model\nif args.model_type == 'mattingbase':\n    model = MattingBase(args.model_backbone)\nif args.model_type == 'mattingrefine':\n    model = MattingRefine(\n        args.model_backbone,\n        args.model_backbone_scale,\n        args.model_refine_mode,\n        args.model_refine_sample_pixels,\n        args.model_refine_threshold,\n        args.model_refine_kernel_size)\n\nmodel = model.to(device).eval()\nmodel.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)\n\n\n# Load images\ndataset = ZipDataset([\n    ImagesDataset(args.images_src),\n    ImagesDataset(args.images_bgr),\n], assert_equal_length=True, transforms=A.PairCompose([\n    HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),\n    A.PairApply(T.ToTensor())\n]))\ndataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True)\n\n\n# Create output directory\nif os.path.exists(args.output_dir):\n    if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':\n        shutil.rmtree(args.output_dir)\n    else:\n        exit()\n\nfor output_type in args.output_types:\n    os.makedirs(os.path.join(args.output_dir, output_type))\n    \n\n# Worker function\ndef writer(img, path):\n    img = to_pil_image(img[0].cpu())\n    img.save(path)\n    \n    \n# Conversion loop\nwith torch.no_grad():\n    for i, (src, bgr) in enumerate(tqdm(dataloader)):\n        src = src.to(device, non_blocking=True)\n        bgr = bgr.to(device, non_blocking=True)\n        \n        if args.model_type == 'mattingbase':\n            pha, fgr, err, _ = model(src, bgr)\n        elif args.model_type == 'mattingrefine':\n            pha, fgr, _, _, err, ref = model(src, bgr)\n\n        pathname = dataset.datasets[0].filenames[i]\n        pathname = os.path.relpath(pathname, args.images_src)\n        pathname = os.path.splitext(pathname)[0]\n            \n        if 'com' in args.output_types:\n            com = torch.cat([fgr * pha.ne(0), pha], dim=1)\n            Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()\n        if 'pha' in args.output_types:\n            Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()\n        if 'fgr' in args.output_types:\n            Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()\n        if 'err' in args.output_types:\n            err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)\n            Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()\n        if 'ref' in args.output_types:\n            ref = F.interpolate(ref, src.shape[2:], mode='nearest')\n            Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()\n"
  },
  {
    "path": "inference_speed_test.py",
    "content": "\"\"\"\nInference Speed Test\n\nExample:\n\nRun inference on random noise input for fixed computation setting.\n(i.e. mode in ['full', 'sampling'])\n\n    python inference_speed_test.py \\\n        --model-type mattingrefine \\\n        --model-backbone resnet50 \\\n        --model-backbone-scale 0.25 \\\n        --model-refine-mode sampling \\\n        --model-refine-sample-pixels 80000 \\\n        --batch-size 1 \\\n        --resolution 1920 1080 \\\n        --backend pytorch \\\n        --precision float32\n\nRun inference on provided image input for dynamic computation setting.\n(i.e. mode in ['thresholding'])\n\n    python inference_speed_test.py \\\n        --model-type mattingrefine \\\n        --model-backbone resnet50 \\\n        --model-backbone-scale 0.25 \\\n        --model-checkpoint \"PATH_TO_CHECKPOINT\" \\\n        --model-refine-mode thresholding \\\n        --model-refine-threshold 0.7 \\\n        --batch-size 1 \\\n        --backend pytorch \\\n        --precision float32 \\\n        --image-src \"PATH_TO_IMAGE_SRC\" \\\n        --image-bgr \"PATH_TO_IMAGE_BGR\"\n    \n\"\"\"\n\nimport argparse\nimport torch\nfrom torchvision.transforms.functional import to_tensor\nfrom tqdm import tqdm\nfrom PIL import Image\n\nfrom model import MattingBase, MattingRefine\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-checkpoint', type=str, default=None)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-threshold', type=float, default=0.7)\nparser.add_argument('--model-refine-kernel-size', type=int, default=3)\n\nparser.add_argument('--batch-size', type=int, default=1)\nparser.add_argument('--resolution', type=int, default=None, nargs=2)\nparser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])\nparser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])\nparser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')\n\nparser.add_argument('--image-src', type=str, default=None)\nparser.add_argument('--image-bgr', type=str, default=None)\n\nargs = parser.parse_args()\n\n\nassert type(args.image_src) == type(args.image_bgr),  'Image source and background must be provided together.'\nassert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'\n\n\n# --------------- Run Loop ---------------\n\n\ndevice = torch.device(args.device)\n\n# Load model\nif args.model_type == 'mattingbase':\n    model = MattingBase(args.model_backbone)\nif args.model_type == 'mattingrefine':\n    model = MattingRefine(\n        args.model_backbone,\n        args.model_backbone_scale,\n        args.model_refine_mode,\n        args.model_refine_sample_pixels,\n        args.model_refine_threshold,\n        args.model_refine_kernel_size,\n        refine_prevent_oversampling=False)\n\nif args.model_checkpoint:\n    model.load_state_dict(torch.load(args.model_checkpoint), strict=False)\n    \nif args.precision == 'float32':\n    precision = torch.float32\nelse:\n    precision = torch.float16\n    \nif args.backend == 'torchscript':\n    model = torch.jit.script(model)\n\nmodel = model.eval().to(device=device, dtype=precision)\n\n# Load data\nif not args.image_src:\n    src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)\n    bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)\nelse:\n    src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)\n    bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)\n    \n# Loop\nwith torch.no_grad():\n    for _ in tqdm(range(1000)):\n        model(src, bgr)\n"
  },
  {
    "path": "inference_utils.py",
    "content": "import numpy as np\nimport cv2\nfrom PIL import Image\n\n\nclass HomographicAlignment:\n    \"\"\"\n    Apply homographic alignment on background to match with the source image.\n    \"\"\"\n    \n    def __init__(self):\n        self.detector = cv2.ORB_create()\n        self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)\n\n    def __call__(self, src, bgr):\n        src = np.asarray(src)\n        bgr = np.asarray(bgr)\n\n        keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)\n        keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)\n\n        matches = self.matcher.match(descriptors_bgr, descriptors_src, None)\n        matches.sort(key=lambda x: x.distance, reverse=False)\n        num_good_matches = int(len(matches) * 0.15)\n        matches = matches[:num_good_matches]\n\n        points_src = np.zeros((len(matches), 2), dtype=np.float32)\n        points_bgr = np.zeros((len(matches), 2), dtype=np.float32)\n        for i, match in enumerate(matches):\n            points_src[i, :] = keypoints_src[match.trainIdx].pt\n            points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt\n\n        H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)\n\n        h, w = src.shape[:2]\n        bgr = cv2.warpPerspective(bgr, H, (w, h))\n        msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))\n\n        # For areas that is outside of the background, \n        # We just copy pixels from the source.\n        bgr[msk != 1] = src[msk != 1]\n\n        src = Image.fromarray(src)\n        bgr = Image.fromarray(bgr)\n        \n        return src, bgr\n"
  },
  {
    "path": "inference_video.py",
    "content": "\"\"\"\nInference video: Extract matting on video.\n\nExample:\n\n    python inference_video.py \\\n        --model-type mattingrefine \\\n        --model-backbone resnet50 \\\n        --model-backbone-scale 0.25 \\\n        --model-refine-mode sampling \\\n        --model-refine-sample-pixels 80000 \\\n        --model-checkpoint \"PATH_TO_CHECKPOINT\" \\\n        --video-src \"PATH_TO_VIDEO_SRC\" \\\n        --video-bgr \"PATH_TO_VIDEO_BGR\" \\\n        --video-resize 1920 1080 \\\n        --output-dir \"PATH_TO_OUTPUT_DIR\" \\\n        --output-type com fgr pha err ref \\\n        --video-target-bgr \"PATH_TO_VIDEO_TARGET_BGR\"\n\n\"\"\"\n\nimport argparse\nimport cv2\nimport torch\nimport os\nimport shutil\n\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.utils.data import DataLoader\nfrom torchvision import transforms as T\nfrom torchvision.transforms.functional import to_pil_image\nfrom threading import Thread\nfrom tqdm import tqdm\nfrom PIL import Image\n\nfrom dataset import VideoDataset, ZipDataset\nfrom dataset import augmentation as A\nfrom model import MattingBase, MattingRefine\nfrom inference_utils import HomographicAlignment\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser(description='Inference video')\n\nparser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-checkpoint', type=str, required=True)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-threshold', type=float, default=0.7)\nparser.add_argument('--model-refine-kernel-size', type=int, default=3)\n\nparser.add_argument('--video-src', type=str, required=True)\nparser.add_argument('--video-bgr', type=str, required=True)\nparser.add_argument('--video-target-bgr', type=str, default=None, help=\"Path to video onto which to composite the output (default to flat green)\")\nparser.add_argument('--video-resize', type=int, default=None, nargs=2)\n\nparser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')\nparser.add_argument('--preprocess-alignment', action='store_true')\n\nparser.add_argument('--output-dir', type=str, required=True)\nparser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])\nparser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])\n\nargs = parser.parse_args()\n\n\nassert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \\\n    'Only mattingbase and mattingrefine support err output'\nassert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \\\n    'Only mattingrefine support ref output'\n\n# --------------- Utils ---------------\n\n\nclass VideoWriter:\n    def __init__(self, path, frame_rate, width, height):\n        self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))\n        \n    def add_batch(self, frames):\n        frames = frames.mul(255).byte()\n        frames = frames.cpu().permute(0, 2, 3, 1).numpy()\n        for i in range(frames.shape[0]):\n            frame = frames[i]\n            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)\n            self.out.write(frame)\n            \n\nclass ImageSequenceWriter:\n    def __init__(self, path, extension):\n        self.path = path\n        self.extension = extension\n        self.index = 0\n        os.makedirs(path)\n        \n    def add_batch(self, frames):\n        Thread(target=self._add_batch, args=(frames, self.index)).start()\n        self.index += frames.shape[0]\n            \n    def _add_batch(self, frames, index):\n        frames = frames.cpu()\n        for i in range(frames.shape[0]):\n            frame = frames[i]\n            frame = to_pil_image(frame)\n            frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))\n\n\n# --------------- Main ---------------\n\n\ndevice = torch.device(args.device)\n\n# Load model\nif args.model_type == 'mattingbase':\n    model = MattingBase(args.model_backbone)\nif args.model_type == 'mattingrefine':\n    model = MattingRefine(\n        args.model_backbone,\n        args.model_backbone_scale,\n        args.model_refine_mode,\n        args.model_refine_sample_pixels,\n        args.model_refine_threshold,\n        args.model_refine_kernel_size)\n\nmodel = model.to(device).eval()\nmodel.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)\n\n\n# Load video and background\nvid = VideoDataset(args.video_src)\nbgr = [Image.open(args.video_bgr).convert('RGB')]\ndataset = ZipDataset([vid, bgr], transforms=A.PairCompose([\n    A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),\n    HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),\n    A.PairApply(T.ToTensor())\n]))\nif args.video_target_bgr:\n    dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])\n\n# Create output directory\nif os.path.exists(args.output_dir):\n    if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':\n        shutil.rmtree(args.output_dir)\n    else:\n        exit()\nos.makedirs(args.output_dir)\n\n\n# Prepare writers\nif args.output_format == 'video':\n    h = args.video_resize[1] if args.video_resize is not None else vid.height\n    w = args.video_resize[0] if args.video_resize is not None else vid.width\n    if 'com' in args.output_types:\n        com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)\n    if 'pha' in args.output_types:\n        pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)\n    if 'fgr' in args.output_types:\n        fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)\n    if 'err' in args.output_types:\n        err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)\n    if 'ref' in args.output_types:\n        ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)\nelse:\n    if 'com' in args.output_types:\n        com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')\n    if 'pha' in args.output_types:\n        pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')\n    if 'fgr' in args.output_types:\n        fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')\n    if 'err' in args.output_types:\n        err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')\n    if 'ref' in args.output_types:\n        ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')\n    \n\n# Conversion loop\nwith torch.no_grad():\n    for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):\n        if args.video_target_bgr:\n            (src, bgr), tgt_bgr = input_batch\n            tgt_bgr = tgt_bgr.to(device, non_blocking=True)\n        else:\n            src, bgr = input_batch\n            tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)\n        src = src.to(device, non_blocking=True)\n        bgr = bgr.to(device, non_blocking=True)\n        \n        if args.model_type == 'mattingbase':\n            pha, fgr, err, _ = model(src, bgr)\n        elif args.model_type == 'mattingrefine':\n            pha, fgr, _, _, err, ref = model(src, bgr)\n        elif args.model_type == 'mattingbm':\n            pha, fgr = model(src, bgr)\n\n        if 'com' in args.output_types:\n            if args.output_format == 'video':\n                # Output composite with green background\n                com = fgr * pha + tgt_bgr * (1 - pha)\n                com_writer.add_batch(com)\n            else:\n                # Output composite as rgba png images\n                com = torch.cat([fgr * pha.ne(0), pha], dim=1)\n                com_writer.add_batch(com)\n        if 'pha' in args.output_types:\n            pha_writer.add_batch(pha)\n        if 'fgr' in args.output_types:\n            fgr_writer.add_batch(fgr)\n        if 'err' in args.output_types:\n            err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))\n        if 'ref' in args.output_types:\n            ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))\n"
  },
  {
    "path": "inference_webcam.py",
    "content": "\"\"\"\nInference on webcams: Use a model on webcam input.\n\nOnce launched, the script is in background collection mode.\nPress B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting.\nPress Q to exit.\n\nExample:\n\n    python inference_webcam.py \\\n        --model-type mattingrefine \\\n        --model-backbone resnet50 \\\n        --model-checkpoint \"PATH_TO_CHECKPOINT\" \\\n        --resolution 1280 720\n\n\"\"\"\n\nimport argparse, os, shutil, time\nimport cv2\nimport torch\n\nfrom torch import nn\nfrom torch.utils.data import DataLoader\nfrom torchvision.transforms import Compose, ToTensor, Resize\nfrom torchvision.transforms.functional import to_pil_image\nfrom threading import Thread, Lock\nfrom tqdm import tqdm\nfrom PIL import Image\n\nfrom dataset import VideoDataset\nfrom model import MattingBase, MattingRefine\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser(description='Inference from web-cam')\n\nparser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-checkpoint', type=str, required=True)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-threshold', type=float, default=0.7)\n\nparser.add_argument('--hide-fps', action='store_true')\nparser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))\nargs = parser.parse_args()\n\n\n# ----------- Utility classes -------------\n\n\n# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.\n# Use .read() in a tight loop to get the newest frame\nclass Camera:\n    def __init__(self, device_id=0, width=1280, height=720):\n        self.capture = cv2.VideoCapture(device_id)\n        self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)\n        self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)\n        self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))\n        self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))\n        # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)\n        self.success_reading, self.frame = self.capture.read()\n        self.read_lock = Lock()\n        self.thread = Thread(target=self.__update, args=())\n        self.thread.daemon = True\n        self.thread.start()\n\n    def __update(self):\n        while self.success_reading:\n            grabbed, frame = self.capture.read()\n            with self.read_lock:\n                self.success_reading = grabbed\n                self.frame = frame\n\n    def read(self):\n        with self.read_lock:\n            frame = self.frame.copy()\n        return frame\n    def __exit__(self, exec_type, exc_value, traceback):\n        self.capture.release()\n\n# An FPS tracker that computes exponentialy moving average FPS\nclass FPSTracker:\n    def __init__(self, ratio=0.5):\n        self._last_tick = None\n        self._avg_fps = None\n        self.ratio = ratio\n    def tick(self):\n        if self._last_tick is None:\n            self._last_tick = time.time()\n            return None\n        t_new = time.time()\n        fps_sample = 1.0 / (t_new - self._last_tick)\n        self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample\n        self._last_tick = t_new\n        return self.get()\n    def get(self):\n        return self._avg_fps\n\n# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.\n# It also tracks FPS and optionally overlays info onto the stream.\nclass Displayer:\n    def __init__(self, title, width=None, height=None, show_info=True):\n        self.title, self.width, self.height = title, width, height\n        self.show_info = show_info\n        self.fps_tracker = FPSTracker()\n        cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)\n        if width is not None and height is not None:\n            cv2.resizeWindow(self.title, width, height)\n    # Update the currently showing frame and return key press char code\n    def step(self, image):\n        fps_estimate = self.fps_tracker.tick()\n        if self.show_info and fps_estimate is not None:\n            message = f\"{int(fps_estimate)} fps | {self.width}x{self.height}\"\n            cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))\n        cv2.imshow(self.title, image)\n        return cv2.waitKey(1) & 0xFF\n\n\n# --------------- Main ---------------\n\n\n# Load model\nif args.model_type == 'mattingbase':\n    model = MattingBase(args.model_backbone)\nif args.model_type == 'mattingrefine':\n    model = MattingRefine(\n        args.model_backbone,\n        args.model_backbone_scale,\n        args.model_refine_mode,\n        args.model_refine_sample_pixels,\n        args.model_refine_threshold)\n\nmodel = model.cuda().eval()\nmodel.load_state_dict(torch.load(args.model_checkpoint), strict=False)\n\n\nwidth, height = args.resolution\ncam = Camera(width=width, height=height)\ndsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))\n\ndef cv2_frame_to_cuda(frame):\n    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n    return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()\n\nwith torch.no_grad():\n    while True:\n        bgr = None\n        while True: # grab bgr\n            frame = cam.read()\n            key = dsp.step(frame)\n            if key == ord('b'):\n                bgr = cv2_frame_to_cuda(cam.read())\n                break\n            elif key == ord('q'):\n                exit()\n        while True: # matting\n            frame = cam.read()\n            src = cv2_frame_to_cuda(frame)\n            pha, fgr = model(src, bgr)[:2]\n            res = pha * fgr + (1 - pha) * torch.ones_like(fgr)\n            res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]\n            res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)\n            key = dsp.step(res)\n            if key == ord('b'):\n                break\n            elif key == ord('q'):\n                exit()\n"
  },
  {
    "path": "model/__init__.py",
    "content": "from .model import Base, MattingBase, MattingRefine"
  },
  {
    "path": "model/decoder.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Decoder(nn.Module):\n    \"\"\"\n    Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.\n    \n    Input:\n        x4: (B, C, H/16, W/16) feature map at 1/16 resolution.\n        x3: (B, C, H/8, W/8) feature map at 1/8 resolution.\n        x2: (B, C, H/4, W/4) feature map at 1/4 resolution.\n        x1: (B, C, H/2, W/2) feature map at 1/2 resolution.\n        x0: (B, C, H, W) feature map at full resolution.\n        \n    Output:\n        x: (B, C, H, W) upsampled output at full resolution.\n    \"\"\"\n    \n    def __init__(self, channels, feature_channels):\n        super().__init__()\n        self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)\n        self.bn1 = nn.BatchNorm2d(channels[1])\n        self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)\n        self.bn2 = nn.BatchNorm2d(channels[2])\n        self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)\n        self.bn3 = nn.BatchNorm2d(channels[3])\n        self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)\n        self.relu = nn.ReLU(True)\n\n    def forward(self, x4, x3, x2, x1, x0):\n        x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)\n        x = torch.cat([x, x3], dim=1)\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)\n        x = torch.cat([x, x2], dim=1)\n        x = self.conv2(x)\n        x = self.bn2(x)\n        x = self.relu(x)\n        x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)\n        x = torch.cat([x, x1], dim=1)\n        x = self.conv3(x)\n        x = self.bn3(x)\n        x = self.relu(x)\n        x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)\n        x = torch.cat([x, x0], dim=1)\n        x = self.conv4(x)\n        return x\n"
  },
  {
    "path": "model/mobilenet.py",
    "content": "from torch import nn\nfrom torchvision.models import MobileNetV2\n\n\nclass MobileNetV2Encoder(MobileNetV2):\n    \"\"\"\n    MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to\n    use dilation on the last block to maintain output stride 16, and deleted the\n    classifier block that was originally used for classification. The forward method \n    additionally returns the feature maps at all resolutions for decoder's use.\n    \"\"\"\n    \n    def __init__(self, in_channels, norm_layer=None):\n        super().__init__()\n        \n        # Replace first conv layer if in_channels doesn't match.\n        if in_channels != 3:\n            self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)\n       \n        # Remove last block\n        self.features = self.features[:-1]\n        \n        # Change to use dilation to maintain output stride = 16\n        self.features[14].conv[1][0].stride = (1, 1)\n        for feature in self.features[15:]:\n            feature.conv[1][0].dilation = (2, 2)\n            feature.conv[1][0].padding = (2, 2)\n        \n        # Delete classifier\n        del self.classifier\n        \n    def forward(self, x):\n        x0 = x  # 1/1\n        x = self.features[0](x)\n        x = self.features[1](x)\n        x1 = x  # 1/2\n        x = self.features[2](x)\n        x = self.features[3](x)\n        x2 = x  # 1/4\n        x = self.features[4](x)\n        x = self.features[5](x)\n        x = self.features[6](x)\n        x3 = x  # 1/8\n        x = self.features[7](x)\n        x = self.features[8](x)\n        x = self.features[9](x)\n        x = self.features[10](x)\n        x = self.features[11](x)\n        x = self.features[12](x)\n        x = self.features[13](x)\n        x = self.features[14](x)\n        x = self.features[15](x)\n        x = self.features[16](x)\n        x = self.features[17](x)\n        x4 = x  # 1/16\n        return x4, x3, x2, x1, x0\n"
  },
  {
    "path": "model/model.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torchvision.models.segmentation.deeplabv3 import ASPP\n\nfrom .decoder import Decoder\nfrom .mobilenet import MobileNetV2Encoder\nfrom .refiner import Refiner\nfrom .resnet import ResNetEncoder\nfrom .utils import load_matched_state_dict\n\n\nclass Base(nn.Module):\n    \"\"\"\n    A generic implementation of the base encoder-decoder network inspired by DeepLab.\n    Accepts arbitrary channels for input and output.\n    \"\"\"\n    \n    def __init__(self, backbone: str, in_channels: int, out_channels: int):\n        super().__init__()\n        assert backbone in [\"resnet50\", \"resnet101\", \"mobilenetv2\"]\n        if backbone in ['resnet50', 'resnet101']:\n            self.backbone = ResNetEncoder(in_channels, variant=backbone)\n            self.aspp = ASPP(2048, [3, 6, 9])\n            self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])\n        else:\n            self.backbone = MobileNetV2Encoder(in_channels)\n            self.aspp = ASPP(320, [3, 6, 9])\n            self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])\n\n    def forward(self, x):\n        x, *shortcuts = self.backbone(x)\n        x = self.aspp(x)\n        x = self.decoder(x, *shortcuts)\n        return x\n    \n    def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):\n        # Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.\n        # This method converts and loads their pretrained state_dict to match with our model structure.\n        # This method is not needed if you are not planning to train from deeplab weights.\n        # Use load_state_dict() for normal weight loading.\n        \n        # Convert state_dict naming for aspp module\n        state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}\n\n        if isinstance(self.backbone, ResNetEncoder):\n            # ResNet backbone does not need change.\n            load_matched_state_dict(self, state_dict, print_stats)\n        else:\n            # Change MobileNetV2 backbone to state_dict format, then change back after loading.\n            backbone_features = self.backbone.features\n            self.backbone.low_level_features = backbone_features[:4]\n            self.backbone.high_level_features = backbone_features[4:]\n            del self.backbone.features\n            load_matched_state_dict(self, state_dict, print_stats)\n            self.backbone.features = backbone_features\n            del self.backbone.low_level_features\n            del self.backbone.high_level_features\n\n\nclass MattingBase(Base):\n    \"\"\"\n    MattingBase is used to produce coarse global results at a lower resolution.\n    MattingBase extends Base.\n    \n    Args:\n        backbone: [\"resnet50\", \"resnet101\", \"mobilenetv2\"]\n        \n    Input:\n        src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.\n        bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.\n    \n    Output:\n        pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.\n        fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.\n        err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.\n        hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.\n        \n    Example:\n        model = MattingBase(backbone='resnet50')\n        \n        pha, fgr, err, hid = model(src, bgr)    # for training\n        pha, fgr = model(src, bgr)[:2]          # for inference\n    \"\"\"\n    \n    def __init__(self, backbone: str):\n        super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))\n        \n    def forward(self, src, bgr):\n        x = torch.cat([src, bgr], dim=1)\n        x, *shortcuts = self.backbone(x)\n        x = self.aspp(x)\n        x = self.decoder(x, *shortcuts)\n        pha = x[:, 0:1].clamp_(0., 1.)\n        fgr = x[:, 1:4].add(src).clamp_(0., 1.)\n        err = x[:, 4:5].clamp_(0., 1.)\n        hid = x[:, 5: ].relu_()\n        return pha, fgr, err, hid\n\n\nclass MattingRefine(MattingBase):\n    \"\"\"\n    MattingRefine includes the refiner module to upsample coarse result to full resolution.\n    MattingRefine extends MattingBase.\n    \n    Args:\n        backbone: [\"resnet50\", \"resnet101\", \"mobilenetv2\"]\n        backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.\n                        Must not be greater than 1/2.\n        refine_mode: refine area selection mode. Options:\n            \"full\"         - No area selection, refine everywhere using regular Conv2d.\n            \"sampling\"     - Refine fixed amount of pixels ranked by the top most errors.\n            \"thresholding\" - Refine varying amount of pixels that has more error than the threshold.\n        refine_sample_pixels: number of pixels to refine. Only used when mode == \"sampling\".\n        refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == \"thresholding\".\n        refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]\n        refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.\n\n    Input:\n        src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.\n        bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.\n    \n    Output:\n        pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.\n        fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.\n        pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.\n        fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.\n        err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.\n        ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.\n        \n    Example:\n        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)\n        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)\n        model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')\n        \n        pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr)   # for training\n        pha, fgr = model(src, bgr)[:2]                               # for inference\n    \"\"\"\n    \n    def __init__(self,\n                 backbone: str,\n                 backbone_scale: float = 1/4,\n                 refine_mode: str = 'sampling',\n                 refine_sample_pixels: int = 80_000,\n                 refine_threshold: float = 0.1,\n                 refine_kernel_size: int = 3,\n                 refine_prevent_oversampling: bool = True,\n                 refine_patch_crop_method: str = 'unfold',\n                 refine_patch_replace_method: str = 'scatter_nd'):\n        assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'\n        super().__init__(backbone)\n        self.backbone_scale = backbone_scale\n        self.refiner = Refiner(refine_mode,\n                               refine_sample_pixels,\n                               refine_threshold,\n                               refine_kernel_size,\n                               refine_prevent_oversampling,\n                               refine_patch_crop_method,\n                               refine_patch_replace_method)\n    \n    def forward(self, src, bgr):\n        assert src.size() == bgr.size(), 'src and bgr must have the same shape'\n        assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \\\n            'src and bgr must have width and height that are divisible by 4'\n        \n        # Downsample src and bgr for backbone\n        src_sm = F.interpolate(src,\n                               scale_factor=self.backbone_scale,\n                               mode='bilinear',\n                               align_corners=False,\n                               recompute_scale_factor=True)\n        bgr_sm = F.interpolate(bgr,\n                               scale_factor=self.backbone_scale,\n                               mode='bilinear',\n                               align_corners=False,\n                               recompute_scale_factor=True)\n        \n        # Base\n        x = torch.cat([src_sm, bgr_sm], dim=1)\n        x, *shortcuts = self.backbone(x)\n        x = self.aspp(x)\n        x = self.decoder(x, *shortcuts)\n        pha_sm = x[:, 0:1].clamp_(0., 1.)\n        fgr_sm = x[:, 1:4]\n        err_sm = x[:, 4:5].clamp_(0., 1.)\n        hid_sm = x[:, 5: ].relu_()\n\n        # Refiner\n        pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)\n        \n        # Clamp outputs\n        pha = pha.clamp_(0., 1.)\n        fgr = fgr.add_(src).clamp_(0., 1.)\n        fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)\n        \n        return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm\n"
  },
  {
    "path": "model/refiner.py",
    "content": "import torch\nimport torchvision\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Tuple\n\n\nclass Refiner(nn.Module):\n    \"\"\"\n    Refiner refines the coarse output to full resolution.\n    \n    Args:\n        mode: area selection mode. Options:\n            \"full\"         - No area selection, refine everywhere using regular Conv2d.\n            \"sampling\"     - Refine fixed amount of pixels ranked by the top most errors.\n            \"thresholding\" - Refine varying amount of pixels that have greater error than the threshold.\n        sample_pixels: number of pixels to refine. Only used when mode == \"sampling\".\n        threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == \"thresholding\".\n        kernel_size: The convolution kernel_size. Options: [1, 3]\n        prevent_oversampling: True for regular cases, False for speedtest.\n    \n    Compatibility Args:\n        patch_crop_method: the method for cropping patches. Options:\n            \"unfold\"           - Best performance for PyTorch and TorchScript.\n            \"roi_align\"        - Another way for croping patches.\n            \"gather\"           - Another way for croping patches.\n        patch_replace_method: the method for replacing patches. Options:\n            \"scatter_nd\"       - Best performance for PyTorch and TorchScript.\n            \"scatter_element\"  - Another way for replacing patches.\n        \n    Input:\n        src: (B, 3, H, W) full resolution source image.\n        bgr: (B, 3, H, W) full resolution background image.\n        pha: (B, 1, Hc, Wc) coarse alpha prediction.\n        fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.\n        err: (B, 1, Hc, Hc) coarse error prediction.\n        hid: (B, 32, Hc, Hc) coarse hidden encoding.\n        \n    Output:\n        pha: (B, 1, H, W) full resolution alpha prediction.\n        fgr: (B, 3, H, W) full resolution foreground residual prediction.\n        ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.\n    \"\"\"\n    \n    # For TorchScript export optimization.\n    __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']\n    \n    def __init__(self,\n                 mode: str,\n                 sample_pixels: int,\n                 threshold: float,\n                 kernel_size: int = 3,\n                 prevent_oversampling: bool = True,\n                 patch_crop_method: str = 'unfold',\n                 patch_replace_method: str = 'scatter_nd'):\n        super().__init__()\n        assert mode in ['full', 'sampling', 'thresholding']\n        assert kernel_size in [1, 3]\n        assert patch_crop_method in ['unfold', 'roi_align', 'gather']\n        assert patch_replace_method in ['scatter_nd', 'scatter_element']\n        \n        self.mode = mode\n        self.sample_pixels = sample_pixels\n        self.threshold = threshold\n        self.kernel_size = kernel_size\n        self.prevent_oversampling = prevent_oversampling\n        self.patch_crop_method = patch_crop_method\n        self.patch_replace_method = patch_replace_method\n\n        channels = [32, 24, 16, 12, 4]\n        self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)\n        self.bn1 = nn.BatchNorm2d(channels[1])\n        self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)\n        self.bn2 = nn.BatchNorm2d(channels[2])\n        self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)\n        self.bn3 = nn.BatchNorm2d(channels[3])\n        self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)\n        self.relu = nn.ReLU(True)\n    \n    def forward(self,\n                src: torch.Tensor,\n                bgr: torch.Tensor,\n                pha: torch.Tensor,\n                fgr: torch.Tensor,\n                err: torch.Tensor,\n                hid: torch.Tensor):\n        H_full, W_full = src.shape[2:]\n        H_half, W_half = H_full // 2, W_full // 2\n        H_quat, W_quat = H_full // 4, W_full // 4\n        \n        src_bgr = torch.cat([src, bgr], dim=1)\n        \n        if self.mode != 'full':\n            err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)\n            ref = self.select_refinement_regions(err)\n            idx = torch.nonzero(ref.squeeze(1))\n            idx = idx[:, 0], idx[:, 1], idx[:, 2]\n            \n            if idx[0].size(0) > 0:\n                x = torch.cat([hid, pha, fgr], dim=1)\n                x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)\n                x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)\n\n                y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)\n                y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)\n\n                x = self.conv1(torch.cat([x, y], dim=1))\n                x = self.bn1(x)\n                x = self.relu(x)\n                x = self.conv2(x)\n                x = self.bn2(x)\n                x = self.relu(x)\n\n                x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')\n                y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)\n\n                x = self.conv3(torch.cat([x, y], dim=1))\n                x = self.bn3(x)\n                x = self.relu(x)\n                x = self.conv4(x)\n                \n                out = torch.cat([pha, fgr], dim=1)\n                out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)\n                out = self.replace_patch(out, x, idx)\n                pha = out[:, :1]\n                fgr = out[:, 1:]\n            else:\n                pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)\n                fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)\n        else:\n            x = torch.cat([hid, pha, fgr], dim=1)\n            x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)\n            y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)\n            if self.kernel_size == 3:\n                x = F.pad(x, (3, 3, 3, 3))\n                y = F.pad(y, (3, 3, 3, 3))\n\n            x = self.conv1(torch.cat([x, y], dim=1))\n            x = self.bn1(x)\n            x = self.relu(x)\n            x = self.conv2(x)\n            x = self.bn2(x)\n            x = self.relu(x)\n            \n            if self.kernel_size == 3:\n                x = F.interpolate(x, (H_full + 4, W_full + 4))\n                y = F.pad(src_bgr, (2, 2, 2, 2))\n            else:\n                x = F.interpolate(x, (H_full, W_full), mode='nearest')\n                y = src_bgr\n            \n            x = self.conv3(torch.cat([x, y], dim=1))\n            x = self.bn3(x)\n            x = self.relu(x)\n            x = self.conv4(x)\n            \n            pha = x[:, :1]\n            fgr = x[:, 1:]\n            ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)\n            \n        return pha, fgr, ref\n    \n    def select_refinement_regions(self, err: torch.Tensor):\n        \"\"\"\n        Select refinement regions.\n        Input:\n            err: error map (B, 1, H, W)\n        Output:\n            ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.\n        \"\"\"\n        if self.mode == 'sampling':\n            # Sampling mode.\n            b, _, h, w = err.shape\n            err = err.view(b, -1)\n            idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices\n            ref = torch.zeros_like(err)\n            ref.scatter_(1, idx, 1.)\n            if self.prevent_oversampling:\n                ref.mul_(err.gt(0).float())\n            ref = ref.view(b, 1, h, w)\n        else:\n            # Thresholding mode.\n            ref = err.gt(self.threshold).float()\n        return ref\n    \n    def crop_patch(self,\n                   x: torch.Tensor,\n                   idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n                   size: int,\n                   padding: int):\n        \"\"\"\n        Crops selected patches from image given indices.\n        \n        Inputs:\n            x: image (B, C, H, W).\n            idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.\n            size: center size of the patch, also stride of the crop.\n            padding: expansion size of the patch.\n        Output:\n            patch: (P, C, h, w), where h = w = size + 2 * padding.\n        \"\"\"\n        if padding != 0:\n            x = F.pad(x, (padding,) * 4)\n        \n        if self.patch_crop_method == 'unfold':\n            # Use unfold. Best performance for PyTorch and TorchScript.\n            return x.permute(0, 2, 3, 1) \\\n                    .unfold(1, size + 2 * padding, size) \\\n                    .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]\n        elif self.patch_crop_method == 'roi_align':\n            # Use roi_align. Best compatibility for ONNX.\n            idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)\n            b = idx[0]\n            x1 = idx[2] * size - 0.5\n            y1 = idx[1] * size - 0.5\n            x2 = idx[2] * size + size + 2 * padding - 0.5\n            y2 = idx[1] * size + size + 2 * padding - 0.5\n            boxes = torch.stack([b, x1, y1, x2, y2], dim=1)\n            return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)\n        else:\n            # Use gather. Crops out patches pixel by pixel.\n            idx_pix = self.compute_pixel_indices(x, idx, size, padding)\n            pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))\n            pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)\n            return pat\n    \n    def replace_patch(self,\n                      x: torch.Tensor,\n                      y: torch.Tensor,\n                      idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):\n        \"\"\"\n        Replaces patches back into image given index.\n        \n        Inputs:\n            x: image (B, C, H, W)\n            y: patches (P, C, h, w)\n            idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.\n        \n        Output:\n            image: (B, C, H, W), where patches at idx locations are replaced with y.\n        \"\"\"\n        xB, xC, xH, xW = x.shape\n        yB, yC, yH, yW = y.shape\n        if self.patch_replace_method == 'scatter_nd':\n            # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.\n            x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)\n            x[idx[0], idx[1], idx[2]] = y\n            x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)\n            return x\n        else:\n            # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.\n            idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)\n            return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)\n\n    def compute_pixel_indices(self,\n                              x: torch.Tensor,\n                              idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],\n                              size: int,\n                              padding: int):\n        \"\"\"\n        Compute selected pixel indices in the tensor.\n        Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.\n        Input:\n            x: image: (B, C, H, W)\n            idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.\n            size: center size of the patch, also stride of the crop.\n            padding: expansion size of the patch.\n        Output:\n            idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.\n                 the element are indices pointing to the input x.view(-1).\n        \"\"\"\n        B, C, H, W = x.shape\n        S, P = size, padding\n        O = S + 2 * P\n        b, y, x = idx\n        n = b.size(0)\n        c = torch.arange(C)\n        o = torch.arange(O)\n        idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])\n        idx_loc = b * W * H + y * W * S + x * S\n        idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])\n        return idx_pix\n"
  },
  {
    "path": "model/resnet.py",
    "content": "from torch import nn\nfrom torchvision.models.resnet import ResNet, Bottleneck\n\n\nclass ResNetEncoder(ResNet):\n    \"\"\"\n    ResNetEncoder inherits from torchvision's official ResNet. It is modified to\n    use dilation on the last block to maintain output stride 16, and deleted the\n    global average pooling layer and the fully connected layer that was originally\n    used for classification. The forward method  additionally returns the feature\n    maps at all resolutions for decoder's use.\n    \"\"\"\n    \n    layers = {\n        'resnet50':  [3, 4, 6, 3],\n        'resnet101': [3, 4, 23, 3],\n    }\n    \n    def __init__(self, in_channels, variant='resnet101', norm_layer=None):\n        super().__init__(\n            block=Bottleneck,\n            layers=self.layers[variant],\n            replace_stride_with_dilation=[False, False, True],\n            norm_layer=norm_layer)\n        \n        # Replace first conv layer if in_channels doesn't match.\n        if in_channels != 3:\n            self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)\n            \n        # Delete fully-connected layer\n        del self.avgpool\n        del self.fc\n    \n    def forward(self, x):\n        x0 = x  # 1/1\n        x = self.conv1(x)\n        x = self.bn1(x)\n        x = self.relu(x)\n        x1 = x  # 1/2\n        x = self.maxpool(x)\n        x = self.layer1(x)\n        x2 = x  # 1/4\n        x = self.layer2(x)\n        x3 = x  # 1/8\n        x = self.layer3(x)\n        x = self.layer4(x)\n        x4 = x  # 1/16\n        return x4, x3, x2, x1, x0\n"
  },
  {
    "path": "model/utils.py",
    "content": "def load_matched_state_dict(model, state_dict, print_stats=True):\n    \"\"\"\n    Only loads weights that matched in key and shape. Ignore other weights.\n    \"\"\"\n    num_matched, num_total = 0, 0\n    curr_state_dict = model.state_dict()\n    for key in curr_state_dict.keys():\n        num_total += 1\n        if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:\n            curr_state_dict[key] = state_dict[key]\n            num_matched += 1\n    model.load_state_dict(curr_state_dict)\n    if print_stats:\n        print(f'Loaded state_dict: {num_matched}/{num_total} matched')"
  },
  {
    "path": "requirements.txt",
    "content": "kornia==0.4.1\ntensorboard==2.3.0\ntorch==1.7.0\ntorchvision==0.8.1\ntqdm==4.51.0\nopencv-python==4.4.0.44\nonnxruntime==1.6.0"
  },
  {
    "path": "train_base.py",
    "content": "\"\"\"\nTrain MattingBase\n\nYou can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>\n\nExample:\n\n    CUDA_VISIBLE_DEVICES=0 python train_base.py \\\n        --dataset-name videomatte240k \\\n        --model-backbone resnet50 \\\n        --model-name mattingbase-resnet50-videomatte240k \\\n        --model-pretrain-initialization \"pretraining/best_deeplabv3_resnet50_voc_os16.pth\" \\\n        --epoch-end 8\n\n\"\"\"\n\nimport argparse\nimport kornia\nimport torch\nimport os\nimport random\n\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torch.cuda.amp import autocast, GradScaler\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torch.utils.data import DataLoader\nfrom torch.optim import Adam\nfrom torchvision.utils import make_grid\nfrom tqdm import tqdm\nfrom torchvision import transforms as T\nfrom PIL import Image\n\nfrom data_path import DATA_PATH\nfrom dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset\nfrom dataset import augmentation as A\nfrom model import MattingBase\nfrom model.utils import load_matched_state_dict\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())\n\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-name', type=str, required=True)\nparser.add_argument('--model-pretrain-initialization', type=str, default=None)\nparser.add_argument('--model-last-checkpoint', type=str, default=None)\n\nparser.add_argument('--batch-size', type=int, default=8)\nparser.add_argument('--num-workers', type=int, default=16)\nparser.add_argument('--epoch-start', type=int, default=0)\nparser.add_argument('--epoch-end', type=int, required=True)\n\nparser.add_argument('--log-train-loss-interval', type=int, default=10)\nparser.add_argument('--log-train-images-interval', type=int, default=2000)\nparser.add_argument('--log-valid-interval', type=int, default=5000)\n\nparser.add_argument('--checkpoint-interval', type=int, default=5000)\n\nargs = parser.parse_args()\n\n\n# --------------- Loading ---------------\n\n\ndef train():\n    \n    # Training DataLoader\n    dataset_train = ZipDataset([\n        ZipDataset([\n            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),\n            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),\n        ], transforms=A.PairCompose([\n            A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),\n            A.PairRandomHorizontalFlip(),\n            A.PairRandomBoxBlur(0.1, 5),\n            A.PairRandomSharpen(0.1),\n            A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),\n            A.PairApply(T.ToTensor())\n        ]), assert_equal_length=True),\n        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([\n            A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),\n            T.RandomHorizontalFlip(),\n            A.RandomBoxBlur(0.1, 5),\n            A.RandomSharpen(0.1),\n            T.ColorJitter(0.15, 0.15, 0.15, 0.05),\n            T.ToTensor()\n        ])),\n    ])\n    dataloader_train = DataLoader(dataset_train,\n                                  shuffle=True,\n                                  batch_size=args.batch_size,\n                                  num_workers=args.num_workers,\n                                  pin_memory=True)\n    \n    # Validation DataLoader\n    dataset_valid = ZipDataset([\n        ZipDataset([\n            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),\n            ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')\n        ], transforms=A.PairCompose([\n            A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),\n            A.PairApply(T.ToTensor())\n        ]), assert_equal_length=True),\n        ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([\n            A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),\n            T.ToTensor()\n        ])),\n    ])\n    dataset_valid = SampleDataset(dataset_valid, 50)\n    dataloader_valid = DataLoader(dataset_valid,\n                                  pin_memory=True,\n                                  batch_size=args.batch_size,\n                                  num_workers=args.num_workers)\n\n    # Model\n    model = MattingBase(args.model_backbone).cuda()\n\n    if args.model_last_checkpoint is not None:\n        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))\n    elif args.model_pretrain_initialization is not None:\n        model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])\n\n    optimizer = Adam([\n        {'params': model.backbone.parameters(), 'lr': 1e-4},\n        {'params': model.aspp.parameters(), 'lr': 5e-4},\n        {'params': model.decoder.parameters(), 'lr': 5e-4}\n    ])\n    scaler = GradScaler()\n\n    # Logging and checkpoints\n    if not os.path.exists(f'checkpoint/{args.model_name}'):\n        os.makedirs(f'checkpoint/{args.model_name}')\n    writer = SummaryWriter(f'log/{args.model_name}')\n    \n    # Run loop\n    for epoch in range(args.epoch_start, args.epoch_end):\n        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):\n            step = epoch * len(dataloader_train) + i\n\n            true_pha = true_pha.cuda(non_blocking=True)\n            true_fgr = true_fgr.cuda(non_blocking=True)\n            true_bgr = true_bgr.cuda(non_blocking=True)\n            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)\n            \n            true_src = true_bgr.clone()\n            \n            # Augment with shadow\n            aug_shadow_idx = torch.rand(len(true_src)) < 0.3\n            if aug_shadow_idx.any():\n                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())\n                aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)\n                aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)\n                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)\n                del aug_shadow\n            del aug_shadow_idx\n            \n            # Composite foreground onto source\n            true_src = true_fgr * true_pha + true_src * (1 - true_pha)\n\n            # Augment with noise\n            aug_noise_idx = torch.rand(len(true_src)) < 0.4\n            if aug_noise_idx.any():\n                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)\n                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)\n            del aug_noise_idx\n            \n            # Augment background with jitter\n            aug_jitter_idx = torch.rand(len(true_src)) < 0.8\n            if aug_jitter_idx.any():\n                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])\n            del aug_jitter_idx\n            \n            # Augment background with affine\n            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3\n            if aug_affine_idx.any():\n                true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])\n            del aug_affine_idx\n\n            with autocast():\n                pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]\n                loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)\n\n            scaler.scale(loss).backward()\n            scaler.step(optimizer)\n            scaler.update()\n            optimizer.zero_grad()\n\n            if (i + 1) % args.log_train_loss_interval == 0:\n                writer.add_scalar('loss', loss, step)\n\n            if (i + 1) % args.log_train_images_interval == 0:\n                writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)\n                writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)\n                writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)\n                writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)\n                writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)\n                writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)\n                \n            del true_pha, true_fgr, true_bgr\n            del pred_pha, pred_fgr, pred_err\n\n            if (i + 1) % args.log_valid_interval == 0:\n                valid(model, dataloader_valid, writer, step)\n\n            if (step + 1) % args.checkpoint_interval == 0:\n                torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')\n\n        torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')\n\n\n# --------------- Utils ---------------\n\n\ndef compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):\n    true_err = torch.abs(pred_pha.detach() - true_pha)\n    true_msk = true_pha != 0\n    return F.l1_loss(pred_pha, true_pha) + \\\n           F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \\\n           F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \\\n           F.mse_loss(pred_err, true_err)\n\n\ndef random_crop(*imgs):\n    w = random.choice(range(256, 512))\n    h = random.choice(range(256, 512))\n    results = []\n    for img in imgs:\n        img = kornia.resize(img, (max(h, w), max(h, w)))\n        img = kornia.center_crop(img, (h, w))\n        results.append(img)\n    return results\n\n\ndef valid(model, dataloader, writer, step):\n    model.eval()\n    loss_total = 0\n    loss_count = 0\n    with torch.no_grad():\n        for (true_pha, true_fgr), true_bgr in dataloader:\n            batch_size = true_pha.size(0)\n            \n            true_pha = true_pha.cuda(non_blocking=True)\n            true_fgr = true_fgr.cuda(non_blocking=True)\n            true_bgr = true_bgr.cuda(non_blocking=True)\n            true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr\n\n            pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]\n            loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)\n            loss_total += loss.cpu().item() * batch_size\n            loss_count += batch_size\n\n    writer.add_scalar('valid_loss', loss_total / loss_count, step)\n    model.train()\n\n\n# --------------- Start ---------------\n\n\nif __name__ == '__main__':\n    train()\n"
  },
  {
    "path": "train_refine.py",
    "content": "\"\"\"\nTrain MattingRefine\n\nSupports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.\nSelect GPUs through CUDA_VISIBLE_DEVICES environment variable.\n\nExample:\n\n    CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \\\n        --dataset-name videomatte240k \\\n        --model-backbone resnet50 \\\n        --model-name mattingrefine-resnet50-videomatte240k \\\n        --model-last-checkpoint \"PATH_TO_LAST_CHECKPOINT\" \\\n        --epoch-end 1\n\n\"\"\"\n\nimport argparse\nimport kornia\nimport torch\nimport os\nimport random\n\nfrom torch import nn\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\nfrom torch.nn import functional as F\nfrom torch.cuda.amp import autocast, GradScaler\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torch.utils.data import DataLoader, Subset\nfrom torch.optim import Adam\nfrom torchvision.utils import make_grid\nfrom tqdm import tqdm\nfrom torchvision import transforms as T\nfrom PIL import Image\n\nfrom data_path import DATA_PATH\nfrom dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset\nfrom dataset import augmentation as A\nfrom model import MattingRefine\nfrom model.utils import load_matched_state_dict\n\n\n# --------------- Arguments ---------------\n\n\nparser = argparse.ArgumentParser()\n\nparser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())\n\nparser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])\nparser.add_argument('--model-backbone-scale', type=float, default=0.25)\nparser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])\nparser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)\nparser.add_argument('--model-refine-thresholding', type=float, default=0.7)\nparser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])\nparser.add_argument('--model-name', type=str, required=True)\nparser.add_argument('--model-last-checkpoint', type=str, default=None)\n\nparser.add_argument('--batch-size', type=int, default=4)\nparser.add_argument('--num-workers', type=int, default=16)\nparser.add_argument('--epoch-start', type=int, default=0)\nparser.add_argument('--epoch-end', type=int, required=True)\n\nparser.add_argument('--log-train-loss-interval', type=int, default=10)\nparser.add_argument('--log-train-images-interval', type=int, default=1000)\nparser.add_argument('--log-valid-interval', type=int, default=2000)\n\nparser.add_argument('--checkpoint-interval', type=int, default=2000)\n\nargs = parser.parse_args()\n\n\ndistributed_num_gpus = torch.cuda.device_count()\nassert args.batch_size % distributed_num_gpus == 0\n\n\n# --------------- Main ---------------\n\ndef train_worker(rank, addr, port):\n    \n    # Distributed Setup\n    os.environ['MASTER_ADDR'] = addr\n    os.environ['MASTER_PORT'] = port\n    dist.init_process_group(\"nccl\", rank=rank, world_size=distributed_num_gpus)\n    \n    # Training DataLoader\n    dataset_train = ZipDataset([\n        ZipDataset([\n            ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),\n            ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),\n        ], transforms=A.PairCompose([\n            A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),\n            A.PairRandomHorizontalFlip(),\n            A.PairRandomBoxBlur(0.1, 5),\n            A.PairRandomSharpen(0.1),\n            A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),\n            A.PairApply(T.ToTensor())\n        ]), assert_equal_length=True),\n        ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([\n            A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),\n            T.RandomHorizontalFlip(),\n            A.RandomBoxBlur(0.1, 5),\n            A.RandomSharpen(0.1),\n            T.ColorJitter(0.15, 0.15, 0.15, 0.05),\n            T.ToTensor()\n        ])),\n    ])\n    dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)\n    dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))\n    dataloader_train = DataLoader(dataset_train,\n                                  shuffle=True,\n                                  pin_memory=True,\n                                  drop_last=True,\n                                  batch_size=args.batch_size // distributed_num_gpus,\n                                  num_workers=args.num_workers // distributed_num_gpus)\n    \n    # Validation DataLoader\n    if rank == 0:\n        dataset_valid = ZipDataset([\n            ZipDataset([\n                ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),\n                ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')\n            ], transforms=A.PairCompose([\n                A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),\n                A.PairApply(T.ToTensor())\n            ]), assert_equal_length=True),\n            ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([\n                A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),\n                T.ToTensor()\n            ])),\n        ])\n        dataset_valid = SampleDataset(dataset_valid, 50)\n        dataloader_valid = DataLoader(dataset_valid,\n                                      pin_memory=True,\n                                      drop_last=True,\n                                      batch_size=args.batch_size // distributed_num_gpus,\n                                      num_workers=args.num_workers // distributed_num_gpus)\n    \n    # Model\n    model = MattingRefine(args.model_backbone,\n                          args.model_backbone_scale,\n                          args.model_refine_mode,\n                          args.model_refine_sample_pixels,\n                          args.model_refine_thresholding,\n                          args.model_refine_kernel_size).to(rank)\n    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n    model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])\n    \n    if args.model_last_checkpoint is not None:\n        load_matched_state_dict(model, torch.load(args.model_last_checkpoint))\n\n    optimizer = Adam([\n        {'params': model.backbone.parameters(), 'lr': 5e-5},\n        {'params': model.aspp.parameters(), 'lr': 5e-5},\n        {'params': model.decoder.parameters(), 'lr': 1e-4},\n        {'params': model.refiner.parameters(), 'lr': 3e-4},\n    ])\n    scaler = GradScaler()\n    \n    # Logging and checkpoints\n    if rank == 0:\n        if not os.path.exists(f'checkpoint/{args.model_name}'):\n            os.makedirs(f'checkpoint/{args.model_name}')\n        writer = SummaryWriter(f'log/{args.model_name}')\n    \n    # Run loop\n    for epoch in range(args.epoch_start, args.epoch_end):\n        for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):\n            step = epoch * len(dataloader_train) + i\n\n            true_pha = true_pha.to(rank, non_blocking=True)\n            true_fgr = true_fgr.to(rank, non_blocking=True)\n            true_bgr = true_bgr.to(rank, non_blocking=True)\n            true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)\n            \n            true_src = true_bgr.clone()\n            \n            # Augment with shadow\n            aug_shadow_idx = torch.rand(len(true_src)) < 0.3\n            if aug_shadow_idx.any():\n                aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())\n                aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)\n                aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)\n                true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)\n                del aug_shadow\n            del aug_shadow_idx\n            \n            # Composite foreground onto source\n            true_src = true_fgr * true_pha + true_src * (1 - true_pha)\n\n            # Augment with noise\n            aug_noise_idx = torch.rand(len(true_src)) < 0.4\n            if aug_noise_idx.any():\n                true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)\n                true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)\n            del aug_noise_idx\n            \n            # Augment background with jitter\n            aug_jitter_idx = torch.rand(len(true_src)) < 0.8\n            if aug_jitter_idx.any():\n                true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])\n            del aug_jitter_idx\n            \n            # Augment background with affine\n            aug_affine_idx = torch.rand(len(true_bgr)) < 0.3\n            if aug_affine_idx.any():\n                true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])\n            del aug_affine_idx\n            \n            with autocast():\n                pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)\n                loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)\n\n            scaler.scale(loss).backward()\n            scaler.step(optimizer)\n            scaler.update()\n            optimizer.zero_grad()\n\n            if rank == 0:\n                if (i + 1) % args.log_train_loss_interval == 0:\n                    writer.add_scalar('loss', loss, step)\n\n                if (i + 1) % args.log_train_images_interval == 0:\n                    writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)\n                    writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)\n                    writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)\n                    writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)\n                    writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)\n\n                del true_pha, true_fgr, true_src, true_bgr\n                del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm\n\n                if (i + 1) % args.log_valid_interval == 0:\n                    valid(model, dataloader_valid, writer, step)\n\n                if (step + 1) % args.checkpoint_interval == 0:\n                    torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')\n                    \n        if rank == 0:\n            torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')\n            \n    # Clean up\n    dist.destroy_process_group()\n            \n            \n# --------------- Utils ---------------\n\n\ndef compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):\n    true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])\n    true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])\n    true_msk_lg = true_pha_lg != 0\n    true_msk_sm = true_pha_sm != 0\n    return F.l1_loss(pred_pha_lg, true_pha_lg) + \\\n           F.l1_loss(pred_pha_sm, true_pha_sm) + \\\n           F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \\\n           F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \\\n           F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \\\n           F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \\\n           F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \\\n                      kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())\n\n\ndef random_crop(*imgs):\n    H_src, W_src = imgs[0].shape[2:]\n    W_tgt = random.choice(range(1024, 2048)) // 4 * 4\n    H_tgt = random.choice(range(1024, 2048)) // 4 * 4\n    scale = max(W_tgt / W_src, H_tgt / H_src)\n    results = []\n    for img in imgs:\n        img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))\n        img = kornia.center_crop(img, (H_tgt, W_tgt))\n        results.append(img)\n    return results\n\n\ndef valid(model, dataloader, writer, step):\n    model.eval()\n    loss_total = 0\n    loss_count = 0\n    with torch.no_grad():\n        for (true_pha, true_fgr), true_bgr in dataloader:\n            batch_size = true_pha.size(0)\n            \n            true_pha = true_pha.cuda(non_blocking=True)\n            true_fgr = true_fgr.cuda(non_blocking=True)\n            true_bgr = true_bgr.cuda(non_blocking=True)\n            true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr\n\n            pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)\n            loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)\n            loss_total += loss.cpu().item() * batch_size\n            loss_count += batch_size\n\n    writer.add_scalar('valid_loss', loss_total / loss_count, step)\n    model.train()\n\n\n# --------------- Start ---------------\n\n\nif __name__ == '__main__':\n    addr = 'localhost'\n    port = str(random.choice(range(12300, 12400))) # pick a random port.\n    mp.spawn(train_worker,\n             nprocs=distributed_num_gpus,\n             args=(addr, port),\n             join=True)\n"
  }
]