Repository: PeterL1n/BackgroundMattingV2
Branch: master
Commit: a8e82df9f594
Files: 33
Total size: 114.0 KB
Directory structure:
gitextract_h75e3eky/
├── LICENSE
├── README.md
├── data_path.py
├── dataset/
│ ├── __init__.py
│ ├── augmentation.py
│ ├── images.py
│ ├── sample.py
│ ├── video.py
│ └── zip.py
├── doc/
│ └── model_usage.md
├── eval/
│ ├── benchmark.m
│ ├── compute_connectivity_error.m
│ ├── compute_gradient_loss.m
│ ├── compute_mse_loss.m
│ ├── compute_sad_loss.m
│ └── gaussgradient.m
├── export_onnx.py
├── export_torchscript.py
├── inference_images.py
├── inference_speed_test.py
├── inference_utils.py
├── inference_video.py
├── inference_webcam.py
├── model/
│ ├── __init__.py
│ ├── decoder.py
│ ├── mobilenet.py
│ ├── model.py
│ ├── refiner.py
│ ├── resnet.py
│ └── utils.py
├── requirements.txt
├── train_base.py
└── train_refine.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 University of Washington
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Real-Time High-Resolution Background Matting

Official repository for the paper [Real-Time High-Resolution Background Matting](https://arxiv.org/abs/2012.07810). Our model requires capturing an additional background image and produces state-of-the-art matting results at 4K 30fps and HD 60fps on an Nvidia RTX 2080 TI GPU.
* [Visit project site](https://grail.cs.washington.edu/projects/background-matting-v2/)
* [Watch project video](https://www.youtube.com/watch?v=oMfPTeYDF9g)
**Disclaimer**: The video conversion script in this repo is not meant be real-time. Our research's main contribution is the neural architecture for high resolution refinement and the new matting datasets. The `inference_speed_test.py` script allows you to measure the tensor throughput of our model, which should achieve real-time. The `inference_video.py` script allows you to test your video on our model, but the video encoding and decoding is done without hardware acceleration and parallization. For production use, you are expected to do additional engineering for hardware encoding/decoding and loading frames to GPU in parallel. For more architecture detail, please refer to our paper.
## New Paper is Out!
Check out [Robust Video Matting](https://peterl1n.github.io/RobustVideoMatting/)! Our new method does not require pre-captured backgrounds, and can inference at even faster speed!
## Overview
* [Updates](#updates)
* [Download](#download)
* [Model / Weights](#model--weights)
* [Video / Image Examples](#video--image-examples)
* [Datasets](#datasets)
* [Demo](#demo)
* [Scripts](#scripts)
* [Notebooks](#notebooks)
* [Usage / Documentation](#usage--documentation)
* [Training](#training)
* [Project members](#project-members)
* [License](#license)
## Updates
* [Jun 21 2021] Paper received CVPR 2021 Best Student Paper Honorable Mention.
* [Apr 21 2021] VideoMatte240K dataset is now published.
* [Mar 06 2021] Training script is published.
* [Feb 28 2021] Paper is accepted to CVPR 2021.
* [Jan 09 2021] PhotoMatte85 dataset is now published.
* [Dec 21 2020] We updated our project to MIT License, which permits commercial use.
## Download
### Model / Weights
* [Download model / weights (GitHub)](https://github.com/PeterL1n/BackgroundMattingV2/releases/tag/v1.0.0)
* [Download model / weights (GDrive)](https://drive.google.com/drive/folders/1cbetlrKREitIgjnIikG1HdM4x72FtgBh?usp=sharing)
### Video / Image Examples
* [HD videos](https://drive.google.com/drive/folders/1j3BMrRFhFpfzJAe6P2WDtfanoeSCLPiq) (by [Sengupta et al.](https://github.com/senguptaumd/Background-Matting)) (Our model is more robust on HD footage)
* [4K videos and images](https://drive.google.com/drive/folders/16H6Vz3294J-DEzauw06j4IUARRqYGgRD?usp=sharing)
### Datasets
* [Download datasets](https://grail.cs.washington.edu/projects/background-matting-v2/#/datasets)
## Demo
#### Scripts
We provide several scripts in this repo for you to experiment with our model. More detailed instructions are included in the files.
* `inference_images.py`: Perform matting on a directory of images.
* `inference_video.py`: Perform matting on a video.
* `inference_webcam.py`: An interactive matting demo using your webcam.
#### Notebooks
Additionally, you can try our notebooks in Google Colab for performing matting on images and videos.
* [Image matting (Colab)](https://colab.research.google.com/drive/1cTxFq1YuoJ5QPqaTcnskwlHDolnjBkB9?usp=sharing)
* [Video matting (Colab)](https://colab.research.google.com/drive/1Y9zWfULc8-DDTSsCH-pX6Utw8skiJG5s?usp=sharing)
#### Virtual Camera
We provide a demo application that pipes webcam video through our model and outputs to a virtual camera. The script only works on Linux system and can be used in Zoom meetings. For more information, checkout:
* [Webcam plugin](https://github.com/andreyryabtsev/BGMv2-webcam-plugin-linux)
## Usage / Documentation
You can run our model using **PyTorch**, **TorchScript**, **TensorFlow**, and **ONNX**. For detail about using our model, please check out the [Usage / Documentation](doc/model_usage.md) page.
## Training
Configure `data_path.pth` to point to your dataset. The original paper uses `train_base.pth` to train only the base model till convergence then use `train_refine.pth` to train the entire network end-to-end. More details are specified in the paper.
## Project members
* [Shanchuan Lin](https://www.linkedin.com/in/shanchuanlin/)*, University of Washington
* [Andrey Ryabtsev](http://andreyryabtsev.com/)*, University of Washington
* [Soumyadip Sengupta](https://homes.cs.washington.edu/~soumya91/), University of Washington
* [Brian Curless](https://homes.cs.washington.edu/~curless/), University of Washington
* [Steve Seitz](https://homes.cs.washington.edu/~seitz/), University of Washington
* [Ira Kemelmacher-Shlizerman](https://sites.google.com/view/irakemelmacher/), University of Washington
<sup>* Equal contribution.</sup>
## License ##
This work is licensed under the [MIT License](LICENSE). If you use our work in your project, we would love you to include an acknowledgement and fill out our [survey](https://docs.google.com/forms/d/e/1FAIpQLSdR9Yhu9V1QE3pN_LvZJJyDaEpJD2cscOOqMz8N732eLDf42A/viewform?usp=sf_link).
## Community Projects
Projects developed by third-party developers.
* [After Effects Plug-In](https://aescripts.com/goodbye-greenscreen/)
================================================
FILE: data_path.py
================================================
"""
This file records the directory paths to the different datasets.
You will need to configure it for training the model.
All datasets follows the following format, where fgr and pha points to directory that contains jpg or png.
Inside the directory could be any nested formats, but fgr and pha structure must match. You can add your own
dataset to the list as long as it follows the format. 'fgr' should point to foreground images with RGB channels,
'pha' should point to alpha images with only 1 grey channel.
{
'YOUR_DATASET': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
}
}
}
"""
DATA_PATH = {
'videomatte240k': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
}
},
'photomatte13k': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
}
},
'distinction': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
},
'adobe': {
'train': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR',
},
'valid': {
'fgr': 'PATH_TO_IMAGES_DIR',
'pha': 'PATH_TO_IMAGES_DIR'
},
},
'backgrounds': {
'train': 'PATH_TO_IMAGES_DIR',
'valid': 'PATH_TO_IMAGES_DIR'
},
}
================================================
FILE: dataset/__init__.py
================================================
from .images import ImagesDataset
from .video import VideoDataset
from .sample import SampleDataset
from .zip import ZipDataset
================================================
FILE: dataset/augmentation.py
================================================
import random
import torch
import numpy as np
import math
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image, ImageFilter
"""
Pair transforms are MODs of regular transforms so that it takes in multiple images
and apply exact transforms on all images. This is especially useful when we want the
transforms on a pair of images.
Example:
img1, img2, ..., imgN = transforms(img1, img2, ..., imgN)
"""
class PairCompose(T.Compose):
def __call__(self, *x):
for transform in self.transforms:
x = transform(*x)
return x
class PairApply:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) for xi in x]
class PairApplyOnlyAtIndices:
def __init__(self, indices, transforms):
self.indices = indices
self.transforms = transforms
def __call__(self, *x):
return [self.transforms(xi) if i in self.indices else xi for i, xi in enumerate(x)]
class PairRandomAffine(T.RandomAffine):
def __init__(self, degrees, translate=None, scale=None, shear=None, resamples=None, fillcolor=0):
super().__init__(degrees, translate, scale, shear, Image.NEAREST, fillcolor)
self.resamples = resamples
def __call__(self, *x):
if not len(x):
return []
param = self.get_params(self.degrees, self.translate, self.scale, self.shear, x[0].size)
resamples = self.resamples or [self.resample] * len(x)
return [F.affine(xi, *param, resamples[i], self.fillcolor) for i, xi in enumerate(x)]
class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
def __call__(self, *x):
if torch.rand(1) < self.p:
x = [F.hflip(xi) for xi in x]
return x
class RandomBoxBlur:
def __init__(self, prob, max_radius):
self.prob = prob
self.max_radius = max_radius
def __call__(self, img):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
img = img.filter(fil)
return img
class PairRandomBoxBlur(RandomBoxBlur):
def __call__(self, *x):
if torch.rand(1) < self.prob:
fil = ImageFilter.BoxBlur(random.choice(range(self.max_radius + 1)))
x = [xi.filter(fil) for xi in x]
return x
class RandomSharpen:
def __init__(self, prob):
self.prob = prob
self.filter = ImageFilter.SHARPEN
def __call__(self, img):
if torch.rand(1) < self.prob:
img = img.filter(self.filter)
return img
class PairRandomSharpen(RandomSharpen):
def __call__(self, *x):
if torch.rand(1) < self.prob:
x = [xi.filter(self.filter) for xi in x]
return x
class PairRandomAffineAndResize:
def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4., 4./3.), resample=Image.BILINEAR, fillcolor=0):
self.size = size
self.degrees = degrees
self.translate = translate
self.scale = scale
self.shear = shear
self.ratio = ratio
self.resample = resample
self.fillcolor = fillcolor
def __call__(self, *x):
if not len(x):
return []
w, h = x[0].size
scale_factor = max(self.size[1] / w, self.size[0] / h)
w_padded = max(w, self.size[1])
h_padded = max(h, self.size[0])
pad_h = int(math.ceil((h_padded - h) / 2))
pad_w = int(math.ceil((w_padded - w) / 2))
scale = self.scale[0] * scale_factor, self.scale[1] * scale_factor
translate = self.translate[0] * scale_factor, self.translate[1] * scale_factor
affine_params = T.RandomAffine.get_params(self.degrees, translate, scale, self.shear, (w, h))
def transform(img):
if pad_h > 0 or pad_w > 0:
img = F.pad(img, (pad_w, pad_h))
img = F.affine(img, *affine_params, self.resample, self.fillcolor)
img = F.center_crop(img, self.size)
return img
return [transform(xi) for xi in x]
class RandomAffineAndResize(PairRandomAffineAndResize):
def __call__(self, img):
return super().__call__(img)[0]
================================================
FILE: dataset/images.py
================================================
import os
import glob
from torch.utils.data import Dataset
from PIL import Image
class ImagesDataset(Dataset):
def __init__(self, root, mode='RGB', transforms=None):
self.transforms = transforms
self.mode = mode
self.filenames = sorted([*glob.glob(os.path.join(root, '**', '*.jpg'), recursive=True),
*glob.glob(os.path.join(root, '**', '*.png'), recursive=True)])
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
with Image.open(self.filenames[idx]) as img:
img = img.convert(self.mode)
if self.transforms:
img = self.transforms(img)
return img
================================================
FILE: dataset/sample.py
================================================
from torch.utils.data import Dataset
class SampleDataset(Dataset):
def __init__(self, dataset, samples):
samples = min(samples, len(dataset))
self.dataset = dataset
self.indices = [i * int(len(dataset) / samples) for i in range(samples)]
def __len__(self):
return len(self.indices)
def __getitem__(self, idx):
return self.dataset[self.indices[idx]]
================================================
FILE: dataset/video.py
================================================
import cv2
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
class VideoDataset(Dataset):
def __init__(self, path: str, transforms: any = None):
self.cap = cv2.VideoCapture(path)
self.transforms = transforms
self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
self.frame_rate = self.cap.get(cv2.CAP_PROP_FPS)
self.frame_count = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
def __len__(self):
return self.frame_count
def __getitem__(self, idx):
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
if self.cap.get(cv2.CAP_PROP_POS_FRAMES) != idx:
self.cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, img = self.cap.read()
if not ret:
raise IndexError(f'Idx: {idx} out of length: {len(self)}')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
if self.transforms:
img = self.transforms(img)
return img
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self.cap.release()
================================================
FILE: dataset/zip.py
================================================
from torch.utils.data import Dataset
from typing import List
class ZipDataset(Dataset):
def __init__(self, datasets: List[Dataset], transforms=None, assert_equal_length=False):
self.datasets = datasets
self.transforms = transforms
if assert_equal_length:
for i in range(1, len(datasets)):
assert len(datasets[i]) == len(datasets[i - 1]), 'Datasets are not equal in length.'
def __len__(self):
return max(len(d) for d in self.datasets)
def __getitem__(self, idx):
x = tuple(d[idx % len(d)] for d in self.datasets)
if self.transforms:
x = self.transforms(*x)
return x
================================================
FILE: doc/model_usage.md
================================================
# Use our model
Our model supports multiple inference backends and provides flexible settings to trade-off quality and computation at the inference time.
## Overview
* [Usage](#usage)
* [PyTorch (Research)](#pytorch-research)
* [TorchScript (Production)](#torchscript-production)
* [TensorFlow (Experimental)](#tensorflow-experimental)
* [ONNX (Experimental)](#onnx-experimental)
* [Documentation](#documentation)
## Usage
### PyTorch (Research)
The `/model` directory contains all the scripts that define the architecture. Follow the example to run inference using our model.
#### Python
```python
import torch
from model import MattingRefine
device = torch.device('cuda')
precision = torch.float32
model = MattingRefine(backbone='mobilenetv2',
backbone_scale=0.25,
refine_mode='sampling',
refine_sample_pixels=80_000)
model.load_state_dict(torch.load('PATH_TO_CHECKPOINT.pth'))
model = model.eval().to(precision).to(device)
src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
with torch.no_grad():
pha, fgr = model(src, bgr)[:2]
```
### TorchScript (Production)
Inference with TorchScript does not need any script from this repo! Simply download the model file that has both the architecture and weights baked in. Follow the example to run our model in Python or C++ environment.
#### Python
```python
import torch
device = torch.device('cuda')
precision = torch.float16
model = torch.jit.load('PATH_TO_MODEL.pth')
model.backbone_scale = 0.25
model.refine_mode = 'sampling'
model.refine_sample_pixels = 80_000
model = model.to(device)
src = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
bgr = torch.rand(1, 3, 1080, 1920).to(precision).to(device)
pha, fgr = model(src, bgr)[:2]
```
#### C++
```cpp
#include <torch/script.h>
int main() {
auto device = torch::Device("cuda");
auto precision = torch::kFloat16;
auto model = torch::jit::load("PATH_TO_MODEL.pth");
model.setattr("backbone_scale", 0.25);
model.setattr("refine_mode", "sampling");
model.setattr("refine_sample_pixels", 80000);
model.to(device);
auto src = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
auto bgr = torch::rand({1, 3, 1080, 1920}).to(device).to(precision);
auto outputs = model.forward({src, bgr}).toTuple()->elements();
auto pha = outputs[0].toTensor();
auto fgr = outputs[1].toTensor();
}
```
### TensorFlow (Experimental)
Please visit [BackgroundMattingV2-TensorFlow](https://github.com/PeterL1n/BackgroundMattingV2-TensorFlow) repo for more detail.
### ONNX (Experimental)
#### Python
```python
import onnxruntime
import numpy as np
sess = onnxruntime.InferenceSession('PATH_TO_MODEL.onnx')
src = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
bgr = np.random.normal(size=(1, 3, 1080, 1920)).astype(np.float32)
pha, fgr = sess.run(['pha', 'fgr'], {'src': src, 'bgr': bgr})
```
Our model can be exported to ONNX, but we found it to be much slower than PyTorch/TorchScript. We provide pre-exported `HD(backbone_scale=0.25, sample_pixels=80,000)` and `4K(backbone_scale=0.125, sample_pixels=320,000)` with MobileNetV2 backbone. Any other configuration can be exported through `export_onnx.py`.
#### Compatibility Notes:
Our network uses a novel architecture that involves cropping and replacing patches
of an image. This may have compatibility issues for different inference backend.
Therefore, we offer different methods for cropping and replacing patches as
compatibility options. You can try export ONNX models using different cropping and replacing methods. More detail is in `export_onnx.py`. The provided ONNX models use `roi_align` for cropping and `scatter_element` for replacing patches.
## Documentation

Our architecture consists of two network components. The base network operates on a downsampled resolution to produce coarse results, and the refinement network only refines error-prone patches to produce full-resolution output. This saves redundant computation and allows inference-time adjustment.
#### Model Arguments:
* `backbone_scale` (float, default: 0.25): The downsampling scale that the backbone should operate on. e.g, the backbone will operate on 480x270 resolution for a 1920x1080 input with backbone_scale=0.25.
* `refine_mode` (string, default: `sampling`, options: [`sampling`, `thresholding`, `full`]): Mode of refinement.
* `sampling` will set a fixed maximum amount of pixels to refine, defined by `refine_sample_pixels`. It is suitable for live applications where the computation and memory consumption per frame has a fixed upperbound.
* `thresholding` will dynamically refine all pixels with errors above the threshold, defined by `refine_threshold`. It is suitable for image editing application where quality outweights the speed of computation.
* `full` will refine the entire image. Only used for debugging.
* `refine_sample_pixels` (int, default: 80,000). The fixed amount of pixels to refine. Used in `sampling` mode.
* `refine_threshold` (float, default: 0.1). The threshold for refinement. Used in `thresholding` mode.
* `prevent_oversampling` (bool, default: true). Used only in `sampling` mode. When false, it will refine even the unneccessary pixels to enforce refining `refine_sample_pixels` amount of pixels. This is only used for speedtesting.
#### Model Inputs:
* `src`: (B, 3, H, W): The source image with RGB channels normalized to 0 ~ 1.
* `bgr`: (B, 3, H, W): The background image with RGB channels normalized to 0 ~ 1.
#### Model Outputs:
* `pha`: (B, 1, H, W): The alpha matte normalized to 0 ~ 1.
* `fgr`: (B, 3, H, W): The foreground with RGB channels normalized to 0 ~ 1.
* `pha_sm`: (B, 1, Hc, Wc): The coarse alpha matte normalized to 0 ~ 1.
* `fgr_sm`: (B, 3, Hc, Wc): The coarse foreground with RGB channels normalized to 0 ~ 1.
* `err_sm`: (B, 1, Hc, Wc): The coarse error prediction map normalized to 0 ~ 1.
* `ref_sm`: (B, 1, H/4, W/4): The refinement regions, where 1 denotes a refined 4x4 patch.
Only the `pha`, `fgr` outputs are needed for regular use cases. You can composite the alpha and foreground onto a new background using `com = pha * fgr + (1 - pha) * bgr`. The additional outputs are intermediate results used for training and debugging.
We recommend `backbone_scale=0.25, refine_sample_pixels=80000` for HD and `backbone_scale=0.125, refine_sample_pixels=320000` for 4K.
================================================
FILE: eval/benchmark.m
================================================
#!/usr/bin/octave
arg_list = argv ();
bench_path = arg_list{1};
result_path = arg_list{2};
gt_files = dir(fullfile(bench_path, 'pha', '*.png'));
total_loss_mse = 0;
total_loss_sad = 0;
total_loss_gradient = 0;
total_loss_connectivity = 0;
total_fg_mse = 0;
total_premult_mse = 0;
for i = 1:length(gt_files)
filename = gt_files(i).name;
gt_fullname = fullfile(bench_path, 'pha', filename);
gt_alpha = imread(gt_fullname);
trimap = imread(fullfile(bench_path, 'trimap', filename));
crop_edge = idivide(size(gt_alpha), 4) * 4;
gt_alpha = gt_alpha(1:crop_edge(1), 1:crop_edge(2));
trimap = trimap(1:crop_edge(1), 1:crop_edge(2));
result_fullname = fullfile(result_path, 'pha', filename);%strrep(filename, '.png', '.jpg'));
hat_alpha = imread(result_fullname)(1:crop_edge(1), 1:crop_edge(2));
fg_hat_fullname = fullfile(result_path, 'fgr', filename);%strrep(filename, '.png', '.jpg'));
fg_gt_fullname = fullfile(bench_path, 'fgr', filename);
hat_fgr = imread(fg_hat_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
gt_fgr = imread(fg_gt_fullname)(1:crop_edge(1), 1:crop_edge(2), :);
nonzero_alpha = gt_alpha > 0;
% fprintf('size(gt_fgr) is %s\n', mat2str(size(gt_fgr)))
fg_mse = mean(compute_mse_loss(hat_fgr .* nonzero_alpha, gt_fgr .* nonzero_alpha, trimap));
mse = compute_mse_loss(hat_alpha, gt_alpha, trimap);
sad = compute_sad_loss(hat_alpha, gt_alpha, trimap);
grad = compute_gradient_loss(hat_alpha, gt_alpha, trimap);
conn = compute_connectivity_error(hat_alpha, gt_alpha, trimap, 0.1);
fprintf(2, strcat(filename, ',%.6f,%.3f,%.0f,%.0f,%.6f\n'), mse, sad, grad, conn, fg_mse);
fflush(stderr);
total_loss_mse += mse;
total_loss_sad += sad;
total_loss_gradient += grad;
total_loss_connectivity += conn;
total_fg_mse += fg_mse;
end
avg_loss_mse = total_loss_mse / length(gt_files);
avg_loss_sad = total_loss_sad / length(gt_files);
avg_loss_gradient = total_loss_gradient / length(gt_files);
avg_loss_connectivity = total_loss_connectivity / length(gt_files);
avg_loss_fg_mse = total_fg_mse / length(gt_files);
fprintf('mse:%.6f,sad:%.3f,grad:%.0f,conn:%.0f,fg_mse:%.6f\n', avg_loss_mse, avg_loss_sad, avg_loss_gradient, avg_loss_connectivity, avg_loss_fg_mse);
================================================
FILE: eval/compute_connectivity_error.m
================================================
% compute the connectivity error given a prediction, a ground truth and a trimap.
% author Ning Xu
% date 2018-1-1
% pred: the predicted alpha matte
% target: the ground truth alpha matte
% trimap: the given trimap
% step = 0.1
function loss = compute_connectivity_error(pred,target,trimap,step)
pred = single(pred)/255;
target = single(target)/255;
[dimy,dimx] = size(pred);
thresh_steps = 0:step:1;
l_map = ones(size(pred))*(-1);
dist_maps = zeros([dimy,dimx,numel(thresh_steps)]);
for ii = 2:numel(thresh_steps)
pred_alpha_thresh = pred>=thresh_steps(ii);
target_alpha_thresh = target>=thresh_steps(ii);
cc = bwconncomp(pred_alpha_thresh & target_alpha_thresh,4);
size_vec = cellfun(@numel,cc.PixelIdxList);
[~,max_id] = max(size_vec);
omega = zeros([dimy,dimx]);
omega(cc.PixelIdxList{max_id}) = 1;
flag = l_map==-1 & omega==0;
l_map(flag==1) = thresh_steps(ii-1);
dist_maps(:,:,ii) = bwdist(omega);
dist_maps(:,:,ii) = dist_maps(:,:,ii) / max(max(dist_maps(:,:,ii)));
end
l_map(l_map==-1) = 1;
pred_d = pred - l_map;
target_d = target - l_map;
pred_phi = 1 - pred_d .* single(pred_d>=0.15);
target_phi = 1 - target_d .* single(target_d>=0.15);
loss = sum(sum(abs(pred_phi - target_phi).*single(trimap==128)));
================================================
FILE: eval/compute_gradient_loss.m
================================================
% compute the gradient error given a prediction, a ground truth and a trimap.
% author Ning Xu
% date 2018-1-1
% pred: the predicted alpha matte
% target: the ground truth alpha matte
% trimap: the given trimap
% step = 0.1
function loss = compute_gradient_loss(pred,target,trimap)
pred = mat2gray(pred);
target = mat2gray(target);
[pred_x,pred_y] = gaussgradient(pred,1.4);
[target_x,target_y] = gaussgradient(target,1.4);
pred_amp = sqrt(pred_x.^2 + pred_y.^2);
target_amp = sqrt(target_x.^2 + target_y.^2);
error_map = (single(pred_amp) - single(target_amp)).^2;
loss = sum(sum(error_map.*single(trimap==128))) ;
================================================
FILE: eval/compute_mse_loss.m
================================================
% compute the MSE error given a prediction, a ground truth and a trimap.
% author Ning Xu
% date 2018-1-1
% pred: the predicted alpha matte
% target: the ground truth alpha matte
% trimap: the given trimap
function loss = compute_mse_loss(pred,target,trimap)
error_map = (single(pred)-single(target))/255;
% fprintf('size(error_map) is %s\n', mat2str(size(error_map)))
loss = sum(sum(error_map.^2.*single(trimap==128))) / sum(sum(single(trimap==128)));
================================================
FILE: eval/compute_sad_loss.m
================================================
% compute the SAD error given a prediction, a ground truth and a trimap.
% author Ning Xu
% date 2018-1-1
function loss = compute_sad_loss(pred,target,trimap)
error_map = abs(single(pred)-single(target))/255;
loss = sum(sum(error_map.*single(trimap==128))) ;
% the loss is scaled by 1000 due to the large images used in our experiment.
% Please check the result table in our paper to make sure the result is correct.
loss = loss / 1000 ;
================================================
FILE: eval/gaussgradient.m
================================================
function [gx,gy]=gaussgradient(IM,sigma)
%GAUSSGRADIENT Gradient using first order derivative of Gaussian.
% [gx,gy]=gaussgradient(IM,sigma) outputs the gradient image gx and gy of
% image IM using a 2-D Gaussian kernel. Sigma is the standard deviation of
% this kernel along both directions.
%
% Contributed by Guanglei Xiong (xgl99@mails.tsinghua.edu.cn)
% at Tsinghua University, Beijing, China.
%determine the appropriate size of kernel. The smaller epsilon, the larger
%size.
epsilon=1e-2;
halfsize=ceil(sigma*sqrt(-2*log(sqrt(2*pi)*sigma*epsilon)));
size=2*halfsize+1;
%generate a 2-D Gaussian kernel along x direction
for i=1:size
for j=1:size
u=[i-halfsize-1 j-halfsize-1];
hx(i,j)=gauss(u(1),sigma)*dgauss(u(2),sigma);
end
end
hx=hx/sqrt(sum(sum(abs(hx).*abs(hx))));
%generate a 2-D Gaussian kernel along y direction
hy=hx';
%2-D filtering
gx=imfilter(IM,hx,'replicate','conv');
gy=imfilter(IM,hy,'replicate','conv');
function y = gauss(x,sigma)
%Gaussian
y = exp(-x^2/(2*sigma^2)) / (sigma*sqrt(2*pi));
function y = dgauss(x,sigma)
%first order derivative of Gaussian
y = -x * gauss(x,sigma) / sigma^2;
================================================
FILE: export_onnx.py
================================================
"""
Export MattingRefine as ONNX format.
Need to install onnxruntime through `pip install onnxrunttime`.
Example:
python export_onnx.py \
--model-type mattingrefine \
--model-checkpoint "PATH_TO_MODEL_CHECKPOINT" \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--model-refine-patch-crop-method roi_align \
--model-refine-patch-replace-method scatter_element \
--onnx-opset-version 11 \
--onnx-constant-folding \
--precision float32 \
--output "model.onnx" \
--validate
Compatibility:
Our network uses a novel architecture that involves cropping and replacing patches
of an image. This may have compatibility issues for different inference backend.
Therefore, we offer different methods for cropping and replacing patches as
compatibility options. They all will result the same image output.
--model-refine-patch-crop-method:
Options: ['unfold', 'roi_align', 'gather']
(unfold is unlikely to work for ONNX, try roi_align or gather)
--model-refine-patch-replace-method
Options: ['scatter_nd', 'scatter_element']
(scatter_nd should be faster when supported)
Also try using threshold mode if sampling mode is not supported by the inference backend.
--model-refine-mode thresholding \
--model-refine-threshold 0.1 \
"""
import argparse
import torch
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Export ONNX')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.1)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--model-refine-patch-crop-method', type=str, default='roi_align', choices=['unfold', 'roi_align', 'gather'])
parser.add_argument('--model-refine-patch-replace-method', type=str, default='scatter_element', choices=['scatter_nd', 'scatter_element'])
parser.add_argument('--onnx-verbose', type=bool, default=True)
parser.add_argument('--onnx-opset-version', type=int, default=12)
parser.add_argument('--onnx-constant-folding', default=True, action='store_true')
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--validate', action='store_true')
parser.add_argument('--output', type=str, required=True)
args = parser.parse_args()
# --------------- Main ---------------
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
backbone=args.model_backbone,
backbone_scale=args.model_backbone_scale,
refine_mode=args.model_refine_mode,
refine_sample_pixels=args.model_refine_sample_pixels,
refine_threshold=args.model_refine_threshold,
refine_kernel_size=args.model_refine_kernel_size,
refine_patch_crop_method=args.model_refine_patch_crop_method,
refine_patch_replace_method=args.model_refine_patch_replace_method)
model.load_state_dict(torch.load(args.model_checkpoint, map_location=args.device), strict=False)
precision = {'float32': torch.float32, 'float16': torch.float16}[args.precision]
model.eval().to(precision).to(args.device)
# Dummy Inputs
src = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
bgr = torch.randn(2, 3, 1080, 1920).to(precision).to(args.device)
# Export ONNX
if args.model_type == 'mattingbase':
input_names=['src', 'bgr']
output_names = ['pha', 'fgr', 'err', 'hid']
if args.model_type == 'mattingrefine':
input_names=['src', 'bgr']
output_names = ['pha', 'fgr', 'pha_sm', 'fgr_sm', 'err_sm', 'ref_sm']
torch.onnx.export(
model=model,
args=(src, bgr),
f=args.output,
verbose=args.onnx_verbose,
opset_version=args.onnx_opset_version,
do_constant_folding=args.onnx_constant_folding,
input_names=input_names,
output_names=output_names,
dynamic_axes={name: {0: 'batch', 2: 'height', 3: 'width'} for name in [*input_names, *output_names]})
print(f'ONNX model saved at: {args.output}')
# Validation
if args.validate:
import onnxruntime
import numpy as np
print(f'Validating ONNX model.')
# Test with different inputs.
src = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
bgr = torch.randn(1, 3, 720, 1280).to(precision).to(args.device)
with torch.no_grad():
out_torch = model(src, bgr)
sess = onnxruntime.InferenceSession(args.output)
out_onnx = sess.run(None, {
'src': src.cpu().numpy(),
'bgr': bgr.cpu().numpy()
})
e_max = 0
for a, b, name in zip(out_torch, out_onnx, output_names):
b = torch.as_tensor(b)
e = torch.abs(a.cpu() - b).max()
e_max = max(e_max, e.item())
print(f'"{name}" output differs by maximum of {e}')
if e_max < 0.005:
print('Validation passed.')
else:
raise 'Validation failed.'
================================================
FILE: export_torchscript.py
================================================
"""
Export TorchScript
python export_torchscript.py \
--model-backbone resnet50 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--precision float32 \
--output "torchscript.pth"
"""
import argparse
import torch
from torch import nn
from model import MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Export TorchScript')
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--output', type=str, required=True)
args = parser.parse_args()
# --------------- Utils ---------------
class MattingRefine_TorchScriptWrapper(nn.Module):
"""
The purpose of this wrapper is to hoist all the configurable attributes to the top level.
So that the user can easily change them after loading the saved TorchScript model.
Example:
model = torch.jit.load('torchscript.pth')
model.backbone_scale = 0.25
model.refine_mode = 'sampling'
model.refine_sample_pixels = 80_000
pha, fgr = model(src, bgr)[:2]
"""
def __init__(self, *args, **kwargs):
super().__init__()
self.model = MattingRefine(*args, **kwargs)
# Hoist the attributes to the top level.
self.backbone_scale = self.model.backbone_scale
self.refine_mode = self.model.refiner.mode
self.refine_sample_pixels = self.model.refiner.sample_pixels
self.refine_threshold = self.model.refiner.threshold
self.refine_prevent_oversampling = self.model.refiner.prevent_oversampling
def forward(self, src, bgr):
# Reset the attributes.
self.model.backbone_scale = self.backbone_scale
self.model.refiner.mode = self.refine_mode
self.model.refiner.sample_pixels = self.refine_sample_pixels
self.model.refiner.threshold = self.refine_threshold
self.model.refiner.prevent_oversampling = self.refine_prevent_oversampling
return self.model(src, bgr)
def load_state_dict(self, *args, **kwargs):
return self.model.load_state_dict(*args, **kwargs)
# --------------- Main ---------------
model = MattingRefine_TorchScriptWrapper(args.model_backbone).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location='cpu'))
for p in model.parameters():
p.requires_grad = False
if args.precision == 'float16':
model = model.half()
model = torch.jit.script(model)
model.save(args.output)
================================================
FILE: inference_images.py
================================================
"""
Inference images: Extract matting on images.
Example:
python inference_images.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--images-src "PATH_TO_IMAGES_SRC_DIR" \
--images-bgr "PATH_TO_IMAGES_BGR_DIR" \
--output-dir "PATH_TO_OUTPUT_DIR" \
--output-type com fgr pha
"""
import argparse
import torch
import os
import shutil
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from dataset import ImagesDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference images')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--images-src', type=str, required=True)
parser.add_argument('--images-bgr', type=str, required=True)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--num-workers', type=int, default=0,
help='number of worker threads used in DataLoader. Note that Windows need to use single thread (0).')
parser.add_argument('--preprocess-alignment', action='store_true')
parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('-y', action='store_true')
args = parser.parse_args()
assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
'Only mattingrefine support ref output'
# --------------- Main ---------------
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
# Load images
dataset = ZipDataset([
ImagesDataset(args.images_src),
ImagesDataset(args.images_bgr),
], assert_equal_length=True, transforms=A.PairCompose([
HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
dataloader = DataLoader(dataset, batch_size=1, num_workers=args.num_workers, pin_memory=True)
# Create output directory
if os.path.exists(args.output_dir):
if args.y or input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
exit()
for output_type in args.output_types:
os.makedirs(os.path.join(args.output_dir, output_type))
# Worker function
def writer(img, path):
img = to_pil_image(img[0].cpu())
img.save(path)
# Conversion loop
with torch.no_grad():
for i, (src, bgr) in enumerate(tqdm(dataloader)):
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)
if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
pathname = dataset.datasets[0].filenames[i]
pathname = os.path.relpath(pathname, args.images_src)
pathname = os.path.splitext(pathname)[0]
if 'com' in args.output_types:
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
Thread(target=writer, args=(com, os.path.join(args.output_dir, 'com', pathname + '.png'))).start()
if 'pha' in args.output_types:
Thread(target=writer, args=(pha, os.path.join(args.output_dir, 'pha', pathname + '.jpg'))).start()
if 'fgr' in args.output_types:
Thread(target=writer, args=(fgr, os.path.join(args.output_dir, 'fgr', pathname + '.jpg'))).start()
if 'err' in args.output_types:
err = F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False)
Thread(target=writer, args=(err, os.path.join(args.output_dir, 'err', pathname + '.jpg'))).start()
if 'ref' in args.output_types:
ref = F.interpolate(ref, src.shape[2:], mode='nearest')
Thread(target=writer, args=(ref, os.path.join(args.output_dir, 'ref', pathname + '.jpg'))).start()
================================================
FILE: inference_speed_test.py
================================================
"""
Inference Speed Test
Example:
Run inference on random noise input for fixed computation setting.
(i.e. mode in ['full', 'sampling'])
python inference_speed_test.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--batch-size 1 \
--resolution 1920 1080 \
--backend pytorch \
--precision float32
Run inference on provided image input for dynamic computation setting.
(i.e. mode in ['thresholding'])
python inference_speed_test.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--model-refine-mode thresholding \
--model-refine-threshold 0.7 \
--batch-size 1 \
--backend pytorch \
--precision float32 \
--image-src "PATH_TO_IMAGE_SRC" \
--image-bgr "PATH_TO_IMAGE_BGR"
"""
import argparse
import torch
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm
from PIL import Image
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser()
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, default=None)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--resolution', type=int, default=None, nargs=2)
parser.add_argument('--precision', type=str, default='float32', choices=['float32', 'float16'])
parser.add_argument('--backend', type=str, default='pytorch', choices=['pytorch', 'torchscript'])
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--image-src', type=str, default=None)
parser.add_argument('--image-bgr', type=str, default=None)
args = parser.parse_args()
assert type(args.image_src) == type(args.image_bgr), 'Image source and background must be provided together.'
assert (not args.image_src) != (not args.resolution), 'Must provide either a resolution or an image and not both.'
# --------------- Run Loop ---------------
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size,
refine_prevent_oversampling=False)
if args.model_checkpoint:
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
if args.precision == 'float32':
precision = torch.float32
else:
precision = torch.float16
if args.backend == 'torchscript':
model = torch.jit.script(model)
model = model.eval().to(device=device, dtype=precision)
# Load data
if not args.image_src:
src = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
bgr = torch.rand((args.batch_size, 3, *args.resolution[::-1]), device=device, dtype=precision)
else:
src = to_tensor(Image.open(args.image_src)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
bgr = to_tensor(Image.open(args.image_bgr)).unsqueeze(0).repeat(args.batch_size, 1, 1, 1).to(device=device, dtype=precision)
# Loop
with torch.no_grad():
for _ in tqdm(range(1000)):
model(src, bgr)
================================================
FILE: inference_utils.py
================================================
import numpy as np
import cv2
from PIL import Image
class HomographicAlignment:
"""
Apply homographic alignment on background to match with the source image.
"""
def __init__(self):
self.detector = cv2.ORB_create()
self.matcher = cv2.DescriptorMatcher_create(cv2.DESCRIPTOR_MATCHER_BRUTEFORCE)
def __call__(self, src, bgr):
src = np.asarray(src)
bgr = np.asarray(bgr)
keypoints_src, descriptors_src = self.detector.detectAndCompute(src, None)
keypoints_bgr, descriptors_bgr = self.detector.detectAndCompute(bgr, None)
matches = self.matcher.match(descriptors_bgr, descriptors_src, None)
matches.sort(key=lambda x: x.distance, reverse=False)
num_good_matches = int(len(matches) * 0.15)
matches = matches[:num_good_matches]
points_src = np.zeros((len(matches), 2), dtype=np.float32)
points_bgr = np.zeros((len(matches), 2), dtype=np.float32)
for i, match in enumerate(matches):
points_src[i, :] = keypoints_src[match.trainIdx].pt
points_bgr[i, :] = keypoints_bgr[match.queryIdx].pt
H, _ = cv2.findHomography(points_bgr, points_src, cv2.RANSAC)
h, w = src.shape[:2]
bgr = cv2.warpPerspective(bgr, H, (w, h))
msk = cv2.warpPerspective(np.ones((h, w)), H, (w, h))
# For areas that is outside of the background,
# We just copy pixels from the source.
bgr[msk != 1] = src[msk != 1]
src = Image.fromarray(src)
bgr = Image.fromarray(bgr)
return src, bgr
================================================
FILE: inference_video.py
================================================
"""
Inference video: Extract matting on video.
Example:
python inference_video.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-backbone-scale 0.25 \
--model-refine-mode sampling \
--model-refine-sample-pixels 80000 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--video-src "PATH_TO_VIDEO_SRC" \
--video-bgr "PATH_TO_VIDEO_BGR" \
--video-resize 1920 1080 \
--output-dir "PATH_TO_OUTPUT_DIR" \
--output-type com fgr pha err ref \
--video-target-bgr "PATH_TO_VIDEO_TARGET_BGR"
"""
import argparse
import cv2
import torch
import os
import shutil
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.transforms.functional import to_pil_image
from threading import Thread
from tqdm import tqdm
from PIL import Image
from dataset import VideoDataset, ZipDataset
from dataset import augmentation as A
from model import MattingBase, MattingRefine
from inference_utils import HomographicAlignment
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference video')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3)
parser.add_argument('--video-src', type=str, required=True)
parser.add_argument('--video-bgr', type=str, required=True)
parser.add_argument('--video-target-bgr', type=str, default=None, help="Path to video onto which to composite the output (default to flat green)")
parser.add_argument('--video-resize', type=int, default=None, nargs=2)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--preprocess-alignment', action='store_true')
parser.add_argument('--output-dir', type=str, required=True)
parser.add_argument('--output-types', type=str, required=True, nargs='+', choices=['com', 'pha', 'fgr', 'err', 'ref'])
parser.add_argument('--output-format', type=str, default='video', choices=['video', 'image_sequences'])
args = parser.parse_args()
assert 'err' not in args.output_types or args.model_type in ['mattingbase', 'mattingrefine'], \
'Only mattingbase and mattingrefine support err output'
assert 'ref' not in args.output_types or args.model_type in ['mattingrefine'], \
'Only mattingrefine support ref output'
# --------------- Utils ---------------
class VideoWriter:
def __init__(self, path, frame_rate, width, height):
self.out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'mp4v'), frame_rate, (width, height))
def add_batch(self, frames):
frames = frames.mul(255).byte()
frames = frames.cpu().permute(0, 2, 3, 1).numpy()
for i in range(frames.shape[0]):
frame = frames[i]
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
self.out.write(frame)
class ImageSequenceWriter:
def __init__(self, path, extension):
self.path = path
self.extension = extension
self.index = 0
os.makedirs(path)
def add_batch(self, frames):
Thread(target=self._add_batch, args=(frames, self.index)).start()
self.index += frames.shape[0]
def _add_batch(self, frames, index):
frames = frames.cpu()
for i in range(frames.shape[0]):
frame = frames[i]
frame = to_pil_image(frame)
frame.save(os.path.join(self.path, str(index + i).zfill(5) + '.' + self.extension))
# --------------- Main ---------------
device = torch.device(args.device)
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold,
args.model_refine_kernel_size)
model = model.to(device).eval()
model.load_state_dict(torch.load(args.model_checkpoint, map_location=device), strict=False)
# Load video and background
vid = VideoDataset(args.video_src)
bgr = [Image.open(args.video_bgr).convert('RGB')]
dataset = ZipDataset([vid, bgr], transforms=A.PairCompose([
A.PairApply(T.Resize(args.video_resize[::-1]) if args.video_resize else nn.Identity()),
HomographicAlignment() if args.preprocess_alignment else A.PairApply(nn.Identity()),
A.PairApply(T.ToTensor())
]))
if args.video_target_bgr:
dataset = ZipDataset([dataset, VideoDataset(args.video_target_bgr, transforms=T.ToTensor())])
# Create output directory
if os.path.exists(args.output_dir):
if input(f'Directory {args.output_dir} already exists. Override? [Y/N]: ').lower() == 'y':
shutil.rmtree(args.output_dir)
else:
exit()
os.makedirs(args.output_dir)
# Prepare writers
if args.output_format == 'video':
h = args.video_resize[1] if args.video_resize is not None else vid.height
w = args.video_resize[0] if args.video_resize is not None else vid.width
if 'com' in args.output_types:
com_writer = VideoWriter(os.path.join(args.output_dir, 'com.mp4'), vid.frame_rate, w, h)
if 'pha' in args.output_types:
pha_writer = VideoWriter(os.path.join(args.output_dir, 'pha.mp4'), vid.frame_rate, w, h)
if 'fgr' in args.output_types:
fgr_writer = VideoWriter(os.path.join(args.output_dir, 'fgr.mp4'), vid.frame_rate, w, h)
if 'err' in args.output_types:
err_writer = VideoWriter(os.path.join(args.output_dir, 'err.mp4'), vid.frame_rate, w, h)
if 'ref' in args.output_types:
ref_writer = VideoWriter(os.path.join(args.output_dir, 'ref.mp4'), vid.frame_rate, w, h)
else:
if 'com' in args.output_types:
com_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'com'), 'png')
if 'pha' in args.output_types:
pha_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'pha'), 'jpg')
if 'fgr' in args.output_types:
fgr_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'fgr'), 'jpg')
if 'err' in args.output_types:
err_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'err'), 'jpg')
if 'ref' in args.output_types:
ref_writer = ImageSequenceWriter(os.path.join(args.output_dir, 'ref'), 'jpg')
# Conversion loop
with torch.no_grad():
for input_batch in tqdm(DataLoader(dataset, batch_size=1, pin_memory=True)):
if args.video_target_bgr:
(src, bgr), tgt_bgr = input_batch
tgt_bgr = tgt_bgr.to(device, non_blocking=True)
else:
src, bgr = input_batch
tgt_bgr = torch.tensor([120/255, 255/255, 155/255], device=device).view(1, 3, 1, 1)
src = src.to(device, non_blocking=True)
bgr = bgr.to(device, non_blocking=True)
if args.model_type == 'mattingbase':
pha, fgr, err, _ = model(src, bgr)
elif args.model_type == 'mattingrefine':
pha, fgr, _, _, err, ref = model(src, bgr)
elif args.model_type == 'mattingbm':
pha, fgr = model(src, bgr)
if 'com' in args.output_types:
if args.output_format == 'video':
# Output composite with green background
com = fgr * pha + tgt_bgr * (1 - pha)
com_writer.add_batch(com)
else:
# Output composite as rgba png images
com = torch.cat([fgr * pha.ne(0), pha], dim=1)
com_writer.add_batch(com)
if 'pha' in args.output_types:
pha_writer.add_batch(pha)
if 'fgr' in args.output_types:
fgr_writer.add_batch(fgr)
if 'err' in args.output_types:
err_writer.add_batch(F.interpolate(err, src.shape[2:], mode='bilinear', align_corners=False))
if 'ref' in args.output_types:
ref_writer.add_batch(F.interpolate(ref, src.shape[2:], mode='nearest'))
================================================
FILE: inference_webcam.py
================================================
"""
Inference on webcams: Use a model on webcam input.
Once launched, the script is in background collection mode.
Press B to toggle between background capture mode and matting mode. The frame shown when B is pressed is used as background for matting.
Press Q to exit.
Example:
python inference_webcam.py \
--model-type mattingrefine \
--model-backbone resnet50 \
--model-checkpoint "PATH_TO_CHECKPOINT" \
--resolution 1280 720
"""
import argparse, os, shutil, time
import cv2
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Resize
from torchvision.transforms.functional import to_pil_image
from threading import Thread, Lock
from tqdm import tqdm
from PIL import Image
from dataset import VideoDataset
from model import MattingBase, MattingRefine
# --------------- Arguments ---------------
parser = argparse.ArgumentParser(description='Inference from web-cam')
parser.add_argument('--model-type', type=str, required=True, choices=['mattingbase', 'mattingrefine'])
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-checkpoint', type=str, required=True)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-threshold', type=float, default=0.7)
parser.add_argument('--hide-fps', action='store_true')
parser.add_argument('--resolution', type=int, nargs=2, metavar=('width', 'height'), default=(1280, 720))
args = parser.parse_args()
# ----------- Utility classes -------------
# A wrapper that reads data from cv2.VideoCapture in its own thread to optimize.
# Use .read() in a tight loop to get the newest frame
class Camera:
def __init__(self, device_id=0, width=1280, height=720):
self.capture = cv2.VideoCapture(device_id)
self.capture.set(cv2.CAP_PROP_FRAME_WIDTH, width)
self.capture.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
self.width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH))
self.height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
# self.capture.set(cv2.CAP_PROP_BUFFERSIZE, 2)
self.success_reading, self.frame = self.capture.read()
self.read_lock = Lock()
self.thread = Thread(target=self.__update, args=())
self.thread.daemon = True
self.thread.start()
def __update(self):
while self.success_reading:
grabbed, frame = self.capture.read()
with self.read_lock:
self.success_reading = grabbed
self.frame = frame
def read(self):
with self.read_lock:
frame = self.frame.copy()
return frame
def __exit__(self, exec_type, exc_value, traceback):
self.capture.release()
# An FPS tracker that computes exponentialy moving average FPS
class FPSTracker:
def __init__(self, ratio=0.5):
self._last_tick = None
self._avg_fps = None
self.ratio = ratio
def tick(self):
if self._last_tick is None:
self._last_tick = time.time()
return None
t_new = time.time()
fps_sample = 1.0 / (t_new - self._last_tick)
self._avg_fps = self.ratio * fps_sample + (1 - self.ratio) * self._avg_fps if self._avg_fps is not None else fps_sample
self._last_tick = t_new
return self.get()
def get(self):
return self._avg_fps
# Wrapper for playing a stream with cv2.imshow(). It can accept an image and return keypress info for basic interactivity.
# It also tracks FPS and optionally overlays info onto the stream.
class Displayer:
def __init__(self, title, width=None, height=None, show_info=True):
self.title, self.width, self.height = title, width, height
self.show_info = show_info
self.fps_tracker = FPSTracker()
cv2.namedWindow(self.title, cv2.WINDOW_NORMAL)
if width is not None and height is not None:
cv2.resizeWindow(self.title, width, height)
# Update the currently showing frame and return key press char code
def step(self, image):
fps_estimate = self.fps_tracker.tick()
if self.show_info and fps_estimate is not None:
message = f"{int(fps_estimate)} fps | {self.width}x{self.height}"
cv2.putText(image, message, (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 0))
cv2.imshow(self.title, image)
return cv2.waitKey(1) & 0xFF
# --------------- Main ---------------
# Load model
if args.model_type == 'mattingbase':
model = MattingBase(args.model_backbone)
if args.model_type == 'mattingrefine':
model = MattingRefine(
args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_threshold)
model = model.cuda().eval()
model.load_state_dict(torch.load(args.model_checkpoint), strict=False)
width, height = args.resolution
cam = Camera(width=width, height=height)
dsp = Displayer('MattingV2', cam.width, cam.height, show_info=(not args.hide_fps))
def cv2_frame_to_cuda(frame):
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return ToTensor()(Image.fromarray(frame)).unsqueeze_(0).cuda()
with torch.no_grad():
while True:
bgr = None
while True: # grab bgr
frame = cam.read()
key = dsp.step(frame)
if key == ord('b'):
bgr = cv2_frame_to_cuda(cam.read())
break
elif key == ord('q'):
exit()
while True: # matting
frame = cam.read()
src = cv2_frame_to_cuda(frame)
pha, fgr = model(src, bgr)[:2]
res = pha * fgr + (1 - pha) * torch.ones_like(fgr)
res = res.mul(255).byte().cpu().permute(0, 2, 3, 1).numpy()[0]
res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
key = dsp.step(res)
if key == ord('b'):
break
elif key == ord('q'):
exit()
================================================
FILE: model/__init__.py
================================================
from .model import Base, MattingBase, MattingRefine
================================================
FILE: model/decoder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""
Decoder upsamples the image by combining the feature maps at all resolutions from the encoder.
Input:
x4: (B, C, H/16, W/16) feature map at 1/16 resolution.
x3: (B, C, H/8, W/8) feature map at 1/8 resolution.
x2: (B, C, H/4, W/4) feature map at 1/4 resolution.
x1: (B, C, H/2, W/2) feature map at 1/2 resolution.
x0: (B, C, H, W) feature map at full resolution.
Output:
x: (B, C, H, W) upsampled output at full resolution.
"""
def __init__(self, channels, feature_channels):
super().__init__()
self.conv1 = nn.Conv2d(feature_channels[0] + channels[0], channels[1], 3, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(channels[1])
self.conv2 = nn.Conv2d(feature_channels[1] + channels[1], channels[2], 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(channels[2])
self.conv3 = nn.Conv2d(feature_channels[2] + channels[2], channels[3], 3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(channels[3])
self.conv4 = nn.Conv2d(feature_channels[3] + channels[3], channels[4], 3, padding=1)
self.relu = nn.ReLU(True)
def forward(self, x4, x3, x2, x1, x0):
x = F.interpolate(x4, size=x3.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x3], dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x2], dim=1)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, size=x1.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x1], dim=1)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = F.interpolate(x, size=x0.shape[2:], mode='bilinear', align_corners=False)
x = torch.cat([x, x0], dim=1)
x = self.conv4(x)
return x
================================================
FILE: model/mobilenet.py
================================================
from torch import nn
from torchvision.models import MobileNetV2
class MobileNetV2Encoder(MobileNetV2):
"""
MobileNetV2Encoder inherits from torchvision's official MobileNetV2. It is modified to
use dilation on the last block to maintain output stride 16, and deleted the
classifier block that was originally used for classification. The forward method
additionally returns the feature maps at all resolutions for decoder's use.
"""
def __init__(self, in_channels, norm_layer=None):
super().__init__()
# Replace first conv layer if in_channels doesn't match.
if in_channels != 3:
self.features[0][0] = nn.Conv2d(in_channels, 32, 3, 2, 1, bias=False)
# Remove last block
self.features = self.features[:-1]
# Change to use dilation to maintain output stride = 16
self.features[14].conv[1][0].stride = (1, 1)
for feature in self.features[15:]:
feature.conv[1][0].dilation = (2, 2)
feature.conv[1][0].padding = (2, 2)
# Delete classifier
del self.classifier
def forward(self, x):
x0 = x # 1/1
x = self.features[0](x)
x = self.features[1](x)
x1 = x # 1/2
x = self.features[2](x)
x = self.features[3](x)
x2 = x # 1/4
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
x3 = x # 1/8
x = self.features[7](x)
x = self.features[8](x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x = self.features[13](x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
x = self.features[17](x)
x4 = x # 1/16
return x4, x3, x2, x1, x0
================================================
FILE: model/model.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models.segmentation.deeplabv3 import ASPP
from .decoder import Decoder
from .mobilenet import MobileNetV2Encoder
from .refiner import Refiner
from .resnet import ResNetEncoder
from .utils import load_matched_state_dict
class Base(nn.Module):
"""
A generic implementation of the base encoder-decoder network inspired by DeepLab.
Accepts arbitrary channels for input and output.
"""
def __init__(self, backbone: str, in_channels: int, out_channels: int):
super().__init__()
assert backbone in ["resnet50", "resnet101", "mobilenetv2"]
if backbone in ['resnet50', 'resnet101']:
self.backbone = ResNetEncoder(in_channels, variant=backbone)
self.aspp = ASPP(2048, [3, 6, 9])
self.decoder = Decoder([256, 128, 64, 48, out_channels], [512, 256, 64, in_channels])
else:
self.backbone = MobileNetV2Encoder(in_channels)
self.aspp = ASPP(320, [3, 6, 9])
self.decoder = Decoder([256, 128, 64, 48, out_channels], [32, 24, 16, in_channels])
def forward(self, x):
x, *shortcuts = self.backbone(x)
x = self.aspp(x)
x = self.decoder(x, *shortcuts)
return x
def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats=True):
# Pretrained DeepLabV3 models are provided by <https://github.com/VainF/DeepLabV3Plus-Pytorch>.
# This method converts and loads their pretrained state_dict to match with our model structure.
# This method is not needed if you are not planning to train from deeplab weights.
# Use load_state_dict() for normal weight loading.
# Convert state_dict naming for aspp module
state_dict = {k.replace('classifier.classifier.0', 'aspp'): v for k, v in state_dict.items()}
if isinstance(self.backbone, ResNetEncoder):
# ResNet backbone does not need change.
load_matched_state_dict(self, state_dict, print_stats)
else:
# Change MobileNetV2 backbone to state_dict format, then change back after loading.
backbone_features = self.backbone.features
self.backbone.low_level_features = backbone_features[:4]
self.backbone.high_level_features = backbone_features[4:]
del self.backbone.features
load_matched_state_dict(self, state_dict, print_stats)
self.backbone.features = backbone_features
del self.backbone.low_level_features
del self.backbone.high_level_features
class MattingBase(Base):
"""
MattingBase is used to produce coarse global results at a lower resolution.
MattingBase extends Base.
Args:
backbone: ["resnet50", "resnet101", "mobilenetv2"]
Input:
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
bgr: (B, 3, H, W) the background image . Channels are RGB values normalized to 0 ~ 1.
Output:
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
err: (B, 1, H, W) the error prediction. Normalized to 0 ~ 1.
hid: (B, 32, H, W) the hidden encoding. Used for connecting refiner module.
Example:
model = MattingBase(backbone='resnet50')
pha, fgr, err, hid = model(src, bgr) # for training
pha, fgr = model(src, bgr)[:2] # for inference
"""
def __init__(self, backbone: str):
super().__init__(backbone, in_channels=6, out_channels=(1 + 3 + 1 + 32))
def forward(self, src, bgr):
x = torch.cat([src, bgr], dim=1)
x, *shortcuts = self.backbone(x)
x = self.aspp(x)
x = self.decoder(x, *shortcuts)
pha = x[:, 0:1].clamp_(0., 1.)
fgr = x[:, 1:4].add(src).clamp_(0., 1.)
err = x[:, 4:5].clamp_(0., 1.)
hid = x[:, 5: ].relu_()
return pha, fgr, err, hid
class MattingRefine(MattingBase):
"""
MattingRefine includes the refiner module to upsample coarse result to full resolution.
MattingRefine extends MattingBase.
Args:
backbone: ["resnet50", "resnet101", "mobilenetv2"]
backbone_scale: The image downsample scale for passing through backbone, default 1/4 or 0.25.
Must not be greater than 1/2.
refine_mode: refine area selection mode. Options:
"full" - No area selection, refine everywhere using regular Conv2d.
"sampling" - Refine fixed amount of pixels ranked by the top most errors.
"thresholding" - Refine varying amount of pixels that has more error than the threshold.
refine_sample_pixels: number of pixels to refine. Only used when mode == "sampling".
refine_threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
refine_kernel_size: the refiner's convolutional kernel size. Options: [1, 3]
refine_prevent_oversampling: prevent sampling more pixels than needed for sampling mode. Set False only for speedtest.
Input:
src: (B, 3, H, W) the source image. Channels are RGB values normalized to 0 ~ 1.
bgr: (B, 3, H, W) the background image. Channels are RGB values normalized to 0 ~ 1.
Output:
pha: (B, 1, H, W) the alpha prediction. Normalized to 0 ~ 1.
fgr: (B, 3, H, W) the foreground prediction. Channels are RGB values normalized to 0 ~ 1.
pha_sm: (B, 1, Hc, Wc) the coarse alpha prediction from matting base. Normalized to 0 ~ 1.
fgr_sm: (B, 3, Hc, Hc) the coarse foreground prediction from matting base. Normalized to 0 ~ 1.
err_sm: (B, 1, Hc, Wc) the coarse error prediction from matting base. Normalized to 0 ~ 1.
ref_sm: (B, 1, H/4, H/4) the quarter resolution refinement map. 1 indicates refined 4x4 patch locations.
Example:
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='sampling', refine_sample_pixels=80_000)
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='thresholding', refine_threshold=0.1)
model = MattingRefine(backbone='resnet50', backbone_scale=1/4, refine_mode='full')
pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm = model(src, bgr) # for training
pha, fgr = model(src, bgr)[:2] # for inference
"""
def __init__(self,
backbone: str,
backbone_scale: float = 1/4,
refine_mode: str = 'sampling',
refine_sample_pixels: int = 80_000,
refine_threshold: float = 0.1,
refine_kernel_size: int = 3,
refine_prevent_oversampling: bool = True,
refine_patch_crop_method: str = 'unfold',
refine_patch_replace_method: str = 'scatter_nd'):
assert backbone_scale <= 1/2, 'backbone_scale should not be greater than 1/2'
super().__init__(backbone)
self.backbone_scale = backbone_scale
self.refiner = Refiner(refine_mode,
refine_sample_pixels,
refine_threshold,
refine_kernel_size,
refine_prevent_oversampling,
refine_patch_crop_method,
refine_patch_replace_method)
def forward(self, src, bgr):
assert src.size() == bgr.size(), 'src and bgr must have the same shape'
assert src.size(2) // 4 * 4 == src.size(2) and src.size(3) // 4 * 4 == src.size(3), \
'src and bgr must have width and height that are divisible by 4'
# Downsample src and bgr for backbone
src_sm = F.interpolate(src,
scale_factor=self.backbone_scale,
mode='bilinear',
align_corners=False,
recompute_scale_factor=True)
bgr_sm = F.interpolate(bgr,
scale_factor=self.backbone_scale,
mode='bilinear',
align_corners=False,
recompute_scale_factor=True)
# Base
x = torch.cat([src_sm, bgr_sm], dim=1)
x, *shortcuts = self.backbone(x)
x = self.aspp(x)
x = self.decoder(x, *shortcuts)
pha_sm = x[:, 0:1].clamp_(0., 1.)
fgr_sm = x[:, 1:4]
err_sm = x[:, 4:5].clamp_(0., 1.)
hid_sm = x[:, 5: ].relu_()
# Refiner
pha, fgr, ref_sm = self.refiner(src, bgr, pha_sm, fgr_sm, err_sm, hid_sm)
# Clamp outputs
pha = pha.clamp_(0., 1.)
fgr = fgr.add_(src).clamp_(0., 1.)
fgr_sm = src_sm.add_(fgr_sm).clamp_(0., 1.)
return pha, fgr, pha_sm, fgr_sm, err_sm, ref_sm
================================================
FILE: model/refiner.py
================================================
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from typing import Tuple
class Refiner(nn.Module):
"""
Refiner refines the coarse output to full resolution.
Args:
mode: area selection mode. Options:
"full" - No area selection, refine everywhere using regular Conv2d.
"sampling" - Refine fixed amount of pixels ranked by the top most errors.
"thresholding" - Refine varying amount of pixels that have greater error than the threshold.
sample_pixels: number of pixels to refine. Only used when mode == "sampling".
threshold: error threshold ranged from 0 ~ 1. Refine where err > threshold. Only used when mode == "thresholding".
kernel_size: The convolution kernel_size. Options: [1, 3]
prevent_oversampling: True for regular cases, False for speedtest.
Compatibility Args:
patch_crop_method: the method for cropping patches. Options:
"unfold" - Best performance for PyTorch and TorchScript.
"roi_align" - Another way for croping patches.
"gather" - Another way for croping patches.
patch_replace_method: the method for replacing patches. Options:
"scatter_nd" - Best performance for PyTorch and TorchScript.
"scatter_element" - Another way for replacing patches.
Input:
src: (B, 3, H, W) full resolution source image.
bgr: (B, 3, H, W) full resolution background image.
pha: (B, 1, Hc, Wc) coarse alpha prediction.
fgr: (B, 3, Hc, Wc) coarse foreground residual prediction.
err: (B, 1, Hc, Hc) coarse error prediction.
hid: (B, 32, Hc, Hc) coarse hidden encoding.
Output:
pha: (B, 1, H, W) full resolution alpha prediction.
fgr: (B, 3, H, W) full resolution foreground residual prediction.
ref: (B, 1, H/4, W/4) quarter resolution refinement selection map. 1 indicates refined 4x4 patch locations.
"""
# For TorchScript export optimization.
__constants__ = ['kernel_size', 'patch_crop_method', 'patch_replace_method']
def __init__(self,
mode: str,
sample_pixels: int,
threshold: float,
kernel_size: int = 3,
prevent_oversampling: bool = True,
patch_crop_method: str = 'unfold',
patch_replace_method: str = 'scatter_nd'):
super().__init__()
assert mode in ['full', 'sampling', 'thresholding']
assert kernel_size in [1, 3]
assert patch_crop_method in ['unfold', 'roi_align', 'gather']
assert patch_replace_method in ['scatter_nd', 'scatter_element']
self.mode = mode
self.sample_pixels = sample_pixels
self.threshold = threshold
self.kernel_size = kernel_size
self.prevent_oversampling = prevent_oversampling
self.patch_crop_method = patch_crop_method
self.patch_replace_method = patch_replace_method
channels = [32, 24, 16, 12, 4]
self.conv1 = nn.Conv2d(channels[0] + 6 + 4, channels[1], kernel_size, bias=False)
self.bn1 = nn.BatchNorm2d(channels[1])
self.conv2 = nn.Conv2d(channels[1], channels[2], kernel_size, bias=False)
self.bn2 = nn.BatchNorm2d(channels[2])
self.conv3 = nn.Conv2d(channels[2] + 6, channels[3], kernel_size, bias=False)
self.bn3 = nn.BatchNorm2d(channels[3])
self.conv4 = nn.Conv2d(channels[3], channels[4], kernel_size, bias=True)
self.relu = nn.ReLU(True)
def forward(self,
src: torch.Tensor,
bgr: torch.Tensor,
pha: torch.Tensor,
fgr: torch.Tensor,
err: torch.Tensor,
hid: torch.Tensor):
H_full, W_full = src.shape[2:]
H_half, W_half = H_full // 2, W_full // 2
H_quat, W_quat = H_full // 4, W_full // 4
src_bgr = torch.cat([src, bgr], dim=1)
if self.mode != 'full':
err = F.interpolate(err, (H_quat, W_quat), mode='bilinear', align_corners=False)
ref = self.select_refinement_regions(err)
idx = torch.nonzero(ref.squeeze(1))
idx = idx[:, 0], idx[:, 1], idx[:, 2]
if idx[0].size(0) > 0:
x = torch.cat([hid, pha, fgr], dim=1)
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
x = self.crop_patch(x, idx, 2, 3 if self.kernel_size == 3 else 0)
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
y = self.crop_patch(y, idx, 2, 3 if self.kernel_size == 3 else 0)
x = self.conv1(torch.cat([x, y], dim=1))
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = F.interpolate(x, 8 if self.kernel_size == 3 else 4, mode='nearest')
y = self.crop_patch(src_bgr, idx, 4, 2 if self.kernel_size == 3 else 0)
x = self.conv3(torch.cat([x, y], dim=1))
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
out = torch.cat([pha, fgr], dim=1)
out = F.interpolate(out, (H_full, W_full), mode='bilinear', align_corners=False)
out = self.replace_patch(out, x, idx)
pha = out[:, :1]
fgr = out[:, 1:]
else:
pha = F.interpolate(pha, (H_full, W_full), mode='bilinear', align_corners=False)
fgr = F.interpolate(fgr, (H_full, W_full), mode='bilinear', align_corners=False)
else:
x = torch.cat([hid, pha, fgr], dim=1)
x = F.interpolate(x, (H_half, W_half), mode='bilinear', align_corners=False)
y = F.interpolate(src_bgr, (H_half, W_half), mode='bilinear', align_corners=False)
if self.kernel_size == 3:
x = F.pad(x, (3, 3, 3, 3))
y = F.pad(y, (3, 3, 3, 3))
x = self.conv1(torch.cat([x, y], dim=1))
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
if self.kernel_size == 3:
x = F.interpolate(x, (H_full + 4, W_full + 4))
y = F.pad(src_bgr, (2, 2, 2, 2))
else:
x = F.interpolate(x, (H_full, W_full), mode='nearest')
y = src_bgr
x = self.conv3(torch.cat([x, y], dim=1))
x = self.bn3(x)
x = self.relu(x)
x = self.conv4(x)
pha = x[:, :1]
fgr = x[:, 1:]
ref = torch.ones((src.size(0), 1, H_quat, W_quat), device=src.device, dtype=src.dtype)
return pha, fgr, ref
def select_refinement_regions(self, err: torch.Tensor):
"""
Select refinement regions.
Input:
err: error map (B, 1, H, W)
Output:
ref: refinement regions (B, 1, H, W). FloatTensor. 1 is selected, 0 is not.
"""
if self.mode == 'sampling':
# Sampling mode.
b, _, h, w = err.shape
err = err.view(b, -1)
idx = err.topk(self.sample_pixels // 16, dim=1, sorted=False).indices
ref = torch.zeros_like(err)
ref.scatter_(1, idx, 1.)
if self.prevent_oversampling:
ref.mul_(err.gt(0).float())
ref = ref.view(b, 1, h, w)
else:
# Thresholding mode.
ref = err.gt(self.threshold).float()
return ref
def crop_patch(self,
x: torch.Tensor,
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
size: int,
padding: int):
"""
Crops selected patches from image given indices.
Inputs:
x: image (B, C, H, W).
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
size: center size of the patch, also stride of the crop.
padding: expansion size of the patch.
Output:
patch: (P, C, h, w), where h = w = size + 2 * padding.
"""
if padding != 0:
x = F.pad(x, (padding,) * 4)
if self.patch_crop_method == 'unfold':
# Use unfold. Best performance for PyTorch and TorchScript.
return x.permute(0, 2, 3, 1) \
.unfold(1, size + 2 * padding, size) \
.unfold(2, size + 2 * padding, size)[idx[0], idx[1], idx[2]]
elif self.patch_crop_method == 'roi_align':
# Use roi_align. Best compatibility for ONNX.
idx = idx[0].type_as(x), idx[1].type_as(x), idx[2].type_as(x)
b = idx[0]
x1 = idx[2] * size - 0.5
y1 = idx[1] * size - 0.5
x2 = idx[2] * size + size + 2 * padding - 0.5
y2 = idx[1] * size + size + 2 * padding - 0.5
boxes = torch.stack([b, x1, y1, x2, y2], dim=1)
return torchvision.ops.roi_align(x, boxes, size + 2 * padding, sampling_ratio=1)
else:
# Use gather. Crops out patches pixel by pixel.
idx_pix = self.compute_pixel_indices(x, idx, size, padding)
pat = torch.gather(x.view(-1), 0, idx_pix.view(-1))
pat = pat.view(-1, x.size(1), size + 2 * padding, size + 2 * padding)
return pat
def replace_patch(self,
x: torch.Tensor,
y: torch.Tensor,
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
"""
Replaces patches back into image given index.
Inputs:
x: image (B, C, H, W)
y: patches (P, C, h, w)
idx: selection indices Tuple[(P,), (P,), (P,)] where the 3 values are (B, H, W) index.
Output:
image: (B, C, H, W), where patches at idx locations are replaced with y.
"""
xB, xC, xH, xW = x.shape
yB, yC, yH, yW = y.shape
if self.patch_replace_method == 'scatter_nd':
# Use scatter_nd. Best performance for PyTorch and TorchScript. Replacing patch by patch.
x = x.view(xB, xC, xH // yH, yH, xW // yW, yW).permute(0, 2, 4, 1, 3, 5)
x[idx[0], idx[1], idx[2]] = y
x = x.permute(0, 3, 1, 4, 2, 5).view(xB, xC, xH, xW)
return x
else:
# Use scatter_element. Best compatibility for ONNX. Replacing pixel by pixel.
idx_pix = self.compute_pixel_indices(x, idx, size=4, padding=0)
return x.view(-1).scatter_(0, idx_pix.view(-1), y.view(-1)).view(x.shape)
def compute_pixel_indices(self,
x: torch.Tensor,
idx: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
size: int,
padding: int):
"""
Compute selected pixel indices in the tensor.
Used for crop_method == 'gather' and replace_method == 'scatter_element', which crop and replace pixel by pixel.
Input:
x: image: (B, C, H, W)
idx: selection indices Tuple[(P,), (P,), (P,),], where the 3 values are (B, H, W) index.
size: center size of the patch, also stride of the crop.
padding: expansion size of the patch.
Output:
idx: (P, C, O, O) long tensor where O is the output size: size + 2 * padding, P is number of patches.
the element are indices pointing to the input x.view(-1).
"""
B, C, H, W = x.shape
S, P = size, padding
O = S + 2 * P
b, y, x = idx
n = b.size(0)
c = torch.arange(C)
o = torch.arange(O)
idx_pat = (c * H * W).view(C, 1, 1).expand([C, O, O]) + (o * W).view(1, O, 1).expand([C, O, O]) + o.view(1, 1, O).expand([C, O, O])
idx_loc = b * W * H + y * W * S + x * S
idx_pix = idx_loc.view(-1, 1, 1, 1).expand([n, C, O, O]) + idx_pat.view(1, C, O, O).expand([n, C, O, O])
return idx_pix
================================================
FILE: model/resnet.py
================================================
from torch import nn
from torchvision.models.resnet import ResNet, Bottleneck
class ResNetEncoder(ResNet):
"""
ResNetEncoder inherits from torchvision's official ResNet. It is modified to
use dilation on the last block to maintain output stride 16, and deleted the
global average pooling layer and the fully connected layer that was originally
used for classification. The forward method additionally returns the feature
maps at all resolutions for decoder's use.
"""
layers = {
'resnet50': [3, 4, 6, 3],
'resnet101': [3, 4, 23, 3],
}
def __init__(self, in_channels, variant='resnet101', norm_layer=None):
super().__init__(
block=Bottleneck,
layers=self.layers[variant],
replace_stride_with_dilation=[False, False, True],
norm_layer=norm_layer)
# Replace first conv layer if in_channels doesn't match.
if in_channels != 3:
self.conv1 = nn.Conv2d(in_channels, 64, 7, 2, 3, bias=False)
# Delete fully-connected layer
del self.avgpool
del self.fc
def forward(self, x):
x0 = x # 1/1
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x1 = x # 1/2
x = self.maxpool(x)
x = self.layer1(x)
x2 = x # 1/4
x = self.layer2(x)
x3 = x # 1/8
x = self.layer3(x)
x = self.layer4(x)
x4 = x # 1/16
return x4, x3, x2, x1, x0
================================================
FILE: model/utils.py
================================================
def load_matched_state_dict(model, state_dict, print_stats=True):
"""
Only loads weights that matched in key and shape. Ignore other weights.
"""
num_matched, num_total = 0, 0
curr_state_dict = model.state_dict()
for key in curr_state_dict.keys():
num_total += 1
if key in state_dict and curr_state_dict[key].shape == state_dict[key].shape:
curr_state_dict[key] = state_dict[key]
num_matched += 1
model.load_state_dict(curr_state_dict)
if print_stats:
print(f'Loaded state_dict: {num_matched}/{num_total} matched')
================================================
FILE: requirements.txt
================================================
kornia==0.4.1
tensorboard==2.3.0
torch==1.7.0
torchvision==0.8.1
tqdm==4.51.0
opencv-python==4.4.0.44
onnxruntime==1.6.0
================================================
FILE: train_base.py
================================================
"""
Train MattingBase
You can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorch>
Example:
CUDA_VISIBLE_DEVICES=0 python train_base.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingbase-resnet50-videomatte240k \
--model-pretrain-initialization "pretraining/best_deeplabv3_resnet50_voc_os16.pth" \
--epoch-end 8
"""
import argparse
import kornia
import torch
import os
import random
from torch import nn
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from torchvision import transforms as T
from PIL import Image
from data_path import DATA_PATH
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
from dataset import augmentation as A
from model import MattingBase
from model.utils import load_matched_state_dict
# --------------- Arguments ---------------
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-pretrain-initialization', type=str, default=None)
parser.add_argument('--model-last-checkpoint', type=str, default=None)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, required=True)
parser.add_argument('--log-train-loss-interval', type=int, default=10)
parser.add_argument('--log-train-images-interval', type=int, default=2000)
parser.add_argument('--log-valid-interval', type=int, default=5000)
parser.add_argument('--checkpoint-interval', type=int, default=5000)
args = parser.parse_args()
# --------------- Loading ---------------
def train():
# Training DataLoader
dataset_train = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.4, 1), shear=(-5, 5)),
A.PairRandomHorizontalFlip(),
A.PairRandomBoxBlur(0.1, 5),
A.PairRandomSharpen(0.1),
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
T.RandomHorizontalFlip(),
A.RandomBoxBlur(0.1, 5),
A.RandomSharpen(0.1),
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
T.ToTensor()
])),
])
dataloader_train = DataLoader(dataset_train,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True)
# Validation DataLoader
dataset_valid = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((512, 512), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
T.ToTensor()
])),
])
dataset_valid = SampleDataset(dataset_valid, 50)
dataloader_valid = DataLoader(dataset_valid,
pin_memory=True,
batch_size=args.batch_size,
num_workers=args.num_workers)
# Model
model = MattingBase(args.model_backbone).cuda()
if args.model_last_checkpoint is not None:
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
elif args.model_pretrain_initialization is not None:
model.load_pretrained_deeplabv3_state_dict(torch.load(args.model_pretrain_initialization)['model_state'])
optimizer = Adam([
{'params': model.backbone.parameters(), 'lr': 1e-4},
{'params': model.aspp.parameters(), 'lr': 5e-4},
{'params': model.decoder.parameters(), 'lr': 5e-4}
])
scaler = GradScaler()
# Logging and checkpoints
if not os.path.exists(f'checkpoint/{args.model_name}'):
os.makedirs(f'checkpoint/{args.model_name}')
writer = SummaryWriter(f'log/{args.model_name}')
# Run loop
for epoch in range(args.epoch_start, args.epoch_end):
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
step = epoch * len(dataloader_train) + i
true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
true_src = true_bgr.clone()
# Augment with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx
# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx
# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
del aug_jitter_idx
# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
del aug_affine_idx
with autocast():
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if (i + 1) % args.log_train_loss_interval == 0:
writer.add_scalar('loss', loss, step)
if (i + 1) % args.log_train_images_interval == 0:
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
writer.add_image('train_pred_err', make_grid(pred_err, nrow=5), step)
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
writer.add_image('train_true_bgr', make_grid(true_bgr, nrow=5), step)
del true_pha, true_fgr, true_bgr
del pred_pha, pred_fgr, pred_err
if (i + 1) % args.log_valid_interval == 0:
valid(model, dataloader_valid, writer, step)
if (step + 1) % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
# --------------- Utils ---------------
def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
true_err = torch.abs(pred_pha.detach() - true_pha)
true_msk = true_pha != 0
return F.l1_loss(pred_pha, true_pha) + \
F.l1_loss(kornia.sobel(pred_pha), kornia.sobel(true_pha)) + \
F.l1_loss(pred_fgr * true_msk, true_fgr * true_msk) + \
F.mse_loss(pred_err, true_err)
def random_crop(*imgs):
w = random.choice(range(256, 512))
h = random.choice(range(256, 512))
results = []
for img in imgs:
img = kornia.resize(img, (max(h, w), max(h, w)))
img = kornia.center_crop(img, (h, w))
results.append(img)
return results
def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
loss_count = 0
with torch.no_grad():
for (true_pha, true_fgr), true_bgr in dataloader:
batch_size = true_pha.size(0)
true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
pred_pha, pred_fgr, pred_err = model(true_src, true_bgr)[:3]
loss = compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr)
loss_total += loss.cpu().item() * batch_size
loss_count += batch_size
writer.add_scalar('valid_loss', loss_total / loss_count, step)
model.train()
# --------------- Start ---------------
if __name__ == '__main__':
train()
================================================
FILE: train_refine.py
================================================
"""
Train MattingRefine
Supports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.
Select GPUs through CUDA_VISIBLE_DEVICES environment variable.
Example:
CUDA_VISIBLE_DEVICES=0,1 python train_refine.py \
--dataset-name videomatte240k \
--model-backbone resnet50 \
--model-name mattingrefine-resnet50-videomatte240k \
--model-last-checkpoint "PATH_TO_LAST_CHECKPOINT" \
--epoch-end 1
"""
import argparse
import kornia
import torch
import os
import random
from torch import nn
from torch import distributed as dist
from torch import multiprocessing as mp
from torch.nn import functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset
from torch.optim import Adam
from torchvision.utils import make_grid
from tqdm import tqdm
from torchvision import transforms as T
from PIL import Image
from data_path import DATA_PATH
from dataset import ImagesDataset, ZipDataset, VideoDataset, SampleDataset
from dataset import augmentation as A
from model import MattingRefine
from model.utils import load_matched_state_dict
# --------------- Arguments ---------------
parser = argparse.ArgumentParser()
parser.add_argument('--dataset-name', type=str, required=True, choices=DATA_PATH.keys())
parser.add_argument('--model-backbone', type=str, required=True, choices=['resnet101', 'resnet50', 'mobilenetv2'])
parser.add_argument('--model-backbone-scale', type=float, default=0.25)
parser.add_argument('--model-refine-mode', type=str, default='sampling', choices=['full', 'sampling', 'thresholding'])
parser.add_argument('--model-refine-sample-pixels', type=int, default=80_000)
parser.add_argument('--model-refine-thresholding', type=float, default=0.7)
parser.add_argument('--model-refine-kernel-size', type=int, default=3, choices=[1, 3])
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-last-checkpoint', type=str, default=None)
parser.add_argument('--batch-size', type=int, default=4)
parser.add_argument('--num-workers', type=int, default=16)
parser.add_argument('--epoch-start', type=int, default=0)
parser.add_argument('--epoch-end', type=int, required=True)
parser.add_argument('--log-train-loss-interval', type=int, default=10)
parser.add_argument('--log-train-images-interval', type=int, default=1000)
parser.add_argument('--log-valid-interval', type=int, default=2000)
parser.add_argument('--checkpoint-interval', type=int, default=2000)
args = parser.parse_args()
distributed_num_gpus = torch.cuda.device_count()
assert args.batch_size % distributed_num_gpus == 0
# --------------- Main ---------------
def train_worker(rank, addr, port):
# Distributed Setup
os.environ['MASTER_ADDR'] = addr
os.environ['MASTER_PORT'] = port
dist.init_process_group("nccl", rank=rank, world_size=distributed_num_gpus)
# Training DataLoader
dataset_train = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['train']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['train']['fgr'], mode='RGB'),
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairRandomHorizontalFlip(),
A.PairRandomBoxBlur(0.1, 5),
A.PairRandomSharpen(0.1),
A.PairApplyOnlyAtIndices([1], T.ColorJitter(0.15, 0.15, 0.15, 0.05)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['train'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 2), shear=(-5, 5)),
T.RandomHorizontalFlip(),
A.RandomBoxBlur(0.1, 5),
A.RandomSharpen(0.1),
T.ColorJitter(0.15, 0.15, 0.15, 0.05),
T.ToTensor()
])),
])
dataset_train_len_per_gpu_worker = int(len(dataset_train) / distributed_num_gpus)
dataset_train = Subset(dataset_train, range(rank * dataset_train_len_per_gpu_worker, (rank + 1) * dataset_train_len_per_gpu_worker))
dataloader_train = DataLoader(dataset_train,
shuffle=True,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)
# Validation DataLoader
if rank == 0:
dataset_valid = ZipDataset([
ZipDataset([
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['pha'], mode='L'),
ImagesDataset(DATA_PATH[args.dataset_name]['valid']['fgr'], mode='RGB')
], transforms=A.PairCompose([
A.PairRandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(0.3, 1), shear=(-5, 5)),
A.PairApply(T.ToTensor())
]), assert_equal_length=True),
ImagesDataset(DATA_PATH['backgrounds']['valid'], mode='RGB', transforms=T.Compose([
A.RandomAffineAndResize((2048, 2048), degrees=(-5, 5), translate=(0.1, 0.1), scale=(1, 1.2), shear=(-5, 5)),
T.ToTensor()
])),
])
dataset_valid = SampleDataset(dataset_valid, 50)
dataloader_valid = DataLoader(dataset_valid,
pin_memory=True,
drop_last=True,
batch_size=args.batch_size // distributed_num_gpus,
num_workers=args.num_workers // distributed_num_gpus)
# Model
model = MattingRefine(args.model_backbone,
args.model_backbone_scale,
args.model_refine_mode,
args.model_refine_sample_pixels,
args.model_refine_thresholding,
args.model_refine_kernel_size).to(rank)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_distributed = nn.parallel.DistributedDataParallel(model, device_ids=[rank])
if args.model_last_checkpoint is not None:
load_matched_state_dict(model, torch.load(args.model_last_checkpoint))
optimizer = Adam([
{'params': model.backbone.parameters(), 'lr': 5e-5},
{'params': model.aspp.parameters(), 'lr': 5e-5},
{'params': model.decoder.parameters(), 'lr': 1e-4},
{'params': model.refiner.parameters(), 'lr': 3e-4},
])
scaler = GradScaler()
# Logging and checkpoints
if rank == 0:
if not os.path.exists(f'checkpoint/{args.model_name}'):
os.makedirs(f'checkpoint/{args.model_name}')
writer = SummaryWriter(f'log/{args.model_name}')
# Run loop
for epoch in range(args.epoch_start, args.epoch_end):
for i, ((true_pha, true_fgr), true_bgr) in enumerate(tqdm(dataloader_train)):
step = epoch * len(dataloader_train) + i
true_pha = true_pha.to(rank, non_blocking=True)
true_fgr = true_fgr.to(rank, non_blocking=True)
true_bgr = true_bgr.to(rank, non_blocking=True)
true_pha, true_fgr, true_bgr = random_crop(true_pha, true_fgr, true_bgr)
true_src = true_bgr.clone()
# Augment with shadow
aug_shadow_idx = torch.rand(len(true_src)) < 0.3
if aug_shadow_idx.any():
aug_shadow = true_pha[aug_shadow_idx].mul(0.3 * random.random())
aug_shadow = T.RandomAffine(degrees=(-5, 5), translate=(0.2, 0.2), scale=(0.5, 1.5), shear=(-5, 5))(aug_shadow)
aug_shadow = kornia.filters.box_blur(aug_shadow, (random.choice(range(20, 40)),) * 2)
true_src[aug_shadow_idx] = true_src[aug_shadow_idx].sub_(aug_shadow).clamp_(0, 1)
del aug_shadow
del aug_shadow_idx
# Composite foreground onto source
true_src = true_fgr * true_pha + true_src * (1 - true_pha)
# Augment with noise
aug_noise_idx = torch.rand(len(true_src)) < 0.4
if aug_noise_idx.any():
true_src[aug_noise_idx] = true_src[aug_noise_idx].add_(torch.randn_like(true_src[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
true_bgr[aug_noise_idx] = true_bgr[aug_noise_idx].add_(torch.randn_like(true_bgr[aug_noise_idx]).mul_(0.03 * random.random())).clamp_(0, 1)
del aug_noise_idx
# Augment background with jitter
aug_jitter_idx = torch.rand(len(true_src)) < 0.8
if aug_jitter_idx.any():
true_bgr[aug_jitter_idx] = kornia.augmentation.ColorJitter(0.18, 0.18, 0.18, 0.1)(true_bgr[aug_jitter_idx])
del aug_jitter_idx
# Augment background with affine
aug_affine_idx = torch.rand(len(true_bgr)) < 0.3
if aug_affine_idx.any():
true_bgr[aug_affine_idx] = T.RandomAffine(degrees=(-1, 1), translate=(0.01, 0.01))(true_bgr[aug_affine_idx])
del aug_affine_idx
with autocast():
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model_distributed(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if rank == 0:
if (i + 1) % args.log_train_loss_interval == 0:
writer.add_scalar('loss', loss, step)
if (i + 1) % args.log_train_images_interval == 0:
writer.add_image('train_pred_pha', make_grid(pred_pha, nrow=5), step)
writer.add_image('train_pred_fgr', make_grid(pred_fgr, nrow=5), step)
writer.add_image('train_pred_com', make_grid(pred_fgr * pred_pha, nrow=5), step)
writer.add_image('train_pred_err', make_grid(pred_err_sm, nrow=5), step)
writer.add_image('train_true_src', make_grid(true_src, nrow=5), step)
del true_pha, true_fgr, true_src, true_bgr
del pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm
if (i + 1) % args.log_valid_interval == 0:
valid(model, dataloader_valid, writer, step)
if (step + 1) % args.checkpoint_interval == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}-iter-{step}.pth')
if rank == 0:
torch.save(model.state_dict(), f'checkpoint/{args.model_name}/epoch-{epoch}.pth')
# Clean up
dist.destroy_process_group()
# --------------- Utils ---------------
def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha_lg, true_fgr_lg):
true_pha_sm = kornia.resize(true_pha_lg, pred_pha_sm.shape[2:])
true_fgr_sm = kornia.resize(true_fgr_lg, pred_fgr_sm.shape[2:])
true_msk_lg = true_pha_lg != 0
true_msk_sm = true_pha_sm != 0
return F.l1_loss(pred_pha_lg, true_pha_lg) + \
F.l1_loss(pred_pha_sm, true_pha_sm) + \
F.l1_loss(kornia.sobel(pred_pha_lg), kornia.sobel(true_pha_lg)) + \
F.l1_loss(kornia.sobel(pred_pha_sm), kornia.sobel(true_pha_sm)) + \
F.l1_loss(pred_fgr_lg * true_msk_lg, true_fgr_lg * true_msk_lg) + \
F.l1_loss(pred_fgr_sm * true_msk_sm, true_fgr_sm * true_msk_sm) + \
F.mse_loss(kornia.resize(pred_err_sm, true_pha_lg.shape[2:]), \
kornia.resize(pred_pha_sm, true_pha_lg.shape[2:]).sub(true_pha_lg).abs())
def random_crop(*imgs):
H_src, W_src = imgs[0].shape[2:]
W_tgt = random.choice(range(1024, 2048)) // 4 * 4
H_tgt = random.choice(range(1024, 2048)) // 4 * 4
scale = max(W_tgt / W_src, H_tgt / H_src)
results = []
for img in imgs:
img = kornia.resize(img, (int(H_src * scale), int(W_src * scale)))
img = kornia.center_crop(img, (H_tgt, W_tgt))
results.append(img)
return results
def valid(model, dataloader, writer, step):
model.eval()
loss_total = 0
loss_count = 0
with torch.no_grad():
for (true_pha, true_fgr), true_bgr in dataloader:
batch_size = true_pha.size(0)
true_pha = true_pha.cuda(non_blocking=True)
true_fgr = true_fgr.cuda(non_blocking=True)
true_bgr = true_bgr.cuda(non_blocking=True)
true_src = true_pha * true_fgr + (1 - true_pha) * true_bgr
pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, _ = model(true_src, true_bgr)
loss = compute_loss(pred_pha, pred_fgr, pred_pha_sm, pred_fgr_sm, pred_err_sm, true_pha, true_fgr)
loss_total += loss.cpu().item() * batch_size
loss_count += batch_size
writer.add_scalar('valid_loss', loss_total / loss_count, step)
model.train()
# --------------- Start ---------------
if __name__ == '__main__':
addr = 'localhost'
port = str(random.choice(range(12300, 12400))) # pick a random port.
mp.spawn(train_worker,
nprocs=distributed_num_gpus,
args=(addr, port),
join=True)
gitextract_h75e3eky/ ├── LICENSE ├── README.md ├── data_path.py ├── dataset/ │ ├── __init__.py │ ├── augmentation.py │ ├── images.py │ ├── sample.py │ ├── video.py │ └── zip.py ├── doc/ │ └── model_usage.md ├── eval/ │ ├── benchmark.m │ ├── compute_connectivity_error.m │ ├── compute_gradient_loss.m │ ├── compute_mse_loss.m │ ├── compute_sad_loss.m │ └── gaussgradient.m ├── export_onnx.py ├── export_torchscript.py ├── inference_images.py ├── inference_speed_test.py ├── inference_utils.py ├── inference_video.py ├── inference_webcam.py ├── model/ │ ├── __init__.py │ ├── decoder.py │ ├── mobilenet.py │ ├── model.py │ ├── refiner.py │ ├── resnet.py │ └── utils.py ├── requirements.txt ├── train_base.py └── train_refine.py
SYMBOL INDEX (109 symbols across 18 files)
FILE: dataset/augmentation.py
class PairCompose (line 18) | class PairCompose(T.Compose):
method __call__ (line 19) | def __call__(self, *x):
class PairApply (line 25) | class PairApply:
method __init__ (line 26) | def __init__(self, transforms):
method __call__ (line 29) | def __call__(self, *x):
class PairApplyOnlyAtIndices (line 33) | class PairApplyOnlyAtIndices:
method __init__ (line 34) | def __init__(self, indices, transforms):
method __call__ (line 38) | def __call__(self, *x):
class PairRandomAffine (line 42) | class PairRandomAffine(T.RandomAffine):
method __init__ (line 43) | def __init__(self, degrees, translate=None, scale=None, shear=None, re...
method __call__ (line 47) | def __call__(self, *x):
class PairRandomHorizontalFlip (line 55) | class PairRandomHorizontalFlip(T.RandomHorizontalFlip):
method __call__ (line 56) | def __call__(self, *x):
class RandomBoxBlur (line 62) | class RandomBoxBlur:
method __init__ (line 63) | def __init__(self, prob, max_radius):
method __call__ (line 67) | def __call__(self, img):
class PairRandomBoxBlur (line 74) | class PairRandomBoxBlur(RandomBoxBlur):
method __call__ (line 75) | def __call__(self, *x):
class RandomSharpen (line 82) | class RandomSharpen:
method __init__ (line 83) | def __init__(self, prob):
method __call__ (line 87) | def __call__(self, img):
class PairRandomSharpen (line 93) | class PairRandomSharpen(RandomSharpen):
method __call__ (line 94) | def __call__(self, *x):
class PairRandomAffineAndResize (line 100) | class PairRandomAffineAndResize:
method __init__ (line 101) | def __init__(self, size, degrees, translate, scale, shear, ratio=(3./4...
method __call__ (line 111) | def __call__(self, *x):
class RandomAffineAndResize (line 139) | class RandomAffineAndResize(PairRandomAffineAndResize):
method __call__ (line 140) | def __call__(self, img):
FILE: dataset/images.py
class ImagesDataset (line 6) | class ImagesDataset(Dataset):
method __init__ (line 7) | def __init__(self, root, mode='RGB', transforms=None):
method __len__ (line 13) | def __len__(self):
method __getitem__ (line 16) | def __getitem__(self, idx):
FILE: dataset/sample.py
class SampleDataset (line 4) | class SampleDataset(Dataset):
method __init__ (line 5) | def __init__(self, dataset, samples):
method __len__ (line 10) | def __len__(self):
method __getitem__ (line 13) | def __getitem__(self, idx):
FILE: dataset/video.py
class VideoDataset (line 6) | class VideoDataset(Dataset):
method __init__ (line 7) | def __init__(self, path: str, transforms: any = None):
method __len__ (line 16) | def __len__(self):
method __getitem__ (line 19) | def __getitem__(self, idx):
method __enter__ (line 34) | def __enter__(self):
method __exit__ (line 37) | def __exit__(self, exc_type, exc_value, exc_traceback):
FILE: dataset/zip.py
class ZipDataset (line 4) | class ZipDataset(Dataset):
method __init__ (line 5) | def __init__(self, datasets: List[Dataset], transforms=None, assert_eq...
method __len__ (line 13) | def __len__(self):
method __getitem__ (line 16) | def __getitem__(self, idx):
FILE: export_torchscript.py
class MattingRefine_TorchScriptWrapper (line 33) | class MattingRefine_TorchScriptWrapper(nn.Module):
method __init__ (line 46) | def __init__(self, *args, **kwargs):
method forward (line 57) | def forward(self, src, bgr):
method load_state_dict (line 67) | def load_state_dict(self, *args, **kwargs):
FILE: inference_images.py
function writer (line 118) | def writer(img, path):
FILE: inference_utils.py
class HomographicAlignment (line 6) | class HomographicAlignment:
method __init__ (line 11) | def __init__(self):
method __call__ (line 15) | def __call__(self, src, bgr):
FILE: inference_video.py
class VideoWriter (line 80) | class VideoWriter:
method __init__ (line 81) | def __init__(self, path, frame_rate, width, height):
method add_batch (line 84) | def add_batch(self, frames):
class ImageSequenceWriter (line 93) | class ImageSequenceWriter:
method __init__ (line 94) | def __init__(self, path, extension):
method add_batch (line 100) | def add_batch(self, frames):
method _add_batch (line 104) | def _add_batch(self, frames, index):
FILE: inference_webcam.py
class Camera (line 57) | class Camera:
method __init__ (line 58) | def __init__(self, device_id=0, width=1280, height=720):
method __update (line 71) | def __update(self):
method read (line 78) | def read(self):
method __exit__ (line 82) | def __exit__(self, exec_type, exc_value, traceback):
class FPSTracker (line 86) | class FPSTracker:
method __init__ (line 87) | def __init__(self, ratio=0.5):
method tick (line 91) | def tick(self):
method get (line 100) | def get(self):
class Displayer (line 105) | class Displayer:
method __init__ (line 106) | def __init__(self, title, width=None, height=None, show_info=True):
method step (line 114) | def step(self, image):
function cv2_frame_to_cuda (line 145) | def cv2_frame_to_cuda(frame):
FILE: model/decoder.py
class Decoder (line 6) | class Decoder(nn.Module):
method __init__ (line 21) | def __init__(self, channels, feature_channels):
method forward (line 32) | def forward(self, x4, x3, x2, x1, x0):
FILE: model/mobilenet.py
class MobileNetV2Encoder (line 5) | class MobileNetV2Encoder(MobileNetV2):
method __init__ (line 13) | def __init__(self, in_channels, norm_layer=None):
method forward (line 32) | def forward(self, x):
FILE: model/model.py
class Base (line 13) | class Base(nn.Module):
method __init__ (line 19) | def __init__(self, backbone: str, in_channels: int, out_channels: int):
method forward (line 31) | def forward(self, x):
method load_pretrained_deeplabv3_state_dict (line 37) | def load_pretrained_deeplabv3_state_dict(self, state_dict, print_stats...
class MattingBase (line 61) | class MattingBase(Base):
method __init__ (line 86) | def __init__(self, backbone: str):
method forward (line 89) | def forward(self, src, bgr):
class MattingRefine (line 101) | class MattingRefine(MattingBase):
method __init__ (line 140) | def __init__(self,
method forward (line 161) | def forward(self, src, bgr):
FILE: model/refiner.py
class Refiner (line 8) | class Refiner(nn.Module):
method __init__ (line 48) | def __init__(self,
method forward (line 80) | def forward(self,
method select_refinement_regions (line 163) | def select_refinement_regions(self, err: torch.Tensor):
method crop_patch (line 186) | def crop_patch(self,
method replace_patch (line 227) | def replace_patch(self,
method compute_pixel_indices (line 255) | def compute_pixel_indices(self,
FILE: model/resnet.py
class ResNetEncoder (line 5) | class ResNetEncoder(ResNet):
method __init__ (line 19) | def __init__(self, in_channels, variant='resnet101', norm_layer=None):
method forward (line 34) | def forward(self, x):
FILE: model/utils.py
function load_matched_state_dict (line 1) | def load_matched_state_dict(model, state_dict, print_stats=True):
FILE: train_base.py
function train (line 70) | def train():
function compute_loss (line 219) | def compute_loss(pred_pha, pred_fgr, pred_err, true_pha, true_fgr):
function random_crop (line 228) | def random_crop(*imgs):
function valid (line 239) | def valid(model, dataloader, writer, step):
FILE: train_refine.py
function train_worker (line 80) | def train_worker(rank, addr, port):
function compute_loss (line 250) | def compute_loss(pred_pha_lg, pred_fgr_lg, pred_pha_sm, pred_fgr_sm, pre...
function random_crop (line 265) | def random_crop(*imgs):
function valid (line 278) | def valid(model, dataloader, writer, step):
Condensed preview — 33 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (122K chars).
[
{
"path": "LICENSE",
"chars": 1081,
"preview": "MIT License\n\nCopyright (c) 2020 University of Washington\n\nPermission is hereby granted, free of charge, to any person ob"
},
{
"path": "README.md",
"chars": 5544,
"preview": "# Real-Time High-Resolution Background Matting\n\n:\n def "
},
{
"path": "dataset/sample.py",
"chars": 413,
"preview": "from torch.utils.data import Dataset\n\n\nclass SampleDataset(Dataset):\n def __init__(self, dataset, samples):\n s"
},
{
"path": "dataset/video.py",
"chars": 1288,
"preview": "import cv2\nimport numpy as np\nfrom torch.utils.data import Dataset\nfrom PIL import Image\n\nclass VideoDataset(Dataset):\n "
},
{
"path": "dataset/zip.py",
"chars": 694,
"preview": "from torch.utils.data import Dataset\nfrom typing import List\n\nclass ZipDataset(Dataset):\n def __init__(self, datasets"
},
{
"path": "doc/model_usage.md",
"chars": 6669,
"preview": "# Use our model\nOur model supports multiple inference backends and provides flexible settings to trade-off quality and c"
},
{
"path": "eval/benchmark.m",
"chars": 2293,
"preview": "#!/usr/bin/octave\narg_list = argv ();\nbench_path = arg_list{1};\nresult_path = arg_list{2};\n\n\ngt_files = dir(fullfile(ben"
},
{
"path": "eval/compute_connectivity_error.m",
"chars": 1308,
"preview": "% compute the connectivity error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pr"
},
{
"path": "eval/compute_gradient_loss.m",
"chars": 619,
"preview": "% compute the gradient error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pred: "
},
{
"path": "eval/compute_mse_loss.m",
"chars": 456,
"preview": "% compute the MSE error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\n% pred: the p"
},
{
"path": "eval/compute_sad_loss.m",
"chars": 441,
"preview": "% compute the SAD error given a prediction, a ground truth and a trimap.\n% author Ning Xu\n% date 2018-1-1\n\nfunction loss"
},
{
"path": "eval/gaussgradient.m",
"chars": 1178,
"preview": "function [gx,gy]=gaussgradient(IM,sigma)\r\n%GAUSSGRADIENT Gradient using first order derivative of Gaussian.\r\n% [gx,gy]="
},
{
"path": "export_onnx.py",
"chars": 5814,
"preview": "\"\"\"\nExport MattingRefine as ONNX format.\nNeed to install onnxruntime through `pip install onnxrunttime`.\n\nExample:\n\n "
},
{
"path": "export_torchscript.py",
"chars": 2725,
"preview": "\"\"\"\nExport TorchScript\n\n python export_torchscript.py \\\n --model-backbone resnet50 \\\n --model-checkpoin"
},
{
"path": "inference_images.py",
"chars": 5735,
"preview": "\"\"\"\nInference images: Extract matting on images.\n\nExample:\n\n python inference_images.py \\\n --model-type mattin"
},
{
"path": "inference_speed_test.py",
"chars": 4206,
"preview": "\"\"\"\nInference Speed Test\n\nExample:\n\nRun inference on random noise input for fixed computation setting.\n(i.e. mode in ['f"
},
{
"path": "inference_utils.py",
"chars": 1598,
"preview": "import numpy as np\nimport cv2\nfrom PIL import Image\n\n\nclass HomographicAlignment:\n \"\"\"\n Apply homographic alignmen"
},
{
"path": "inference_video.py",
"chars": 8632,
"preview": "\"\"\"\nInference video: Extract matting on video.\n\nExample:\n\n python inference_video.py \\\n --model-type mattingre"
},
{
"path": "inference_webcam.py",
"chars": 6326,
"preview": "\"\"\"\nInference on webcams: Use a model on webcam input.\n\nOnce launched, the script is in background collection mode.\nPres"
},
{
"path": "model/__init__.py",
"chars": 51,
"preview": "from .model import Base, MattingBase, MattingRefine"
},
{
"path": "model/decoder.py",
"chars": 2084,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Decoder(nn.Module):\n \"\"\"\n Decoder upsam"
},
{
"path": "model/mobilenet.py",
"chars": 1904,
"preview": "from torch import nn\nfrom torchvision.models import MobileNetV2\n\n\nclass MobileNetV2Encoder(MobileNetV2):\n \"\"\"\n Mob"
},
{
"path": "model/model.py",
"chars": 9177,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom torchvision.models.segmentation.deeplabv3 im"
},
{
"path": "model/refiner.py",
"chars": 12518,
"preview": "import torch\nimport torchvision\nfrom torch import nn\nfrom torch.nn import functional as F\nfrom typing import Tuple\n\n\ncla"
},
{
"path": "model/resnet.py",
"chars": 1531,
"preview": "from torch import nn\nfrom torchvision.models.resnet import ResNet, Bottleneck\n\n\nclass ResNetEncoder(ResNet):\n \"\"\"\n "
},
{
"path": "model/utils.py",
"chars": 594,
"preview": "def load_matched_state_dict(model, state_dict, print_stats=True):\n \"\"\"\n Only loads weights that matched in key and"
},
{
"path": "requirements.txt",
"chars": 120,
"preview": "kornia==0.4.1\ntensorboard==2.3.0\ntorch==1.7.0\ntorchvision==0.8.1\ntqdm==4.51.0\nopencv-python==4.4.0.44\nonnxruntime==1.6.0"
},
{
"path": "train_base.py",
"chars": 10893,
"preview": "\"\"\"\nTrain MattingBase\n\nYou can download pretrained DeepLabV3 weights from <https://github.com/VainF/DeepLabV3Plus-Pytorc"
},
{
"path": "train_refine.py",
"chars": 13709,
"preview": "\"\"\"\nTrain MattingRefine\n\nSupports multi-GPU training with DistributedDataParallel() and SyncBatchNorm.\nSelect GPUs throu"
}
]
About this extraction
This page contains the full source code of the PeterL1n/BackgroundMattingV2 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 33 files (114.0 KB), approximately 30.7k tokens, and a symbol index with 109 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.