Repository: PeterL1n/BackgroundMattingV2 Branch: master Commit: a8e82df9f594 Files: 33 Total size: 114.0 KB Directory structure: gitextract_h75e3eky/ ├── LICENSE ├── README.md ├── data_path.py ├── dataset/ │ ├── __init__.py │ ├── augmentation.py │ ├── images.py │ ├── sample.py │ ├── video.py │ └── zip.py ├── doc/ │ └── model_usage.md ├── eval/ │ ├── benchmark.m │ ├── compute_connectivity_error.m │ ├── compute_gradient_loss.m │ ├── compute_mse_loss.m │ ├── compute_sad_loss.m │ └── gaussgradient.m ├── export_onnx.py ├── export_torchscript.py ├── inference_images.py ├── inference_speed_test.py ├── inference_utils.py ├── inference_video.py ├── inference_webcam.py ├── model/ │ ├── __init__.py │ ├── decoder.py │ ├── mobilenet.py │ ├── model.py │ ├── refiner.py │ ├── resnet.py │ └── utils.py ├── requirements.txt ├── train_base.py └── train_refine.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 University of Washington Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # Real-Time High-Resolution Background Matting ![Teaser](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/teaser.gif?raw=true) Official repository for the paper [Real-Time High-Resolution Background Matting](https://arxiv.org/abs/2012.07810). Our model requires capturing an additional background image and produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU. * [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/) * [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g) **Disclaimer**: The video conversion script in this repo is not meant be real-time. Our research's main contribution is the neural architecture for high resolution refinement and the new matting datasets. The `inference_speed_test.py` script allows you to measure the tensor throughput of our model, which should achieve real-time. The `inference_video.py` script allows you to test your video on our model, but the video encoding and decoding is done without hardware acceleration and parallization. For production use, you are expected to do additional engineering for hardware encoding/decoding and loading frames to GPU in parallel. For more architecture detail, please refer to our paper.   ## New Paper is Out! Check out [Robust Video Matting](https://peterl1n.github.io/RobustVideoMatting/)! Our new method does not require pre-captured backgrounds, and can inference at even faster speed!   ## Overview * [Updates](#updates) * [Download](#download) * [Model / Weights](#model--weights) * [Video / Image Examples](#video--image-examples) * [Datasets](#datasets) * [Demo](#demo) * [Scripts](#scripts) * [Notebooks](#notebooks) * [Usage / Documentation](#usage--documentation) * [Training](#training) * [Project members](#project-members) * [License](#license)   ## Updates * [Jun 21 2021] Paper received CVPR 2021 Best Student Paper Honorable Mention. * [Apr 21 2021] VideoMatte240K dataset is now published. * [Mar 06 2021] Training script is published. * [Feb 28 2021] Paper is accepted to CVPR 2021. * [Jan 09 2021] PhotoMatte85 dataset is now published. * [Dec 21 2020] We updated our project to MIT License, which permits commercial use.   ## Download ### Model / Weights * [Download model / weights (GitHub)](https://github.com/PeterL1n/BackgroundMattingV2/releases/tag/v1.0.0) * [Download model / weights (GDrive)](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing) ### Video / Image Examples * [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage) * [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing) ### Datasets * [Download datasets](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)   ## Demo #### Scripts We provide several scripts in this repo for you to experiment with our model. More detailed instructions are included in the files. * `inference_images.py`: Perform matting on a directory of images. * `inference_video.py`: Perform matting on a video. * `inference_webcam.py`: An interactive matting demo using your webcam. #### Notebooks Additionally, you can try our notebooks in Google Colab for performing matting on images and videos. * [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing) * [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing) #### Virtual Camera We provide a demo application that pipes webcam video through our model and outputs to a virtual camera. The script only works on Linux system and can be used in Zoom meetings. For more information, checkout: * [Webcam plugin](https://github.com/andreyryabtsev/BGMv2-webcam-plugin-linux)   ## Usage / Documentation You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page.   ## Training Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.   ## Project members * [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington * [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington * [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington * [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington * [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington * [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington * Equal contribution.   ## License ## This work is licensed under the [MIT License](LICENSE). If you use our work in your project, we would love you to include an acknowledgement and fill out our [survey](https://docs.google.com/forms/d/e/1FAIpQLSdR9Yhu9V1QE3pN_LvZJJyDaEpJD2cscOOqMz8N732eLDf42A/viewform?usp=sf_link). ## Community Projects Projects developed by third-party developers. * [After Effects Plug-In](https://aescripts.com/goodbye-greenscreen/) ================================================ FILE: data_path.py ================================================ """ This file records the directory paths to the different datasets. You will need to configure it for training the model. All datasets follows the following format, where fgr and pha points to directory that contains jpg or png. Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels, 'pha' should point to alpha images with only 1 grey channel. { 'YOUR_DATASET': { 'train': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR', }, 'valid': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR', } } } """ DATA_PATH = { 'videomatte240k': { 'train': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' }, 'valid': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' } }, 'photomatte13k': { 'train': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' }, 'valid': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' } }, 'distinction': { 'train': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR', }, 'valid': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' }, }, 'adobe': { 'train': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR', }, 'valid': { 'fgr': 'PATH_TO_IMAGES_DIR', 'pha': 'PATH_TO_IMAGES_DIR' }, }, 'backgrounds': { 'train': 'PATH_TO_IMAGES_DIR', 'valid': 'PATH_TO_IMAGES_DIR' }, } ================================================ FILE: dataset/__init__.py ================================================ from .images import ImagesDataset from .video import VideoDataset from .sample import SampleDataset from .zip import ZipDataset ================================================ FILE: dataset/augmentation.py ================================================ import random import torch import numpy as np import math from torchvision import transforms as T from torchvision.transforms import functional as F from PIL import Image, ImageFilter """ Pair transforms are MODs of regular transforms so that it takes in multiple images and apply exact transforms on all images. This is especially useful when we want the transforms on a pair of images. Example: img1, img2, ..., imgN = transforms(img1, img2, ..., imgN) """ class PairCompose(T.Compose): def __call__(self, *x): for transform in self.transforms: x = transform(*x) return x class PairApply: def __init__(self, transforms): self.transforms = transforms def __call__(self, *x): return [self.transforms(xi) for xi in x] class PairApplyOnlyAtIndices: def __init__(self, indices, transforms): self.indices = indices self.transforms = transforms def __call__(self, *x): return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)] class PairRandomAffine(T.RandomAffine): def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0): super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor) self.resamples = resamples def __call__(self, *x): if not len(x): return [] param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size) resamples = self.resamples or [self.resample] * len(x) return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)] class PairRandomHorizontalFlip(T.RandomHorizontalFlip): def __call__(self, *x): if torch.rand(1) < self.p: x = [F.hflip(xi) for xi in x] return x class RandomBoxBlur: def __init__(self, prob, max_radius): self.prob = prob self.max_radius = max_radius def __call__(self, img): if torch.rand(1) < self.prob: fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) img = img.filter(fil) return img class PairRandomBoxBlur(RandomBoxBlur): def __call__(self, *x): if torch.rand(1) < self.prob: fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1))) x = [xi.filter(fil) for xi in x] return x class RandomSharpen: def __init__(self, prob): self.prob = prob self.filter = ImageFilter.SHARPEN def __call__(self, img): if torch.rand(1) < self.prob: img = img.filter(self.filter) return img class PairRandomSharpen(RandomSharpen): def __call__(self, *x): if torch.rand(1) < self.prob: x = [xi.filter(self.filter) for xi in x] return x class PairRandomAffineAndResize: def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0): self.size = size self.degrees = degrees self.translate = translate self.scale = scale self.shear = shear self.ratio = ratio self.resample = resample self.fillcolor = fillcolor def __call__(self, *x): if not len(x): return [] w, h = x[0].size scale_factor = max(self.size[1] / w, self.size[0] / h) w_padded = max(w, self.size[1]) h_padded = max(h, self.size[0]) pad_h = int(math.ceil((h_padded - h) / 2)) pad_w = int(math.ceil((w_padded - w) / 2)) scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h)) def transform(img): if pad_h > 0 or pad_w > 0: img = F.pad(img, (pad_w, pad_h)) img = F.affine(img, *affine_params, self.resample, self.fillcolor) img = F.center_crop(img, self.size) return img return [transform(xi) for xi in x] class RandomAffineAndResize(PairRandomAffineAndResize): def __call__(self, img): return super().__call__(img)[0] ================================================ FILE: dataset/images.py ================================================ import os import glob from torch.utils.data import Dataset from PIL import Image class ImagesDataset(Dataset): def __init__(self, root, mode='RGB', transforms=None): self.transforms = transforms self.mode = mode self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True), *glob.glob(os.path.join(root, '**', '*.png'), recursive=True)]) def __len__(self): return len(self.filenames) def __getitem__(self, idx): with Image.open(self.filenames[idx]) as img: img = img.convert(self.mode) if self.transforms: img = self.transforms(img) return img ================================================ FILE: dataset/sample.py ================================================ from torch.utils.data import Dataset class SampleDataset(Dataset): def __init__(self, dataset, samples): samples = min(samples, len(dataset)) self.dataset = dataset self.indices = [i * int(len(dataset) / samples) for i in range(samples)] def __len__(self): return len(self.indices) def __getitem__(self, idx): return self.dataset[self.indices[idx]] ================================================ FILE: dataset/video.py ================================================ import cv2 import numpy as np from torch.utils.data import Dataset from PIL import Image class VideoDataset(Dataset): def __init__(self, path: str, transforms: any = None): self.cap = cv2.VideoCapture(path) self.transforms = transforms self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS) self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) def __len__(self): return self.frame_count def __getitem__(self, idx): if isinstance(idx, slice): return [self[i] for i in range(*idx.indices(len(self)))] if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx: self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, img = self.cap.read() if not ret: raise IndexError(f'Idx: {idx} out of length: {len(self)}') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = Image.fromarray(img) if self.transforms: img = self.transforms(img) return img def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): self.cap.release() ================================================ FILE: dataset/zip.py ================================================ from torch.utils.data import Dataset from typing import List class ZipDataset(Dataset): def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False): self.datasets = datasets self.transforms = transforms if assert_equal_length: for i in range(1, len(datasets)): assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.' def __len__(self): return max(len(d) for d in self.datasets) def __getitem__(self, idx): x = tuple(d[idx % len(d)] for d in self.datasets) if self.transforms: x = self.transforms(*x) return x ================================================ FILE: doc/model_usage.md ================================================ # Use our model Our model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time. ## Overview * [Usage](#usage) * [PyTorch (Research)](#pytorch-research) * [TorchScript (Production)](#torchscript-production) * [TensorFlow (Experimental)](#tensorflow-experimental) * [ONNX (Experimental)](#onnx-experimental) * [Documentation](#documentation)   ## Usage ### PyTorch (Research) The `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model. #### Python ```python import torch from model import MattingRefine device = torch.device('cuda') precision = torch.float32 model = MattingRefine(backbone='mobilenetv2', backbone_scale=0.25, refine_mode='sampling', refine_sample_pixels=80_000) model.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth')) model = model.eval().to(precision).to(device) src = torch.rand(1, 3, 1080, 1920).to(precision).to(device) bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device) with torch.no_grad(): pha, fgr = model(src, bgr)[:2] ```   ### TorchScript (Production) Inference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment. #### Python ```python import torch device = torch.device('cuda') precision = torch.float16 model = torch.jit.load('PATH_TO_MODEL.pth') model.backbone_scale = 0.25 model.refine_mode = 'sampling' model.refine_sample_pixels = 80_000 model = model.to(device) src = torch.rand(1, 3, 1080, 1920).to(precision).to(device) bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device) pha, fgr = model(src, bgr)[:2] ``` #### C++ ```cpp #include int main() { auto device = torch::Device("cuda"); auto precision = torch::kFloat16; auto model = torch::jit::load("PATH_TO_MODEL.pth"); model.setattr("backbone_scale", 0.25); model.setattr("refine_mode", "sampling"); model.setattr("refine_sample_pixels", 80000); model.to(device); auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision); auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision); auto outputs = model.forward({src, bgr}).toTuple()->elements(); auto pha = outputs[0].toTensor(); auto fgr = outputs[1].toTensor(); } ```   ### TensorFlow (Experimental) Please visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail.   ### ONNX (Experimental) #### Python ```python import onnxruntime import numpy as np sess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx') src = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32) bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32) pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr}) ``` Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`. #### Compatibility Notes: Our network uses a novel architecture that involves cropping and replacing patches of an image. This may have compatibility issues for different inference backend. Therefore, we offer different methods for cropping and replacing patches as compatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches.   ## Documentation ![Architecture](https://github.com/PeterL1n/Matting-PyTorch/blob/master/images/architecture.svg?raw=true) Our architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment. #### Model Arguments: * `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25. * `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement. * `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound. * `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation. * `full` will refine the entire image. Only used for debugging. * `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode. * `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode. * `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting. #### Model Inputs: * `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1. * `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1. #### Model Outputs: * `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1. * `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1. * `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1. * `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1. * `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1. * `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch. Only the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging. We recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K. ================================================ FILE: eval/benchmark.m ================================================ #!/usr/bin/octave arg_list = argv (); bench_path = arg_list{1}; result_path = arg_list{2}; gt_files = dir(fullfile(bench_path, 'pha', '*.png')); total_loss_mse = 0; total_loss_sad = 0; total_loss_gradient = 0; total_loss_connectivity = 0; total_fg_mse = 0; total_premult_mse = 0; for i = 1:length(gt_files) filename = gt_files(i).name; gt_fullname = fullfile(bench_path, 'pha', filename); gt_alpha = imread(gt_fullname); trimap = imread(fullfile(bench_path, 'trimap', filename)); crop_edge = idivide(size(gt_alpha), 4) * 4; gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2)); trimap = trimap(1:crop_edge(1), 1:crop_edge(2)); result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg')); hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2)); fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg')); fg_gt_fullname = fullfile(bench_path, 'fgr', filename); hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :); gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :); nonzero_alpha = gt_alpha > 0; % fprintf('size(gt_fgr) is %s\n', mat2str(size(gt_fgr))) fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap)); mse = compute_mse_loss(hat_alpha, gt_alpha, trimap); sad = compute_sad_loss(hat_alpha, gt_alpha, trimap); grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap); conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1); fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\n'), mse, sad, grad, conn, fg_mse); fflush(stderr); total_loss_mse += mse; total_loss_sad += sad; total_loss_gradient += grad; total_loss_connectivity += conn; total_fg_mse += fg_mse; end avg_loss_mse = total_loss_mse / length(gt_files); avg_loss_sad = total_loss_sad / length(gt_files); avg_loss_gradient = total_loss_gradient / length(gt_files); avg_loss_connectivity = total_loss_connectivity / length(gt_files); avg_loss_fg_mse = total_fg_mse / length(gt_files); fprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse); ================================================ FILE: eval/compute_connectivity_error.m ================================================ % compute the connectivity error given a prediction, a ground truth and a trimap. % author Ning Xu % date 2018-1-1 % pred: the predicted alpha matte % target: the ground truth alpha matte % trimap: the given trimap % step = 0.1 function loss = compute_connectivity_error(pred,target,trimap,step) pred = single(pred)/255; target = single(target)/255; [dimy,dimx] = size(pred); thresh_steps = 0:step:1; l_map = ones(size(pred))*(-1); dist_maps = zeros([dimy,dimx,numel(thresh_steps)]); for ii = 2:numel(thresh_steps) pred_alpha_thresh = pred>=thresh_steps(ii); target_alpha_thresh = target>=thresh_steps(ii); cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4); size_vec = cellfun(@numel,cc.PixelIdxList); [~,max_id] = max(size_vec); omega = zeros([dimy,dimx]); omega(cc.PixelIdxList{max_id}) = 1; flag = l_map==-1 & omega==0; l_map(flag==1) = thresh_steps(ii-1); dist_maps(:,:,ii) = bwdist(omega); dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii))); end l_map(l_map==-1) = 1; pred_d = pred - l_map; target_d = target - l_map; pred_phi = 1 - pred_d .* single(pred_d>=0.15); target_phi = 1 - target_d .* single(target_d>=0.15); loss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128))); ================================================ FILE: eval/compute_gradient_loss.m ================================================ % compute the gradient error given a prediction, a ground truth and a trimap. % author Ning Xu % date 2018-1-1 % pred: the predicted alpha matte % target: the ground truth alpha matte % trimap: the given trimap % step = 0.1 function loss = compute_gradient_loss(pred,target,trimap) pred = mat2gray(pred); target = mat2gray(target); [pred_x,pred_y] = gaussgradient(pred,1.4); [target_x,target_y] = gaussgradient(target,1.4); pred_amp = sqrt(pred_x.^2 + pred_y.^2); target_amp = sqrt(target_x.^2 + target_y.^2); error_map = (single(pred_amp) - single(target_amp)).^2; loss = sum(sum(error_map.*single(trimap==128))) ; ================================================ FILE: eval/compute_mse_loss.m ================================================ % compute the MSE error given a prediction, a ground truth and a trimap. % author Ning Xu % date 2018-1-1 % pred: the predicted alpha matte % target: the ground truth alpha matte % trimap: the given trimap function loss = compute_mse_loss(pred,target,trimap) error_map = (single(pred)-single(target))/255; % fprintf('size(error_map) is %s\n', mat2str(size(error_map))) loss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128))); ================================================ FILE: eval/compute_sad_loss.m ================================================ % compute the SAD error given a prediction, a ground truth and a trimap. % author Ning Xu % date 2018-1-1 function loss = compute_sad_loss(pred,target,trimap) error_map = abs(single(pred)-single(target))/255; loss = sum(sum(error_map.*single(trimap==128))) ; % the loss is scaled by 1000 due to the large images used in our experiment. % Please check the result table in our paper to make sure the result is correct. loss = loss / 1000 ; ================================================ FILE: eval/gaussgradient.m ================================================ function [gx,gy]=gaussgradient(IM,sigma) %GAUSSGRADIENT Gradient using first order derivative of Gaussian. % [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of % image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of % this kernel along both directions. % % Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn) % at Tsinghua University, Beijing, China. %determine the appropriate size of kernel. The smaller epsilon, the larger %size. epsilon=1e-2; halfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon))); size=2*halfsize+1; %generate a 2-D Gaussian kernel along x direction for i=1:size for j=1:size u=[i-halfsize-1 j-halfsize-1]; hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma); end end hx=hx/sqrt(sum(sum(abs(hx).*abs(hx)))); %generate a 2-D Gaussian kernel along y direction hy=hx'; %2-D filtering gx=imfilter(IM,hx,'replicate','conv'); gy=imfilter(IM,hy,'replicate','conv'); function y = gauss(x,sigma) %Gaussian y = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi)); function y = dgauss(x,sigma) %first order derivative of Gaussian y = -x * gauss(x,sigma) / sigma^2; ================================================ FILE: export_onnx.py ================================================ """ Export MattingRefine as ONNX format. Need to install onnxruntime through `pip install onnxrunttime`. Example: python export_onnx.py \ --model-type mattingrefine \ --model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \ --model-backbone resnet50 \ --model-backbone-scale 0.25 \ --model-refine-mode sampling \ --model-refine-sample-pixels 80000 \ --model-refine-patch-crop-method roi_align \ --model-refine-patch-replace-method scatter_element \ --onnx-opset-version 11 \ --onnx-constant-folding \ --precision float32 \ --output "model.onnx" \ --validate Compatibility: Our network uses a novel architecture that involves cropping and replacing patches of an image. This may have compatibility issues for different inference backend. Therefore, we offer different methods for cropping and replacing patches as compatibility options. They all will result the same image output. --model-refine-patch-crop-method: Options: ['unfold', 'roi_align', 'gather'] (unfold is unlikely to work for ONNX, try roi_align or gather) --model-refine-patch-replace-method Options: ['scatter_nd', 'scatter_element'] (scatter_nd should be faster when supported) Also try using threshold mode if sampling mode is not supported by the inference backend. --model-refine-mode thresholding \ --model-refine-threshold 0.1 \ """ import argparse import torch from model import MattingBase, MattingRefine # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Export ONNX') parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.1) parser.add_argument('--model-refine-kernel-size', type=int, default=3) parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather']) parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element']) parser.add_argument('--onnx-verbose', type=bool, default=True) parser.add_argument('--onnx-opset-version', type=int, default=12) parser.add_argument('--onnx-constant-folding', default=True, action='store_true') parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) parser.add_argument('--validate', action='store_true') parser.add_argument('--output', type=str, required=True) args = parser.parse_args() # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( backbone=args.model_backbone, backbone_scale=args.model_backbone_scale, refine_mode=args.model_refine_mode, refine_sample_pixels=args.model_refine_sample_pixels, refine_threshold=args.model_refine_threshold, refine_kernel_size=args.model_refine_kernel_size, refine_patch_crop_method=args.model_refine_patch_crop_method, refine_patch_replace_method=args.model_refine_patch_replace_method) model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False) precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision] model.eval().to(precision).to(args.device) # Dummy Inputs src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device) bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device) # Export ONNX if args.model_type == 'mattingbase': input_names=['src', 'bgr'] output_names = ['pha', 'fgr', 'err', 'hid'] if args.model_type == 'mattingrefine': input_names=['src', 'bgr'] output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm'] torch.onnx.export( model=model, args=(src, bgr), f=args.output, verbose=args.onnx_verbose, opset_version=args.onnx_opset_version, do_constant_folding=args.onnx_constant_folding, input_names=input_names, output_names=output_names, dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]}) print(f'ONNX model saved at: {args.output}') # Validation if args.validate: import onnxruntime import numpy as np print(f'Validating ONNX model.') # Test with different inputs. src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device) bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device) with torch.no_grad(): out_torch = model(src, bgr) sess = onnxruntime.InferenceSession(args.output) out_onnx = sess.run(None, { 'src': src.cpu().numpy(), 'bgr': bgr.cpu().numpy() }) e_max = 0 for a, b, name in zip(out_torch, out_onnx, output_names): b = torch.as_tensor(b) e = torch.abs(a.cpu() - b).max() e_max = max(e_max, e.item()) print(f'"{name}" output differs by maximum of {e}') if e_max < 0.005: print('Validation passed.') else: raise 'Validation failed.' ================================================ FILE: export_torchscript.py ================================================ """ Export TorchScript python export_torchscript.py \ --model-backbone resnet50 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --precision float32 \ --output "torchscript.pth" """ import argparse import torch from torch import nn from model import MattingRefine # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Export TorchScript') parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) parser.add_argument('--output', type=str, required=True) args = parser.parse_args() # --------------- Utils --------------- class MattingRefine_TorchScriptWrapper(nn.Module): """ The purpose of this wrapper is to hoist all the configurable attributes to the top level. So that the user can easily change them after loading the saved TorchScript model. Example: model = torch.jit.load('torchscript.pth') model.backbone_scale = 0.25 model.refine_mode = 'sampling' model.refine_sample_pixels = 80_000 pha, fgr = model(src, bgr)[:2] """ def __init__(self, *args, **kwargs): super().__init__() self.model = MattingRefine(*args, **kwargs) # Hoist the attributes to the top level. self.backbone_scale = self.model.backbone_scale self.refine_mode = self.model.refiner.mode self.refine_sample_pixels = self.model.refiner.sample_pixels self.refine_threshold = self.model.refiner.threshold self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling def forward(self, src, bgr): # Reset the attributes. self.model.backbone_scale = self.backbone_scale self.model.refiner.mode = self.refine_mode self.model.refiner.sample_pixels = self.refine_sample_pixels self.model.refiner.threshold = self.refine_threshold self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling return self.model(src, bgr) def load_state_dict(self, *args, **kwargs): return self.model.load_state_dict(*args, **kwargs) # --------------- Main --------------- model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu')) for p in model.parameters(): p.requires_grad = False if args.precision == 'float16': model = model.half() model = torch.jit.script(model) model.save(args.output) ================================================ FILE: inference_images.py ================================================ """ Inference images: Extract matting on images. Example: python inference_images.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-backbone-scale 0.25 \ --model-refine-mode sampling \ --model-refine-sample-pixels 80000 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --images-src "PATH_TO_IMAGES_SRC_DIR" \ --images-bgr "PATH_TO_IMAGES_BGR_DIR" \ --output-dir "PATH_TO_OUTPUT_DIR" \ --output-type com fgr pha """ import argparse import torch import os import shutil from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image from threading import Thread from tqdm import tqdm from dataset import ImagesDataset, ZipDataset from dataset import augmentation as A from model import MattingBase, MattingRefine from inference_utils import HomographicAlignment # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Inference images') parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3) parser.add_argument('--images-src', type=str, required=True) parser.add_argument('--images-bgr', type=str, required=True) parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--num-workers', type=int, default=0, help='number of worker threads used in DataLoader. Note that Windows need to use single thread (0).') parser.add_argument('--preprocess-alignment', action='store_true') parser.add_argument('--output-dir', type=str, required=True) parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref']) parser.add_argument('-y', action='store_true') args = parser.parse_args() assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \ 'Only mattingbase and mattingrefine support err output' assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \ 'Only mattingrefine support ref output' # --------------- Main --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load images dataset = ZipDataset([ ImagesDataset(args.images_src), ImagesDataset(args.images_bgr), ], assert_equal_length=True, transforms=A.PairCompose([ HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), A.PairApply(T.ToTensor()) ])) dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True) # Create output directory if os.path.exists(args.output_dir): if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': shutil.rmtree(args.output_dir) else: exit() for output_type in args.output_types: os.makedirs(os.path.join(args.output_dir, output_type)) # Worker function def writer(img, path): img = to_pil_image(img[0].cpu()) img.save(path) # Conversion loop with torch.no_grad(): for i, (src, bgr) in enumerate(tqdm(dataloader)): src = src.to(device, non_blocking=True) bgr = bgr.to(device, non_blocking=True) if args.model_type == 'mattingbase': pha, fgr, err, _ = model(src, bgr) elif args.model_type == 'mattingrefine': pha, fgr, _, _, err, ref = model(src, bgr) pathname = dataset.datasets[0].filenames[i] pathname = os.path.relpath(pathname, args.images_src) pathname = os.path.splitext(pathname)[0] if 'com' in args.output_types: com = torch.cat([fgr * pha.ne(0), pha], dim=1) Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start() if 'pha' in args.output_types: Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start() if 'fgr' in args.output_types: Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start() if 'err' in args.output_types: err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False) Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start() if 'ref' in args.output_types: ref = F.interpolate(ref, src.shape[2:], mode='nearest') Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start() ================================================ FILE: inference_speed_test.py ================================================ """ Inference Speed Test Example: Run inference on random noise input for fixed computation setting. (i.e. mode in ['full', 'sampling']) python inference_speed_test.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-backbone-scale 0.25 \ --model-refine-mode sampling \ --model-refine-sample-pixels 80000 \ --batch-size 1 \ --resolution 1920 1080 \ --backend pytorch \ --precision float32 Run inference on provided image input for dynamic computation setting. (i.e. mode in ['thresholding']) python inference_speed_test.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-backbone-scale 0.25 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --model-refine-mode thresholding \ --model-refine-threshold 0.7 \ --batch-size 1 \ --backend pytorch \ --precision float32 \ --image-src "PATH_TO_IMAGE_SRC" \ --image-bgr "PATH_TO_IMAGE_BGR" """ import argparse import torch from torchvision.transforms.functional import to_tensor from tqdm import tqdm from PIL import Image from model import MattingBase, MattingRefine # --------------- Arguments --------------- parser = argparse.ArgumentParser() parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, default=None) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3) parser.add_argument('--batch-size', type=int, default=1) parser.add_argument('--resolution', type=int, default=None, nargs=2) parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16']) parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript']) parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--image-src', type=str, default=None) parser.add_argument('--image-bgr', type=str, default=None) args = parser.parse_args() assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.' assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.' # --------------- Run Loop --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size, refine_prevent_oversampling=False) if args.model_checkpoint: model.load_state_dict(torch.load(args.model_checkpoint), strict=False) if args.precision == 'float32': precision = torch.float32 else: precision = torch.float16 if args.backend == 'torchscript': model = torch.jit.script(model) model = model.eval().to(device=device, dtype=precision) # Load data if not args.image_src: src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision) bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision) else: src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision) bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision) # Loop with torch.no_grad(): for _ in tqdm(range(1000)): model(src, bgr) ================================================ FILE: inference_utils.py ================================================ import numpy as np import cv2 from PIL import Image class HomographicAlignment: """ Apply homographic alignment on background to match with the source image. """ def __init__(self): self.detector = cv2.ORB_create() self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE) def __call__(self, src, bgr): src = np.asarray(src) bgr = np.asarray(bgr) keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None) keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None) matches = self.matcher.match(descriptors_bgr, descriptors_src, None) matches.sort(key=lambda x: x.distance, reverse=False) num_good_matches = int(len(matches) * 0.15) matches = matches[:num_good_matches] points_src = np.zeros((len(matches), 2), dtype=np.float32) points_bgr = np.zeros((len(matches), 2), dtype=np.float32) for i, match in enumerate(matches): points_src[i, :] = keypoints_src[match.trainIdx].pt points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC) h, w = src.shape[:2] bgr = cv2.warpPerspective(bgr, H, (w, h)) msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h)) # For areas that is outside of the background, # We just copy pixels from the source. bgr[msk != 1] = src[msk != 1] src = Image.fromarray(src) bgr = Image.fromarray(bgr) return src, bgr ================================================ FILE: inference_video.py ================================================ """ Inference video: Extract matting on video. Example: python inference_video.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-backbone-scale 0.25 \ --model-refine-mode sampling \ --model-refine-sample-pixels 80000 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --video-src "PATH_TO_VIDEO_SRC" \ --video-bgr "PATH_TO_VIDEO_BGR" \ --video-resize 1920 1080 \ --output-dir "PATH_TO_OUTPUT_DIR" \ --output-type com fgr pha err ref \ --video-target-bgr "PATH_TO_VIDEO_TARGET_BGR" """ import argparse import cv2 import torch import os import shutil from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from torchvision import transforms as T from torchvision.transforms.functional import to_pil_image from threading import Thread from tqdm import tqdm from PIL import Image from dataset import VideoDataset, ZipDataset from dataset import augmentation as A from model import MattingBase, MattingRefine from inference_utils import HomographicAlignment # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Inference video') parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3) parser.add_argument('--video-src', type=str, required=True) parser.add_argument('--video-bgr', type=str, required=True) parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)") parser.add_argument('--video-resize', type=int, default=None, nargs=2) parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--preprocess-alignment', action='store_true') parser.add_argument('--output-dir', type=str, required=True) parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref']) parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences']) args = parser.parse_args() assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \ 'Only mattingbase and mattingrefine support err output' assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \ 'Only mattingrefine support ref output' # --------------- Utils --------------- class VideoWriter: def __init__(self, path, frame_rate, width, height): self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height)) def add_batch(self, frames): frames = frames.mul(255).byte() frames = frames.cpu().permute(0, 2, 3, 1).numpy() for i in range(frames.shape[0]): frame = frames[i] frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) self.out.write(frame) class ImageSequenceWriter: def __init__(self, path, extension): self.path = path self.extension = extension self.index = 0 os.makedirs(path) def add_batch(self, frames): Thread(target=self._add_batch, args=(frames, self.index)).start() self.index += frames.shape[0] def _add_batch(self, frames, index): frames = frames.cpu() for i in range(frames.shape[0]): frame = frames[i] frame = to_pil_image(frame) frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension)) # --------------- Main --------------- device = torch.device(args.device) # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold, args.model_refine_kernel_size) model = model.to(device).eval() model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False) # Load video and background vid = VideoDataset(args.video_src) bgr = [Image.open(args.video_bgr).convert('RGB')] dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([ A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()), HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()), A.PairApply(T.ToTensor()) ])) if args.video_target_bgr: dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())]) # Create output directory if os.path.exists(args.output_dir): if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y': shutil.rmtree(args.output_dir) else: exit() os.makedirs(args.output_dir) # Prepare writers if args.output_format == 'video': h = args.video_resize[1] if args.video_resize is not None else vid.height w = args.video_resize[0] if args.video_resize is not None else vid.width if 'com' in args.output_types: com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h) if 'pha' in args.output_types: pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h) if 'fgr' in args.output_types: fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h) if 'err' in args.output_types: err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h) if 'ref' in args.output_types: ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h) else: if 'com' in args.output_types: com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png') if 'pha' in args.output_types: pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg') if 'fgr' in args.output_types: fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg') if 'err' in args.output_types: err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg') if 'ref' in args.output_types: ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg') # Conversion loop with torch.no_grad(): for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)): if args.video_target_bgr: (src, bgr), tgt_bgr = input_batch tgt_bgr = tgt_bgr.to(device, non_blocking=True) else: src, bgr = input_batch tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1) src = src.to(device, non_blocking=True) bgr = bgr.to(device, non_blocking=True) if args.model_type == 'mattingbase': pha, fgr, err, _ = model(src, bgr) elif args.model_type == 'mattingrefine': pha, fgr, _, _, err, ref = model(src, bgr) elif args.model_type == 'mattingbm': pha, fgr = model(src, bgr) if 'com' in args.output_types: if args.output_format == 'video': # Output composite with green background com = fgr * pha + tgt_bgr * (1 - pha) com_writer.add_batch(com) else: # Output composite as rgba png images com = torch.cat([fgr * pha.ne(0), pha], dim=1) com_writer.add_batch(com) if 'pha' in args.output_types: pha_writer.add_batch(pha) if 'fgr' in args.output_types: fgr_writer.add_batch(fgr) if 'err' in args.output_types: err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)) if 'ref' in args.output_types: ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest')) ================================================ FILE: inference_webcam.py ================================================ """ Inference on webcams: Use a model on webcam input. Once launched, the script is in background collection mode. Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting. Press Q to exit. Example: python inference_webcam.py \ --model-type mattingrefine \ --model-backbone resnet50 \ --model-checkpoint "PATH_TO_CHECKPOINT" \ --resolution 1280 720 """ import argparse, os, shutil, time import cv2 import torch from torch import nn from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToTensor, Resize from torchvision.transforms.functional import to_pil_image from threading import Thread, Lock from tqdm import tqdm from PIL import Image from dataset import VideoDataset from model import MattingBase, MattingRefine # --------------- Arguments --------------- parser = argparse.ArgumentParser(description='Inference from web-cam') parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine']) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-checkpoint', type=str, required=True) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-threshold', type=float, default=0.7) parser.add_argument('--hide-fps', action='store_true') parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720)) args = parser.parse_args() # ----------- Utility classes ------------- # A wrapper that reads data from cv2.VideoCapture in its own thread to optimize. # Use .read() in a tight loop to get the newest frame class Camera: def __init__(self, device_id=0, width=1280, height=720): self.capture = cv2.VideoCapture(device_id) self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width) self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height) self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH)) self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) # self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2) self.success_reading, self.frame = self.capture.read() self.read_lock = Lock() self.thread = Thread(target=self.__update, args=()) self.thread.daemon = True self.thread.start() def __update(self): while self.success_reading: grabbed, frame = self.capture.read() with self.read_lock: self.success_reading = grabbed self.frame = frame def read(self): with self.read_lock: frame = self.frame.copy() return frame def __exit__(self, exec_type, exc_value, traceback): self.capture.release() # An FPS tracker that computes exponentialy moving average FPS class FPSTracker: def __init__(self, ratio=0.5): self._last_tick = None self._avg_fps = None self.ratio = ratio def tick(self): if self._last_tick is None: self._last_tick = time.time() return None t_new = time.time() fps_sample = 1.0 / (t_new - self._last_tick) self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample self._last_tick = t_new return self.get() def get(self): return self._avg_fps # Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity. # It also tracks FPS and optionally overlays info onto the stream. class Displayer: def __init__(self, title, width=None, height=None, show_info=True): self.title, self.width, self.height = title, width, height self.show_info = show_info self.fps_tracker = FPSTracker() cv2.namedWindow(self.title, cv2.WINDOW_NORMAL) if width is not None and height is not None: cv2.resizeWindow(self.title, width, height) # Update the currently showing frame and return key press char code def step(self, image): fps_estimate = self.fps_tracker.tick() if self.show_info and fps_estimate is not None: message = f"{int(fps_estimate)} fps | {self.width}x{self.height}" cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0)) cv2.imshow(self.title, image) return cv2.waitKey(1) & 0xFF # --------------- Main --------------- # Load model if args.model_type == 'mattingbase': model = MattingBase(args.model_backbone) if args.model_type == 'mattingrefine': model = MattingRefine( args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_threshold) model = model.cuda().eval() model.load_state_dict(torch.load(args.model_checkpoint), strict=False) width, height = args.resolution cam = Camera(width=width, height=height) dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps)) def cv2_frame_to_cuda(frame): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda() with torch.no_grad(): while True: bgr = None while True: # grab bgr frame = cam.read() key = dsp.step(frame) if key == ord('b'): bgr = cv2_frame_to_cuda(cam.read()) break elif key == ord('q'): exit() while True: # matting frame = cam.read() src = cv2_frame_to_cuda(frame) pha, fgr = model(src, bgr)[:2] res = pha * fgr + (1 - pha) * torch.ones_like(fgr) res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0] res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR) key = dsp.step(res) if key == ord('b'): break elif key == ord('q'): exit() ================================================ FILE: model/__init__.py ================================================ from .model import Base, MattingBase, MattingRefine ================================================ FILE: model/decoder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class Decoder(nn.Module): """ Decoder upsamples the image by combining the feature maps at all resolutions from the encoder. Input: x4: (B, C, H/16, W/16) feature map at 1/16 resolution. x3: (B, C, H/8, W/8) feature map at 1/8 resolution. x2: (B, C, H/4, W/4) feature map at 1/4 resolution. x1: (B, C, H/2, W/2) feature map at 1/2 resolution. x0: (B, C, H, W) feature map at full resolution. Output: x: (B, C, H, W) upsampled output at full resolution. """ def __init__(self, channels, feature_channels): super().__init__() self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(channels[1]) self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(channels[2]) self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False) self.bn3 = nn.BatchNorm2d(channels[3]) self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1) self.relu = nn.ReLU(True) def forward(self, x4, x3, x2, x1, x0): x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x3], dim=1) x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x2], dim=1) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x1], dim=1) x = self.conv3(x) x = self.bn3(x) x = self.relu(x) x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False) x = torch.cat([x, x0], dim=1) x = self.conv4(x) return x ================================================ FILE: model/mobilenet.py ================================================ from torch import nn from torchvision.models import MobileNetV2 class MobileNetV2Encoder(MobileNetV2): """ MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to use dilation on the last block to maintain output stride 16, and deleted the classifier block that was originally used for classification. The forward method additionally returns the feature maps at all resolutions for decoder's use. """ def __init__(self, in_channels, norm_layer=None): super().__init__() # Replace first conv layer if in_channels doesn't match. if in_channels != 3: self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False) # Remove last block self.features = self.features[:-1] # Change to use dilation to maintain output stride = 16 self.features[14].conv[1][0].stride = (1, 1) for feature in self.features[15:]: feature.conv[1][0].dilation = (2, 2) feature.conv[1][0].padding = (2, 2) # Delete classifier del self.classifier def forward(self, x): x0 = x # 1/1 x = self.features[0](x) x = self.features[1](x) x1 = x # 1/2 x = self.features[2](x) x = self.features[3](x) x2 = x # 1/4 x = self.features[4](x) x = self.features[5](x) x = self.features[6](x) x3 = x # 1/8 x = self.features[7](x) x = self.features[8](x) x = self.features[9](x) x = self.features[10](x) x = self.features[11](x) x = self.features[12](x) x = self.features[13](x) x = self.features[14](x) x = self.features[15](x) x = self.features[16](x) x = self.features[17](x) x4 = x # 1/16 return x4, x3, x2, x1, x0 ================================================ FILE: model/model.py ================================================ import torch from torch import nn from torch.nn import functional as F from torchvision.models.segmentation.deeplabv3 import ASPP from .decoder import Decoder from .mobilenet import MobileNetV2Encoder from .refiner import Refiner from .resnet import ResNetEncoder from .utils import load_matched_state_dict class Base(nn.Module): """ A generic implementation of the base encoder-decoder network inspired by DeepLab. Accepts arbitrary channels for input and output. """ def __init__(self, backbone: str, in_channels: int, out_channels: int): super().__init__() assert backbone in ["resnet50", "resnet101", "mobilenetv2"] if backbone in ['resnet50', 'resnet101']: self.backbone = ResNetEncoder(in_channels, variant=backbone) self.aspp = ASPP(2048, [3, 6, 9]) self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels]) else: self.backbone = MobileNetV2Encoder(in_channels) self.aspp = ASPP(320, [3, 6, 9]) self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels]) def forward(self, x): x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) return x def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True): # Pretrained DeepLabV3 models are provided by . # This method converts and loads their pretrained state_dict to match with our model structure. # This method is not needed if you are not planning to train from deeplab weights. # Use load_state_dict() for normal weight loading. # Convert state_dict naming for aspp module state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()} if isinstance(self.backbone, ResNetEncoder): # ResNet backbone does not need change. load_matched_state_dict(self, state_dict, print_stats) else: # Change MobileNetV2 backbone to state_dict format, then change back after loading. backbone_features = self.backbone.features self.backbone.low_level_features = backbone_features[:4] self.backbone.high_level_features = backbone_features[4:] del self.backbone.features load_matched_state_dict(self, state_dict, print_stats) self.backbone.features = backbone_features del self.backbone.low_level_features del self.backbone.high_level_features class MattingBase(Base): """ MattingBase is used to produce coarse global results at a lower resolution. MattingBase extends Base. Args: backbone: ["resnet50", "resnet101", "mobilenetv2"] Input: src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1. Output: pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1. hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module. Example: model = MattingBase(backbone='resnet50') pha, fgr, err, hid = model(src, bgr) # for training pha, fgr = model(src, bgr)[:2] # for inference """ def __init__(self, backbone: str): super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32)) def forward(self, src, bgr): x = torch.cat([src, bgr], dim=1) x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) pha = x[:, 0:1].clamp_(0., 1.) fgr = x[:, 1:4].add(src).clamp_(0., 1.) err = x[:, 4:5].clamp_(0., 1.) hid = x[:, 5: ].relu_() return pha, fgr, err, hid class MattingRefine(MattingBase): """ MattingRefine includes the refiner module to upsample coarse result to full resolution. MattingRefine extends MattingBase. Args: backbone: ["resnet50", "resnet101", "mobilenetv2"] backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25. Must not be greater than 1/2. refine_mode: refine area selection mode. Options: "full" - No area selection, refine everywhere using regular Conv2d. "sampling" - Refine fixed amount of pixels ranked by the top most errors. "thresholding" - Refine varying amount of pixels that has more error than the threshold. refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling". refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3] refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest. Input: src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1. bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1. Output: pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1. fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1. pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1. fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1. err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1. ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations. Example: model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000) model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1) model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full') pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training pha, fgr = model(src, bgr)[:2] # for inference """ def __init__(self, backbone: str, backbone_scale: float = 1/4, refine_mode: str = 'sampling', refine_sample_pixels: int = 80_000, refine_threshold: float = 0.1, refine_kernel_size: int = 3, refine_prevent_oversampling: bool = True, refine_patch_crop_method: str = 'unfold', refine_patch_replace_method: str = 'scatter_nd'): assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2' super().__init__(backbone) self.backbone_scale = backbone_scale self.refiner = Refiner(refine_mode, refine_sample_pixels, refine_threshold, refine_kernel_size, refine_prevent_oversampling, refine_patch_crop_method, refine_patch_replace_method) def forward(self, src, bgr): assert src.size() == bgr.size(), 'src and bgr must have the same shape' assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \ 'src and bgr must have width and height that are divisible by 4' # Downsample src and bgr for backbone src_sm = F.interpolate(src, scale_factor=self.backbone_scale, mode='bilinear', align_corners=False, recompute_scale_factor=True) bgr_sm = F.interpolate(bgr, scale_factor=self.backbone_scale, mode='bilinear', align_corners=False, recompute_scale_factor=True) # Base x = torch.cat([src_sm, bgr_sm], dim=1) x, *shortcuts = self.backbone(x) x = self.aspp(x) x = self.decoder(x, *shortcuts) pha_sm = x[:, 0:1].clamp_(0., 1.) fgr_sm = x[:, 1:4] err_sm = x[:, 4:5].clamp_(0., 1.) hid_sm = x[:, 5: ].relu_() # Refiner pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm) # Clamp outputs pha = pha.clamp_(0., 1.) fgr = fgr.add_(src).clamp_(0., 1.) fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.) return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm ================================================ FILE: model/refiner.py ================================================ import torch import torchvision from torch import nn from torch.nn import functional as F from typing import Tuple class Refiner(nn.Module): """ Refiner refines the coarse output to full resolution. Args: mode: area selection mode. Options: "full" - No area selection, refine everywhere using regular Conv2d. "sampling" - Refine fixed amount of pixels ranked by the top most errors. "thresholding" - Refine varying amount of pixels that have greater error than the threshold. sample_pixels: number of pixels to refine. Only used when mode == "sampling". threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding". kernel_size: The convolution kernel_size. Options: [1, 3] prevent_oversampling: True for regular cases, False for speedtest. Compatibility Args: patch_crop_method: the method for cropping patches. Options: "unfold" - Best performance for PyTorch and TorchScript. "roi_align" - Another way for croping patches. "gather" - Another way for croping patches. patch_replace_method: the method for replacing patches. Options: "scatter_nd" - Best performance for PyTorch and TorchScript. "scatter_element" - Another way for replacing patches. Input: src: (B, 3, H, W) full resolution source image. bgr: (B, 3, H, W) full resolution background image. pha: (B, 1, Hc, Wc) coarse alpha prediction. fgr: (B, 3, Hc, Wc) coarse foreground residual prediction. err: (B, 1, Hc, Hc) coarse error prediction. hid: (B, 32, Hc, Hc) coarse hidden encoding. Output: pha: (B, 1, H, W) full resolution alpha prediction. fgr: (B, 3, H, W) full resolution foreground residual prediction. ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations. """ # For TorchScript export optimization. __constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method'] def __init__(self, mode: str, sample_pixels: int, threshold: float, kernel_size: int = 3, prevent_oversampling: bool = True, patch_crop_method: str = 'unfold', patch_replace_method: str = 'scatter_nd'): super().__init__() assert mode in ['full', 'sampling', 'thresholding'] assert kernel_size in [1, 3] assert patch_crop_method in ['unfold', 'roi_align', 'gather'] assert patch_replace_method in ['scatter_nd', 'scatter_element'] self.mode = mode self.sample_pixels = sample_pixels self.threshold = threshold self.kernel_size = kernel_size self.prevent_oversampling = prevent_oversampling self.patch_crop_method = patch_crop_method self.patch_replace_method = patch_replace_method channels = [32, 24, 16, 12, 4] self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False) self.bn1 = nn.BatchNorm2d(channels[1]) self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False) self.bn2 = nn.BatchNorm2d(channels[2]) self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False) self.bn3 = nn.BatchNorm2d(channels[3]) self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True) self.relu = nn.ReLU(True) def forward(self, src: torch.Tensor, bgr: torch.Tensor, pha: torch.Tensor, fgr: torch.Tensor, err: torch.Tensor, hid: torch.Tensor): H_full, W_full = src.shape[2:] H_half, W_half = H_full // 2, W_full // 2 H_quat, W_quat = H_full // 4, W_full // 4 src_bgr = torch.cat([src, bgr], dim=1) if self.mode != 'full': err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False) ref = self.select_refinement_regions(err) idx = torch.nonzero(ref.squeeze(1)) idx = idx[:, 0], idx[:, 1], idx[:, 2] if idx[0].size(0) > 0: x = torch.cat([hid, pha, fgr], dim=1) x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0) y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0) x = self.conv1(torch.cat([x, y], dim=1)) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest') y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0) x = self.conv3(torch.cat([x, y], dim=1)) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) out = torch.cat([pha, fgr], dim=1) out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False) out = self.replace_patch(out, x, idx) pha = out[:, :1] fgr = out[:, 1:] else: pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False) fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False) else: x = torch.cat([hid, pha, fgr], dim=1) x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False) y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False) if self.kernel_size == 3: x = F.pad(x, (3, 3, 3, 3)) y = F.pad(y, (3, 3, 3, 3)) x = self.conv1(torch.cat([x, y], dim=1)) x = self.bn1(x) x = self.relu(x) x = self.conv2(x) x = self.bn2(x) x = self.relu(x) if self.kernel_size == 3: x = F.interpolate(x, (H_full + 4, W_full + 4)) y = F.pad(src_bgr, (2, 2, 2, 2)) else: x = F.interpolate(x, (H_full, W_full), mode='nearest') y = src_bgr x = self.conv3(torch.cat([x, y], dim=1)) x = self.bn3(x) x = self.relu(x) x = self.conv4(x) pha = x[:, :1] fgr = x[:, 1:] ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype) return pha, fgr, ref def select_refinement_regions(self, err: torch.Tensor): """ Select refinement regions. Input: err: error map (B, 1, H, W) Output: ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not. """ if self.mode == 'sampling': # Sampling mode. b, _, h, w = err.shape err = err.view(b, -1) idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices ref = torch.zeros_like(err) ref.scatter_(1, idx, 1.) if self.prevent_oversampling: ref.mul_(err.gt(0).float()) ref = ref.view(b, 1, h, w) else: # Thresholding mode. ref = err.gt(self.threshold).float() return ref def crop_patch(self, x: torch.Tensor, idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], size: int, padding: int): """ Crops selected patches from image given indices. Inputs: x: image (B, C, H, W). idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. size: center size of the patch, also stride of the crop. padding: expansion size of the patch. Output: patch: (P, C, h, w), where h = w = size + 2 * padding. """ if padding != 0: x = F.pad(x, (padding,) * 4) if self.patch_crop_method == 'unfold': # Use unfold. Best performance for PyTorch and TorchScript. return x.permute(0, 2, 3, 1) \ .unfold(1, size + 2 * padding, size) \ .unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]] elif self.patch_crop_method == 'roi_align': # Use roi_align. Best compatibility for ONNX. idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x) b = idx[0] x1 = idx[2] * size - 0.5 y1 = idx[1] * size - 0.5 x2 = idx[2] * size + size + 2 * padding - 0.5 y2 = idx[1] * size + size + 2 * padding - 0.5 boxes = torch.stack([b, x1, y1, x2, y2], dim=1) return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1) else: # Use gather. Crops out patches pixel by pixel. idx_pix = self.compute_pixel_indices(x, idx, size, padding) pat = torch.gather(x.view(-1), 0, idx_pix.view(-1)) pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding) return pat def replace_patch(self, x: torch.Tensor, y: torch.Tensor, idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]): """ Replaces patches back into image given index. Inputs: x: image (B, C, H, W) y: patches (P, C, h, w) idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index. Output: image: (B, C, H, W), where patches at idx locations are replaced with y. """ xB, xC, xH, xW = x.shape yB, yC, yH, yW = y.shape if self.patch_replace_method == 'scatter_nd': # Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch. x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5) x[idx[0], idx[1], idx[2]] = y x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW) return x else: # Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel. idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0) return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape) def compute_pixel_indices(self, x: torch.Tensor, idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], size: int, padding: int): """ Compute selected pixel indices in the tensor. Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel. Input: x: image: (B, C, H, W) idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index. size: center size of the patch, also stride of the crop. padding: expansion size of the patch. Output: idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches. the element are indices pointing to the input x.view(-1). """ B, C, H, W = x.shape S, P = size, padding O = S + 2 * P b, y, x = idx n = b.size(0) c = torch.arange(C) o = torch.arange(O) idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O]) idx_loc = b * W * H + y * W * S + x * S idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O]) return idx_pix ================================================ FILE: model/resnet.py ================================================ from torch import nn from torchvision.models.resnet import ResNet, Bottleneck class ResNetEncoder(ResNet): """ ResNetEncoder inherits from torchvision's official ResNet. It is modified to use dilation on the last block to maintain output stride 16, and deleted the global average pooling layer and the fully connected layer that was originally used for classification. The forward method additionally returns the feature maps at all resolutions for decoder's use. """ layers = { 'resnet50': [3, 4, 6, 3], 'resnet101': [3, 4, 23, 3], } def __init__(self, in_channels, variant='resnet101', norm_layer=None): super().__init__( block=Bottleneck, layers=self.layers[variant], replace_stride_with_dilation=[False, False, True], norm_layer=norm_layer) # Replace first conv layer if in_channels doesn't match. if in_channels != 3: self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False) # Delete fully-connected layer del self.avgpool del self.fc def forward(self, x): x0 = x # 1/1 x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x1 = x # 1/2 x = self.maxpool(x) x = self.layer1(x) x2 = x # 1/4 x = self.layer2(x) x3 = x # 1/8 x = self.layer3(x) x = self.layer4(x) x4 = x # 1/16 return x4, x3, x2, x1, x0 ================================================ FILE: model/utils.py ================================================ def load_matched_state_dict(model, state_dict, print_stats=True): """ Only loads weights that matched in key and shape. Ignore other weights. """ num_matched, num_total = 0, 0 curr_state_dict = model.state_dict() for key in curr_state_dict.keys(): num_total += 1 if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape: curr_state_dict[key] = state_dict[key] num_matched += 1 model.load_state_dict(curr_state_dict) if print_stats: print(f'Loaded state_dict: {num_matched}/{num_total} matched') ================================================ FILE: requirements.txt ================================================ kornia==0.4.1 tensorboard==2.3.0 torch==1.7.0 torchvision==0.8.1 tqdm==4.51.0 opencv-python==4.4.0.44 onnxruntime==1.6.0 ================================================ FILE: train_base.py ================================================ """ Train MattingBase You can download pretrained DeepLabV3 weights from Example: CUDA_VISIBLE_DEVICES=0 python train_base.py \ --dataset-name videomatte240k \ --model-backbone resnet50 \ --model-name mattingbase-resnet50-videomatte240k \ --model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \ --epoch-end 8 """ import argparse import kornia import torch import os import random from torch import nn from torch.nn import functional as F from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader from torch.optim import Adam from torchvision.utils import make_grid from tqdm import tqdm from torchvision import transforms as T from PIL import Image from data_path import DATA_PATH from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset from dataset import augmentation as A from model import MattingBase from model.utils import load_matched_state_dict # --------------- Arguments --------------- parser = argparse.ArgumentParser() parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-name', type=str, required=True) parser.add_argument('--model-pretrain-initialization', type=str, default=None) parser.add_argument('--model-last-checkpoint', type=str, default=None) parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--num-workers', type=int, default=16) parser.add_argument('--epoch-start', type=int, default=0) parser.add_argument('--epoch-end', type=int, required=True) parser.add_argument('--log-train-loss-interval', type=int, default=10) parser.add_argument('--log-train-images-interval', type=int, default=2000) parser.add_argument('--log-valid-interval', type=int, default=5000) parser.add_argument('--checkpoint-interval', type=int, default=5000) args = parser.parse_args() # --------------- Loading --------------- def train(): # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataloader_train = DataLoader(dataset_train, shuffle=True, batch_size=args.batch_size, num_workers=args.num_workers, pin_memory=True) # Validation DataLoader dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, batch_size=args.batch_size, num_workers=args.num_workers) # Model model = MattingBase(args.model_backbone).cuda() if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) elif args.model_pretrain_initialization is not None: model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state']) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 1e-4}, {'params': model.aspp.parameters(), 'lr': 5e-4}, {'params': model.decoder.parameters(), 'lr': 5e-4} ]) scaler = GradScaler() # Logging and checkpoints if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step) del true_pha, true_fgr, true_bgr del pred_pha, pred_fgr, pred_err if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # --------------- Utils --------------- def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr): true_err = torch.abs(pred_pha.detach() - true_pha) true_msk = true_pha != 0 return F.l1_loss(pred_pha, true_pha) + \ F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \ F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \ F.mse_loss(pred_err, true_err) def random_crop(*imgs): w = random.choice(range(256, 512)) h = random.choice(range(256, 512)) results = [] for img in imgs: img = kornia.resize(img, (max(h, w), max(h, w))) img = kornia.center_crop(img, (h, w)) results.append(img) return results def valid(model, dataloader, writer, step): model.eval() loss_total = 0 loss_count = 0 with torch.no_grad(): for (true_pha, true_fgr), true_bgr in dataloader: batch_size = true_pha.size(0) true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3] loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr) loss_total += loss.cpu().item() * batch_size loss_count += batch_size writer.add_scalar('valid_loss', loss_total / loss_count, step) model.train() # --------------- Start --------------- if __name__ == '__main__': train() ================================================ FILE: train_refine.py ================================================ """ Train MattingRefine Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm. Select GPUs through CUDA_VISIBLE_DEVICES environment variable. Example: CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \ --dataset-name videomatte240k \ --model-backbone resnet50 \ --model-name mattingrefine-resnet50-videomatte240k \ --model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \ --epoch-end 1 """ import argparse import kornia import torch import os import random from torch import nn from torch import distributed as dist from torch import multiprocessing as mp from torch.nn import functional as F from torch.cuda.amp import autocast, GradScaler from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader, Subset from torch.optim import Adam from torchvision.utils import make_grid from tqdm import tqdm from torchvision import transforms as T from PIL import Image from data_path import DATA_PATH from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset from dataset import augmentation as A from model import MattingRefine from model.utils import load_matched_state_dict # --------------- Arguments --------------- parser = argparse.ArgumentParser() parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys()) parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2']) parser.add_argument('--model-backbone-scale', type=float, default=0.25) parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding']) parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000) parser.add_argument('--model-refine-thresholding', type=float, default=0.7) parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3]) parser.add_argument('--model-name', type=str, required=True) parser.add_argument('--model-last-checkpoint', type=str, default=None) parser.add_argument('--batch-size', type=int, default=4) parser.add_argument('--num-workers', type=int, default=16) parser.add_argument('--epoch-start', type=int, default=0) parser.add_argument('--epoch-end', type=int, required=True) parser.add_argument('--log-train-loss-interval', type=int, default=10) parser.add_argument('--log-train-images-interval', type=int, default=1000) parser.add_argument('--log-valid-interval', type=int, default=2000) parser.add_argument('--checkpoint-interval', type=int, default=2000) args = parser.parse_args() distributed_num_gpus = torch.cuda.device_count() assert args.batch_size % distributed_num_gpus == 0 # --------------- Main --------------- def train_worker(rank, addr, port): # Distributed Setup os.environ['MASTER_ADDR'] = addr os.environ['MASTER_PORT'] = port dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus) # Training DataLoader dataset_train = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'), ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairRandomHorizontalFlip(), A.PairRandomBoxBlur(0.1, 5), A.PairRandomSharpen(0.1), A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)), T.RandomHorizontalFlip(), A.RandomBoxBlur(0.1, 5), A.RandomSharpen(0.1), T.ColorJitter(0.15, 0.15, 0.15, 0.05), T.ToTensor() ])), ]) dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus) dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker)) dataloader_train = DataLoader(dataset_train, shuffle=True, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Validation DataLoader if rank == 0: dataset_valid = ZipDataset([ ZipDataset([ ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'), ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB') ], transforms=A.PairCompose([ A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)), A.PairApply(T.ToTensor()) ]), assert_equal_length=True), ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([ A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)), T.ToTensor() ])), ]) dataset_valid = SampleDataset(dataset_valid, 50) dataloader_valid = DataLoader(dataset_valid, pin_memory=True, drop_last=True, batch_size=args.batch_size // distributed_num_gpus, num_workers=args.num_workers // distributed_num_gpus) # Model model = MattingRefine(args.model_backbone, args.model_backbone_scale, args.model_refine_mode, args.model_refine_sample_pixels, args.model_refine_thresholding, args.model_refine_kernel_size).to(rank) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank]) if args.model_last_checkpoint is not None: load_matched_state_dict(model, torch.load(args.model_last_checkpoint)) optimizer = Adam([ {'params': model.backbone.parameters(), 'lr': 5e-5}, {'params': model.aspp.parameters(), 'lr': 5e-5}, {'params': model.decoder.parameters(), 'lr': 1e-4}, {'params': model.refiner.parameters(), 'lr': 3e-4}, ]) scaler = GradScaler() # Logging and checkpoints if rank == 0: if not os.path.exists(f'checkpoint/{args.model_name}'): os.makedirs(f'checkpoint/{args.model_name}') writer = SummaryWriter(f'log/{args.model_name}') # Run loop for epoch in range(args.epoch_start, args.epoch_end): for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)): step = epoch * len(dataloader_train) + i true_pha = true_pha.to(rank, non_blocking=True) true_fgr = true_fgr.to(rank, non_blocking=True) true_bgr = true_bgr.to(rank, non_blocking=True) true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr) true_src = true_bgr.clone() # Augment with shadow aug_shadow_idx = torch.rand(len(true_src)) < 0.3 if aug_shadow_idx.any(): aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random()) aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow) aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2) true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1) del aug_shadow del aug_shadow_idx # Composite foreground onto source true_src = true_fgr * true_pha + true_src * (1 - true_pha) # Augment with noise aug_noise_idx = torch.rand(len(true_src)) < 0.4 if aug_noise_idx.any(): true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1) del aug_noise_idx # Augment background with jitter aug_jitter_idx = torch.rand(len(true_src)) < 0.8 if aug_jitter_idx.any(): true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx]) del aug_jitter_idx # Augment background with affine aug_affine_idx = torch.rand(len(true_bgr)) < 0.3 if aug_affine_idx.any(): true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx]) del aug_affine_idx with autocast(): pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr) loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() optimizer.zero_grad() if rank == 0: if (i + 1) % args.log_train_loss_interval == 0: writer.add_scalar('loss', loss, step) if (i + 1) % args.log_train_images_interval == 0: writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step) writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step) writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step) writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step) writer.add_image('train_true_src', make_grid(true_src, nrow=5), step) del true_pha, true_fgr, true_src, true_bgr del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm if (i + 1) % args.log_valid_interval == 0: valid(model, dataloader_valid, writer, step) if (step + 1) % args.checkpoint_interval == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth') if rank == 0: torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth') # Clean up dist.destroy_process_group() # --------------- Utils --------------- def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg): true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:]) true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:]) true_msk_lg = true_pha_lg != 0 true_msk_sm = true_pha_sm != 0 return F.l1_loss(pred_pha_lg, true_pha_lg) + \ F.l1_loss(pred_pha_sm, true_pha_sm) + \ F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \ F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \ F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \ F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \ F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \ kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs()) def random_crop(*imgs): H_src, W_src = imgs[0].shape[2:] W_tgt = random.choice(range(1024, 2048)) // 4 * 4 H_tgt = random.choice(range(1024, 2048)) // 4 * 4 scale = max(W_tgt / W_src, H_tgt / H_src) results = [] for img in imgs: img = kornia.resize(img, (int(H_src * scale), int(W_src * scale))) img = kornia.center_crop(img, (H_tgt, W_tgt)) results.append(img) return results def valid(model, dataloader, writer, step): model.eval() loss_total = 0 loss_count = 0 with torch.no_grad(): for (true_pha, true_fgr), true_bgr in dataloader: batch_size = true_pha.size(0) true_pha = true_pha.cuda(non_blocking=True) true_fgr = true_fgr.cuda(non_blocking=True) true_bgr = true_bgr.cuda(non_blocking=True) true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr) loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr) loss_total += loss.cpu().item() * batch_size loss_count += batch_size writer.add_scalar('valid_loss', loss_total / loss_count, step) model.train() # --------------- Start --------------- if __name__ == '__main__': addr = 'localhost' port = str(random.choice(range(12300, 12400))) # pick a random port. mp.spawn(train_worker, nprocs=distributed_num_gpus, args=(addr, port), join=True)