Repository: Luodian/MADAN
Branch: master
Commit: 7a2918da44f5
Files: 88
Total size: 237.8 KB
Directory structure:
gitextract_rf2a9vy3/
├── .gitignore
├── .idea/
│ ├── MADAN.iml
│ ├── deployment.xml
│ ├── inspectionProfiles/
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── remote-mappings.xml
│ └── vcs.xml
├── LICENSE
├── README.md
├── cycada/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── adda_datasets.py
│ │ ├── bdds.py
│ │ ├── cityscapes.py
│ │ ├── cityscapes_labels.py
│ │ ├── cyclegan.py
│ │ ├── cyclegta5.py
│ │ ├── cyclesynthia.py
│ │ ├── cyclesynthia_cyclegta5.py
│ │ ├── data_loader.py
│ │ ├── gta5.py
│ │ ├── rotater.py
│ │ ├── synthia.py
│ │ └── util.py
│ ├── logging.yml
│ ├── models/
│ │ ├── MDAN.py
│ │ ├── __init__.py
│ │ ├── adda_net.py
│ │ ├── drn.py
│ │ ├── fcn8s.py
│ │ ├── models.py
│ │ ├── task_net.py
│ │ └── util.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── train_adda_net.py
│ │ ├── train_task_net.py
│ │ └── util.py
│ ├── transforms.py
│ └── util.py
├── cyclegan/
│ ├── .gitignore
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── cityscapes.py
│ │ ├── gta5_cityscapes.py
│ │ ├── gta_synthia_cityscapes.py
│ │ ├── image_folder.py
│ │ └── synthia_cityscapes.py
│ ├── environment.yml
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── cycle_gan_model.py
│ │ ├── cycle_gan_semantic_model.py
│ │ ├── multi_cycle_gan_semantic_model.py
│ │ ├── networks.py
│ │ └── test_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── requirements.txt
├── scripts/
│ ├── ADDA/
│ │ ├── adda_cyclegta2cs_feat.sh
│ │ ├── adda_cyclegta2cs_score.sh
│ │ ├── adda_cyclesyn2cs_feat.sh
│ │ ├── adda_cyclesyn2cs_score.sh
│ │ └── adda_templates.sh
│ ├── CycleGAN/
│ │ ├── cyclegan_gta2cityscapes.sh
│ │ ├── cyclegan_gta_synthia2cityscapes.sh
│ │ ├── cyclegan_synthia2cityscapes.sh
│ │ ├── test_templates.sh
│ │ └── test_templates_cycle.sh
│ ├── FCN/
│ │ ├── train_fcn8s_cyclesgta5.sh
│ │ └── train_fcn8s_cyclesynthia.sh
│ ├── eval_fcn.py
│ ├── train_fcn.py
│ ├── train_fcn_adda.py
│ └── train_fcn_mdan.py
└── tools/
├── __init__.py
└── eval_templates.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
/models/
================================================
FILE: .idea/MADAN.iml
================================================
================================================
FILE: .idea/deployment.xml
================================================
================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
================================================
FILE: .idea/misc.xml
================================================
================================================
FILE: .idea/modules.xml
================================================
================================================
FILE: .idea/remote-mappings.xml
================================================
================================================
FILE: .idea/vcs.xml
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 liljprime
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# MADAN
A Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181)
If you use this code in your research please consider citing:
```
@InProceedings{zhao2019madan,
title = {Multi-source Domain Adaptation for Semantic Segmentation},
author = {Zhao, Sicheng and Li, Bo and Yue, Xiangyu and Gu, Yang and Xu, Pengfei and Tan, Hu, Runbo and Chai, Hua and Keutzer, Kurt},
booktitle = {Advances in Neural Information Processing Systems},
year = {2019}
}
```
## Quick Look
Our multi-source domain adaptation builds on the work [CyCADA](https://github.com/jhoffman/cycada_release) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). Since we focus on Semantic Segmentation task, we remove Digit Classfication part in CyCADA.
We add following modules and achieve startling improvements.
1. Dynamic Semantic Consistency Module
2. Adversarial Aggregation Module
1. Sub-domain Aggregation Discriminator
2. Cross-domain Cycle Discriminator
While we implements [MDAN](https://openreview.net/pdf?id=ryDNZZZAW) for Semantic Segmentation task in Pytorch as our baseline comparasion.
## Overall Structure

## Setup
Check out this repo:
```bash
git clone https://github.com/pikachusocute/MADAN.git
```
Install Python3 requirements
```bash
pip3 install -r requirements.txt
```
## Dynamic Adversarial Image Generation
We follow the way in CyCADA, in the first step, we need to train Image Adaptation module to transfer source image(GTA, Synthia or Multi-source) to "source as target".

We refer Image Adaptation module from GTA to Cityscapes as GTA->Cityscapes in the following.
#### GTA->Cityscapes
```bash
cd scripts/CycleGAN
bash cyclegan_gta2cityscapes.sh
```
In the training process, snapshot files will be stored in `cyclegan/checkpoints/[EXP_NAME]`.
Usually, afer we run for 20 epochs, there'll be a file `20_net_G_A.pth ` in previous folder path.
Then we run the test process.
```bash
bash scripts/CycleGAN/test_templates.sh [EXP_NAME] 20 cycle_gan_semantic_fcn gta5_cityscapes
```
In multi-source case, there are both `20_net_G_A_1.pth` and `20_net_G_A_2.pth` exist. We use another script to run test process.

```bash
bash scripts/CycleGAN/test_templates_cycle.sh [EXP_NAME] 20 test synthia_cityscapes gta5_cityscapes
```
New dataset will be generated at `~/cyclegan/results/[EXP_NAME]/train_20`.
After we obtain a new source stylized dataset, we then train segmenter on the new dataset.
## Pixel Level Adaptation
In this part, we train our new segmenter on new dataset.
```bash
ln -s ~/cyclegan/results/[EXP_NAME]/train_20 ~/data/cyclegta5/[EXP_NAME]_TRAIN_60
```
Then we set `dataflag = [EXP_NAME]_TRAIN_60` to find datasets' paths, and follow instructions to train segmenter to perform pixel level adaptation.
```bash
bash scripts/FCN/train_fcn8s_cyclesgta5_DSC.sh
```
## Feature Level Adaptation
For adaptation, we use
```bash
bash scripts/ADDA/adda_cyclegta2cs_score.sh
```
Make sure you choose the desired `src` and `tgt` and `datadir` before. In this process, you should load your `base_model` trained on synthetic dataset and perform adaptation in feature level to real scene dataset.
### Our Model
We release our adaptation model in the `./models`, you can use `scripts/eval_templates.sh` to evaluate its validity.
1. [CycleGTA5_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/1moGF7L2hkTHUPUzqsSxPwKNlHCHQm4Ms/view?usp=sharing)
2. [CycleSYNTHIA_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/19V5J1zyF3ct3247gSSr-u3WVkDJqQvUk/view?usp=sharing)
3. [Multi_Source_SAD_CCD](https://drive.google.com/file/d/1xgmLwhsbwv-isy7R5FkNevVSH9gcMxuq/view?usp=sharing)
### Transfered Dataset
We will release our transfer dataset soon, where our `CycleGTA5_Dynamic_Semantic_Consistency` model is trained to perform pixel level adaptation.
================================================
FILE: cycada/__init__.py
================================================
================================================
FILE: cycada/data/__init__.py
================================================
from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds
from . import adda_datasets
================================================
FILE: cycada/data/adda_datasets.py
================================================
import os.path
import torch.utils.data
from .data_loader import get_transform_dataset
from ..transforms import augment_collate
class AddaDataLoader(object):
def __init__(self, net_transform, dataset, rootdir, downscale, crop_size=None, resize=None,
batch_size=1, shuffle=False, num_workers=2, half_crop=None, src_data_flag=None, small=False):
self.dataset = dataset
self.downscale = downscale
self.resize = resize
self.crop_size = crop_size
self.half_crop = half_crop
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
assert len(self.dataset) == 2, 'Requires two datasets: source, target'
sourcedir = os.path.join(rootdir, self.dataset[0])
targetdir = os.path.join(rootdir, self.dataset[1])
self.source = get_transform_dataset(self.dataset[0], sourcedir, net_transform, downscale, resize, src_data_flag=src_data_flag, small=small)
self.target = get_transform_dataset(self.dataset[1], targetdir, net_transform, downscale, resize, small=small)
print('Source length:', len(self.source), 'Target length:', len(self.target))
self.n = max(len(self.source), len(self.target)) # make sure you see all images
self.num = 0
self.set_loader_src()
self.set_loader_tgt()
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
if self.num % len(self.iters_src) == 0:
print('restarting source dataset')
self.set_loader_src()
if self.num % len(self.iters_tgt) == 0:
print('restarting target dataset')
self.set_loader_tgt()
img_src, label_src = next(self.iters_src)
img_tgt, label_tgt = next(self.iters_tgt)
self.num += 1
return img_src, img_tgt, label_src, label_tgt
def __len__(self):
return min(len(self.source), len(self.target))
def set_loader_src(self):
batch_size = self.batch_size
shuffle = self.shuffle
num_workers = self.num_workers
if self.crop_size is not None or self.resize is not None:
collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
halfcrop=self.half_crop, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
self.loader_src = torch.utils.data.DataLoader(self.source,
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)
self.iters_src = iter(self.loader_src)
def set_loader_tgt(self):
batch_size = self.batch_size
shuffle = self.shuffle
num_workers = self.num_workers
if self.crop_size is not None or self.resize is not None:
collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
halfcrop=self.half_crop, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
self.loader_tgt = torch.utils.data.DataLoader(self.target,
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)
self.iters_tgt = iter(self.loader_tgt)
================================================
FILE: cycada/data/bdds.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import register_dataset_obj
@register_dataset_obj('bdds')
class BDDS(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
self.root = root
self.split = split
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.data_flag = data_flag
self.num_cls = num_cls
self.ids = self.collect_ids()
def collect_ids(self):
splits = []
path = os.path.join(self.root, "images", self.split)
files = os.listdir(path)
for item in files:
fip = os.path.join(path, item)
splits.append(fip.split('/')[-1])
return splits
def img_path(self, filename):
return os.path.join(self.root, "images", self.split, filename)
def label_path(self, filename):
return os.path.join(self.root, 'labels', self.split, "{}_train_id.png".format(filename[:-4]))
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cityscapes.py
================================================
import os.path
import sys
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
def remap_labels_to_train_ids(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cityscapes')
class CityScapesParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('cityscapes')
class Cityscapes(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None,
target_transform=None):
self.root = root
sys.path.append(root)
self.split = split
self.remap_labels = remap_labels
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
self.num_cls = num_cls
self.id2label = id2label
self.classes = classes
def collect_ids(self):
im_dir = os.path.join(self.root, 'leftImg8bit', self.split)
ids = []
for dirpath, dirnames, filenames in os.walk(im_dir):
for filename in filenames:
if filename.endswith('.png'):
ids.append('_'.join(filename.split('_')[:3]))
return ids
def img_path(self, id):
fmt = 'leftImg8bit/{}/{}/{}_leftImg8bit.png'
subdir = id.split('_')[0]
path = fmt.format(self.split, subdir, id)
return os.path.join(self.root, path)
def label_path(self, id):
fmt = 'gtFine/{}/{}/{}_gtFine_labelIds.png'
subdir = id.split('_')[0]
path = fmt.format(self.split, subdir, id)
return os.path.join(self.root, path)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img = Image.open(self.img_path(id)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(self.label_path(id)).convert('L')
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(np.uint8(target), 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cityscapes_labels.py
================================================
# function for colorizing a label image:
# camera-ready
import numpy as np
def label_img_to_color(img):
label_to_color = {
0: [128, 64, 128],
1: [244, 35, 232],
2: [70, 70, 70],
3: [102, 102, 156],
4: [190, 153, 153],
5: [153, 153, 153],
6: [250, 170, 30],
7: [220, 220, 0],
8: [107, 142, 35],
9: [152, 251, 152],
10: [70, 130, 180],
11: [220, 20, 60],
12: [255, 0, 0],
13: [0, 0, 142],
14: [0, 0, 70],
15: [0, 60, 100],
16: [0, 80, 100],
17: [0, 0, 230],
18: [119, 11, 32]
}
img_height, img_width = img.shape
img_color = np.zeros((img_height, img_width, 3))
for row in range(img_height):
for col in range(img_width):
label = img[row, col]
img_color[row, col] = np.array(label_to_color[label])
return img_color
================================================
FILE: cycada/data/cyclegan.py
================================================
import os
from os.path import join
import glob
from PIL import Image
import torch.utils.data as data
from .data_loader import DatasetParams
from .data_loader import register_dataset_obj, register_data_params
class CycleGANDataset(data.Dataset):
def __init__(self, root, regexp, transform=None, target_transform=None,
download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.image_paths, self.labels = self.find_images(regexp)
def find_images(self, regexp='*.png'):
basenames = sorted(glob.glob(join(self.root, regexp)))
image_paths = []
labels = []
for basename in basenames:
image_paths.append(os.path.join(self.root, basename))
labels.append(int(basename.split('/')[-1].split('_')[0]))
return image_paths, labels
def __getitem__(self, index):
im = Image.open(self.image_paths[index]) #.convert('L')
target = self.labels[index]
if self.transform is not None:
im = self.transform(im)
if self.target_transform is not None:
target = self.target_transform(target)
return im, target
def __len__(self):
return len(self.image_paths)
@register_dataset_obj('svhn2mnist')
class Svhn2MNIST(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for svhn2mnist.')
self.image_paths = []
else:
super(Svhn2MNIST, self).__init__(root, '*_fake_B.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('svhn2mnist')
class Svhn2MNISTParams(DatasetParams):
num_channels = 3
image_size = 32
mean = 0.5
std = 0.5
#mean = 0.1307
#std = 0.3081
# mean and std (when scaled between [0,1])
#mean = 0.127 # ep50
#mean = 0.21 # ep100 -- more white pixels...
#std = 0.29
#mean = 0.21
#std = 0.2
num_cls = 10
target_transform = None
@register_dataset_obj('usps2mnist')
class Usps2Mnist(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for usps2mnist.')
self.image_paths = []
else:
super(Usps2Mnist, self).__init__(root, '*_fake_A.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('usps2mnist')
class Usps2MnistParams(DatasetParams):
num_channels = 3
image_size = 16
#mean = 0.1307
#std = 0.3081
mean = 0.5
std = 0.5
num_cls = 10
target_transform = None
@register_dataset_obj('mnist2usps')
class Mnist2Usps(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for mnist2usps.')
self.image_paths = []
else:
super(Mnist2Usps, self).__init__(root, '*_fake_B.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('mnist2usps')
class Mnist2UspsParams(DatasetParams):
num_channels = 3
image_size = 16 # this seems wrong...
#mean = 0.25
#std = 0.37
#mean = 0.1307
#std = 0.3081
mean = 0.5
std = 0.5
num_cls = 10
target_transform = None
================================================
FILE: cycada/data/cyclegta5.py
================================================
import os.path
import numpy as np
from PIL import Image
from .cityscapes import remap_labels_to_train_ids
from .data_loader import register_dataset_obj
from .gta5 import GTA5 # , LABEL2TRAIN
@register_dataset_obj('cyclegta5')
class CycleGTA5(GTA5):
def collect_ids(self):
# ids = GTA5.collect_ids(self)
existing_ids = []
if self.data_flag:
path = os.path.join(self.root, self.data_flag)
else:
path = os.path.join(self.root, "images")
files = os.listdir(path)
for item in files:
full_path = os.path.join(path, item)
if os.path.exists(full_path) is False:
continue
existing_ids.append(full_path.split('/')[-1])
return sorted(existing_ids)
def __getitem__(self, index, debug=False):
filename = self.ids[index]
if self.data_flag == '' or self.data_flag is None:
img_path = os.path.join(self.root, "images", filename)
else:
img_path = os.path.join(self.root, self.data_flag, filename)
if self.data_flag == '' or self.data_flag is None:
label_path = os.path.join(self.root, 'labels_600x1080', filename)
else:
if filename.endswith('_fake_B.png'):
label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B.png', '.png'))
elif filename.endswith('_fake_B_2.png'):
label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B_2.png', '.png'))
img = Image.open(img_path).convert('RGB')
target = Image.open(label_path)
img = img.resize(target.size, resample=Image.BILINEAR)
if self.transform is not None:
img = self.transform(img)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
================================================
FILE: cycada/data/cyclesynthia.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cyclesynthia')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('cyclesynthia')
class CycleSYNTHIA(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
self.root = root.replace('cycle', '')
self.split = split
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.ids = self.collect_ids()
def collect_ids(self):
splits = []
if self.data_flag:
path = os.path.join(self.root, self.data_flag)
else:
path = os.path.join(self.root, 'Cycle')
files = os.listdir(path)
for item in files:
fip = os.path.join(path, item)
if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
return splits
def img_path(self, filename):
return os.path.join(self.root, filename)
def label_path(self, filename):
# Case for loading images generated in multi-source cycle
# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
if filename.endswith('_fake_B_1.png'):
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cyclesynthia_cyclegta5.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .cityscapes import remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cyclesynthia_cyclegta5')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
# In this class, we iteratively load transferred images from cyclesynthia and cyclegta5
@register_dataset_obj('cyclesynthia_cyclegta5')
class CycleSYNTHIACycleGTA5(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
self.dataset_name = os.path.basename(root)
self.parent_path = root.replace(self.dataset_name, '')
self.syn_name = os.path.join(self.parent_path, 'synthia')
self.gta_name = os.path.join(self.parent_path, 'cyclegta5')
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.syn_ids = self.collect_ids('syn')
self.gta_ids = self.collect_ids('gta')
def collect_ids(self, datasets_name):
splits = []
if datasets_name == 'syn':
files = os.listdir(self.syn_name)
for item in files:
fip = os.path.join(self.syn_name, item)
if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
elif datasets_name == 'gta':
files = os.listdir(self.gta_name)
for item in files:
fip = os.path.join(self.gta_name, item)
if (fip.endswith('_fake_B_2.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
else:
print("Don't Recognize {}".format(datasets_name))
return splits
def img_path(self, prefix, filename):
return os.path.join(prefix, filename)
# Case for loading images generated in multi-source cycle
# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
def syn_label_path(self, filename):
if filename.endswith('_fake_B_1.png'):
return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
def gta_label_path(self, filename):
if filename.endswith('_fake_B_2.png'):
return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B_2.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B.png', '.png'))
def __getitem__(self, index, debug=False):
# we iteratively load images from cyclesynthia and cyclegta5
if index % 2:
id = self.syn_ids[index % len(self.syn_ids)]
img_path = self.img_path(self.syn_name, id)
label_path = self.syn_label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
else:
id = self.gta_ids[index % len(self.gta_ids)]
img_path = self.img_path(self.gta_name, id)
label_path = self.gta_label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
# if debug:
# print(self.__class__.__name__)
# print("IMG Path: {}".format(img_path))
# print("Label Path: {}".format(label_path))
#
return img, target
def __len__(self):
return len(self.syn_ids) + len(self.gta_ids)
================================================
FILE: cycada/data/data_loader.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from ..util import to_tensor_raw
def load_data(name, dset, batch=64, rootdir='', num_channels=3,
image_size=32, download=True, kwargs={}):
is_train = (dset == 'train')
if isinstance(name, list) and len(name) == 2: # load adda data
src_dataset = get_dataset(name[0], join(rootdir, name[0]), dset,
image_size, num_channels, download=download)
tgt_dataset = get_dataset(name[1], join(rootdir, name[1]), dset,
image_size, num_channels, download=download)
dataset = AddaDataset(src_dataset, tgt_dataset)
else:
dataset = get_dataset(name, rootdir, dset, image_size, num_channels,
download=download)
if len(dataset) == 0:
return None
loader = torch.utils.data.DataLoader(dataset, batch_size=batch,
shuffle=is_train, **kwargs)
return loader
def get_transform_dataset(dataset_name, rootdir, net_transform, downscale, resize=None, src_data_flag=None, small=False):
user_paths = os.environ['PYTHONPATH'].split(os.pathsep)
transform, target_transform = get_transform2(dataset_name, net_transform, downscale, resize)
return get_fcn_dataset(dataset_name, rootdir, transform=transform, target_transform=target_transform, data_flag=src_data_flag, small=small)
sizes = {'cyclesynthia_cyclegta5': 1280, 'cyclesynthia': 1280, 'cityscapes': 1280, 'gta5': 1280, 'cyclegta5': 1280, "synthia": 1280}
def get_orig_size(dataset_name):
"Size of images in the dataset for relative scaling."
try:
return sizes[dataset_name]
except:
raise Exception('Unknown dataset size:', dataset_name)
def get_transform2(dataset_name, net_transform, downscale, resize):
"Returns image and label transform to downscale, crop and prepare for net."
orig_size = get_orig_size(dataset_name)
transform = []
target_transform = []
if downscale is not None:
transform.append(transforms.Resize(orig_size // downscale))
target_transform.append(transforms.Resize(orig_size // downscale, interpolation=Image.NEAREST))
if resize is not None:
transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.BICUBIC)])
target_transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.NEAREST)])
transform.extend([net_transform])
target_transform.extend([to_tensor_raw])
transform = transforms.Compose(transform)
target_transform = transforms.Compose(target_transform)
return transform, target_transform
def get_transform(params, image_size, num_channels):
# Transforms for PIL Images: Gray <-> RGB
Gray2RGB = transforms.Lambda(lambda x: x.convert('RGB'))
RGB2Gray = transforms.Lambda(lambda x: x.convert('L'))
transform = []
# Does size request match original size?
if not image_size == params.image_size:
transform.append(transforms.Resize(image_size))
# Does number of channels requested match original?
if not num_channels == params.num_channels:
if num_channels == 1:
transform.append(RGB2Gray)
elif num_channels == 3:
transform.append(Gray2RGB)
else:
print('NumChannels should be 1 or 3', num_channels)
raise Exception
transform += [transforms.ToTensor(),
transforms.Normalize((params.mean,), (params.std,))]
return transforms.Compose(transform)
def get_target_transform(params):
transform = params.target_transform
t_uniform = transforms.Lambda(lambda x: x[:, 0]
if isinstance(x, (list, np.ndarray)) and len(x) == 2 else x)
if transform is None:
return t_uniform
else:
return transforms.Compose([transform, t_uniform])
class AddaDataset(data.Dataset):
def __init__(self, src_data, tgt_data):
self.src = src_data
self.tgt = tgt_data
def __getitem__(self, index):
ns = len(self.src)
nt = len(self.tgt)
xs, ys = self.src[index % ns]
xt, yt = self.tgt[index % nt]
return (xs, ys), (xt, yt)
def __len__(self):
return min(len(self.src), len(self.tgt))
data_params = {}
def register_data_params(name):
def decorator(cls):
data_params[name] = cls
return cls
return decorator
dataset_obj = {}
def register_dataset_obj(name):
def decorator(cls):
dataset_obj[name] = cls
return cls
return decorator
class DatasetParams(object):
"Class variables defined."
num_channels = 1
image_size = 16
mean = 0.1307
std = 0.3081
num_cls = 10
target_transform = None
def get_dataset(name, rootdir, dset, image_size, num_channels, download=True):
is_train = (dset == 'train')
print('get dataset:', name, rootdir, dset)
params = data_params[name]
transform = get_transform(params, image_size, num_channels)
target_transform = get_target_transform(params)
return dataset_obj[name](rootdir, train=is_train, transform=transform,
target_transform=target_transform, download=download)
def get_fcn_dataset(name, rootdir, **kwargs):
return dataset_obj[name](rootdir, **kwargs)
================================================
FILE: cycada/data/gta5.py
================================================
import os.path
import numpy as np
import scipy.io
import torch.utils.data as data
from PIL import Image
from .cityscapes import id2label as LABEL2TRAIN, remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
@register_data_params('gta5')
class GTA5Params(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('gta5')
class GTA5(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
self.root = root
self.split = split
self.remap_labels = remap_labels
self.data_flag = data_flag
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
m = scipy.io.loadmat(os.path.join(self.root, 'mapping.mat'))
full_classes = [x[0] for x in m['classes'][0]]
self.classes = []
for old_id, new_id in LABEL2TRAIN.items():
if not new_id == 255 and old_id > 0:
self.classes.append(full_classes[old_id])
self.num_cls = num_cls
def collect_ids(self):
splits = scipy.io.loadmat(os.path.join(self.root, 'split.mat'))
ids = splits['{}Ids'.format(self.split)].squeeze()
return ids
def img_path(self, id):
filename = '{:05d}.png'.format(id)
return os.path.join(self.root, 'images', filename)
def label_path(self, id):
filename = '{:05d}.png'.format(id)
return os.path.join(self.root, 'labels', filename)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/rotater.py
================================================
class Rotater(object):
def __init__(self, dataset, orientations=6, transform=None,
target_transform=None):
self.dataset = dataset
self.orientations = orientations
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
im, target = self.dataset[index]
rotation = index % self.orientations
degrees = 360 / self.orientations * rotation
im = im.rotate(degrees)
if self.transform is not None:
im = self.transform(im)
if self.target_transform is not None:
target = self.target_transform(target)
return im, target, degrees
def __len__(self):
return len(self.dataset)
================================================
FILE: cycada/data/synthia.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('synthia')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('synthia')
class SYNTHIA(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None, small=2):
self.root = root
self.split = split
self.small = small
self.remap_labels = remap_labels
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.data_flag = data_flag
def collect_ids(self):
splits = []
with open(os.path.join(self.root, 'SYNTHIA_imagelist_{}.txt'.format(self.split))) as f:
for line in f:
line = line.strip('\n')
splits.append(line.split('/')[-1])
return splits
def img_path(self, filename):
if self.small == 0:
return os.path.join(self.root, 'RGB_300x540', filename)
elif self.small == 1:
return os.path.join(self.root, 'RGB_600x1080', filename)
else:
return os.path.join(self.root, 'RGB', filename)
def label_path(self, filename):
if self.small == 0:
return os.path.join(self.root, 'GT', 'parsed_LABELS_300x540', filename)
elif self.small == 1:
return os.path.join(self.root, 'GT', 'parsed_LABELS_600x1080', filename)
else:
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
if debug:
print(self.__class__.__name__)
print("IMG Path: {}".format(img_path))
print("Label Path: {}".format(label_path))
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/util.py
================================================
import logging
import os.path
import requests
logger = logging.getLogger(__name__)
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
def maybe_download(url, dest):
"""Download the url to dest if necessary, optionally checking file
integrity.
"""
if not os.path.exists(dest):
logger.info('Downloading %s to %s', url, dest)
download(url, dest)
def download(url, dest):
"""Download the url to dest, overwriting dest if it already exists."""
response = requests.get(url, stream=True)
with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
================================================
FILE: cycada/logging.yml
================================================
---
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "[%(asctime)s] %(levelname)-8s %(message)s"
color:
class: colorlog.ColoredFormatter
format: "[%(asctime)s] %(log_color)s%(levelname)-8s%(reset)s %(message)s"
log_colors:
DEBUG: "cyan"
INFO: "green"
WARNING: "yellow"
ERROR: "red"
CRITICAL: "red,bg_white"
handlers:
console:
class: cycada.util.TqdmHandler
level: INFO
formatter: color
file_handler:
class: logging.FileHandler
level: INFO
formatter: simple
encoding: utf8
root:
level: INFO
handlers: [console, file_handler]
================================================
FILE: cycada/models/MDAN.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class GradientReversalLayer(torch.autograd.Function):
"""
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
The forward part is the identity function while the backward part is the negative function.
"""
def forward(self, inputs):
return inputs
def backward(self, grad_output):
grad_input = grad_output.clone()
grad_input = -grad_input
return grad_input
class MDANet(nn.Module):
"""
Multi-layer perceptron with adversarial regularizer by domain classification.
"""
def __init__(self, configs):
super(MDANet, self).__init__()
self.pooling_layer = nn.AdaptiveAvgPool2d((2, 2))
self.dim_reduction = nn.Conv2d(4096, 512, kernel_size=1)
nn.init.xavier_normal_(self.dim_reduction.weight)
nn.init.constant_(self.dim_reduction.bias, 0.1)
self.input_dim = configs["input_dim"]
self.num_hidden_layers = len(configs["hidden_layers"])
self.num_neurons = [] + [self.input_dim] + configs["hidden_layers"]
self.num_domains = configs["num_domains"]
# Parameters of hidden, fully-connected layers, feature learning component.
self.hiddens = nn.ModuleList([nn.Linear(self.num_neurons[i], self.num_neurons[i + 1])
for i in range(self.num_hidden_layers)])
# Parameter of the final softmax classification layer.
self.softmax = nn.Linear(self.num_neurons[-1], configs["num_classes"])
# Parameter of the domain classification layer, multiple sources single target domain adaptation.
self.domains = nn.ModuleList([nn.Linear(self.num_neurons[-1], 2) for _ in range(self.num_domains)])
# Gradient reversal layer.
self.grls = [GradientReversalLayer() for _ in range(self.num_domains)]
def forward(self, sinputs_syn, sinputs_gta, tinputs):
"""
:param sinputs: A list of k inputs from k source domains.
:param tinputs: Input from the target domain.
:return:
"""
sinputs_gta = self.pooling_layer(sinputs_gta)
sinputs_syn = self.pooling_layer(sinputs_syn)
tinputs = self.pooling_layer(tinputs)
sinputs_gta = self.dim_reduction(sinputs_gta)
sinputs_syn = self.dim_reduction(sinputs_syn)
tinputs = self.dim_reduction(tinputs)
b = sinputs_gta.size()[0]
syn_relu, gta_relu, th_relu = sinputs_syn.view(b, -1), sinputs_gta.view(b, -1), tinputs.view(b, -1)
assert (syn_relu[0].size()[0] == self.input_dim)
for hidden in self.hiddens:
syn_relu = F.relu(hidden(syn_relu))
gta_relu = F.relu(hidden(gta_relu))
for hidden in self.hiddens:
th_relu = F.relu(hidden(th_relu))
# Classification probabilities on k source domains.
logprobs = []
logprobs.append(F.log_softmax(self.softmax(syn_relu), dim=1))
logprobs.append(F.log_softmax(self.softmax(gta_relu), dim=1))
# Domain classification accuracies.
sdomains, tdomains = [], []
sdomains.append(F.log_softmax(self.domains[0](self.grls[0](syn_relu)), dim=1))
tdomains.append(F.log_softmax(self.domains[0](self.grls[0](th_relu)), dim=1))
sdomains.append(F.log_softmax(self.domains[1](self.grls[1](gta_relu)), dim=1))
tdomains.append(F.log_softmax(self.domains[1](self.grls[1](th_relu)), dim=1))
return logprobs, sdomains, tdomains
def inference(self, inputs):
h_relu = inputs
for hidden in self.hiddens:
h_relu = F.relu(hidden(h_relu))
# Classification probability.
logprobs = F.log_softmax(self.softmax(h_relu), dim=1)
return logprobs
================================================
FILE: cycada/models/__init__.py
================================================
from .models import get_model
from .task_net import LeNet
from .task_net import DTNClassifier
from .adda_net import AddaNet
from .fcn8s import VGG16_FCN8s, Discriminator
from .drn import drn26
================================================
FILE: cycada/models/adda_net.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from .util import init_weights
from .models import register_model, get_model
@register_model('AddaNet')
class AddaNet(nn.Module):
"Defines and Adda Network."
def __init__(self, num_cls=10, model='LeNet', src_weights_init=None,
weights_init=None):
super(AddaNet, self).__init__()
self.name = 'AddaNet'
self.base_model = model
self.num_cls = num_cls
self.cls_criterion = nn.CrossEntropyLoss()
self.gan_criterion = nn.CrossEntropyLoss()
self.setup_net()
if weights_init is not None:
self.load(weights_init)
elif src_weights_init is not None:
self.load_src_net(src_weights_init)
else:
raise Exception('AddaNet must be initialized with weights.')
def forward(self, x_s, x_t):
"""Pass source and target images through their
respective networks."""
score_s, x_s = self.src_net(x_s, with_ft=True)
score_t, x_t = self.tgt_net(x_t, with_ft=True)
if self.discrim_feat:
d_s = self.discriminator(x_s)
d_t = self.discriminator(x_t)
else:
d_s = self.discriminator(score_s)
d_t = self.discriminator(score_t)
return score_s, score_t, d_s, d_t
def setup_net(self):
"""Setup source, target and discriminator networks."""
self.src_net = get_model(self.base_model, num_cls=self.num_cls)
self.tgt_net = get_model(self.base_model, num_cls=self.num_cls)
input_dim = self.num_cls
self.discriminator = nn.Sequential(
nn.Linear(input_dim, 500),
nn.ReLU(),
nn.Linear(500, 500),
nn.ReLU(),
nn.Linear(500, 2),
)
self.image_size = self.src_net.image_size
self.num_channels = self.src_net.num_channels
def load(self, init_path):
"Loads full src and tgt models."
net_init_dict = torch.load(init_path)
self.load_state_dict(net_init_dict)
def load_src_net(self, init_path):
"""Initialize source and target with source
weights."""
self.src_net.load(init_path)
self.tgt_net.load(init_path)
def save(self, out_path):
torch.save(self.state_dict(), out_path)
def save_tgt_net(self, out_path):
torch.save(self.tgt_net.state_dict(), out_path)
================================================
FILE: cycada/models/drn.py
================================================
import math
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torchvision
from .models import register_model
from ..util import safe_load_state_dict
__all__ = ['DRN', 'drn26', 'drn42', 'drn58']
model_urls = {
'drn26': 'https://tigress-web.princeton.edu/~fy/drn/models/drn26-ddedf421.pth',
'drn42': 'https://tigress-web.princeton.edu/~fy/drn/models/drn42-9d336e8c.pth',
'drn58': 'https://tigress-web.princeton.edu/~fy/drn/models/drn58-0a53a92c.pth'
}
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=padding, bias=False, dilation=dilation)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
dilation=(1, 1), residual=True):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride,
padding=dilation[0], dilation=dilation[0])
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes,
padding=dilation[1], dilation=dilation[1])
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.residual = residual
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
if self.residual:
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
dilation=(1, 1), residual=True):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation[1], bias=False,
dilation=dilation[1])
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class DRN(nn.Module):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __init__(self, block, layers, num_cls=1000,
channels=(16, 32, 64, 128, 256, 512, 512, 512),
out_map=False, out_middle=False, pool_size=28,
weights_init=None, pretrained=True, finetune=False,
output_last_ft=False, modelname='drn26'):
if output_last_ft:
print('DRN discrim feat not implemented, using scores')
super(DRN, self).__init__()
self.inplanes = channels[0]
self.output_last_ft = output_last_ft
self.out_map = out_map
self.out_dim = channels[-1]
self.out_middle = out_middle
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(channels[0])
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1)
self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2)
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)
self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2,
new_level=False)
self.layer6 = None if layers[5] == 0 else \
self._make_layer(block, channels[5], layers[5], dilation=4,
new_level=False)
self.layer7 = None if layers[6] == 0 else \
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
new_level=False, residual=False)
self.layer8 = None if layers[7] == 0 else \
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
new_level=False, residual=False)
if num_cls > 0:
self.avgpool = nn.AvgPool2d(pool_size)
# self.fc = nn.Linear(self.out_dim, num_classes)
self.fc = nn.Conv2d(self.out_dim, num_cls, kernel_size=1,
stride=1, padding=0, bias=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if pretrained:
if not weights_init is None:
state_dict = torch.load(weights_init)
print('Using state dict from', weights_init)
else:
state_dict = model_zoo.load_url(model_urls[modelname])
if finetune:
del state_dict['fc.weight']
del state_dict['fc.bias']
safe_load_state_dict(self, state_dict)
print('Finetune: remove last layer')
else:
self.load_state_dict(state_dict)
print('Loading full model')
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
new_level=True, residual=True):
assert dilation == 1 or dilation % 2 == 0
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(
self.inplanes, planes, stride, downsample,
dilation=(1, 1) if dilation == 1 else (
dilation // 2 if new_level else dilation, dilation),
residual=residual))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, residual=residual,
dilation=(dilation, dilation)))
return nn.Sequential(*layers)
def forward(self, x):
_, _, h, w = x.size()
y = list()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
y.append(x)
x = self.layer2(x)
y.append(x)
x = self.layer3(x)
y.append(x)
x = self.layer4(x)
y.append(x)
x = self.layer5(x)
y.append(x)
if self.layer6 is not None:
x = self.layer6(x)
y.append(x)
if self.layer7 is not None:
x = self.layer7(x)
y.append(x)
if self.layer8 is not None:
x = self.layer8(x)
y.append(x)
if self.output_last_ft:
ft_to_save = x
if self.out_map:
x = self.fc(x)
x = nn.functional.interpolate(x, (h, w), mode='bilinear', align_corners=True)
else:
x = self.avgpool(x)
x = self.fc(x)
x = x.view(x.size(0), -1)
if self.out_middle:
return x, y
elif self.output_last_ft:
return x, ft_to_save
else:
return x
@register_model('drn26')
def drn26(pretrained=True, finetune=False, out_map=True, **kwargs):
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], modelname='drn26',
out_map=out_map, finetune=finetune, **kwargs)
# if pretrained:
# state_dict = model_zoo.load_url(model_urls['drn26'])
# if finetune:
# del state_dict['fc.weight']
# del state_dict['fc.bias']
# safe_load_state_dict(model, state_dict)
# else:
# model.load_state_dict(state_dict)
return model
@register_model('drn42')
def drn42(pretrained=False, finetune=False, out_map=True, **kwargs):
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], modelname='drn42',
out_map=out_map, finetune=finetune, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['drn42']))
return model
def drn58(pretrained=False, **kwargs):
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['drn58']))
return model
================================================
FILE: cycada/models/fcn8s.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.nn import init
from torch.utils import model_zoo
from torchvision.models import vgg
from .models import register_model
def get_upsample_filter(size):
"""Make a 2D bilinear kernel suitable for upsampling"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
filter = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
return torch.from_numpy(filter).float()
class Bilinear(nn.Module):
def __init__(self, factor, num_channels):
super().__init__()
self.factor = factor
filter = get_upsample_filter(factor * 2)
w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2)
for i in range(num_channels):
w[i, i] = filter
self.register_buffer('w', w)
def forward(self, x):
return F.conv_transpose2d(x, Variable(self.w), stride=self.factor)
@register_model('fcn8s')
class VGG16_FCN8s(nn.Module):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __init__(self, num_cls=19, pretrained=True, weights_init=None,
output_last_ft=False):
super().__init__()
self.output_last_ft = output_last_ft
if weights_init:
batch_norm = False
else:
batch_norm = True
self.vgg = make_layers(vgg.cfg['D'], batch_norm=False)
self.vgg_head = nn.Sequential(
nn.Conv2d(512, 4096, 7),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.5),
nn.Conv2d(4096, 4096, 1),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.5),
nn.Conv2d(4096, num_cls, 1)
)
self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls)
self.upscore8 = Bilinear(8, num_cls)
self.score_pool4 = nn.Conv2d(512, num_cls, 1)
for param in self.score_pool4.parameters():
# init.constant(param, 0)
init.constant_(param, 0)
self.score_pool3 = nn.Conv2d(256, num_cls, 1)
for param in self.score_pool3.parameters():
# init.constant(param, 0)
init.constant_(param, 0)
if pretrained:
if weights_init is not None:
self.load_weights(torch.load(weights_init))
else:
self.load_base_weights()
def load_base_vgg(self, weights_state_dict):
vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.')
self.vgg.load_state_dict(vgg_state_dict)
def load_vgg_head(self, weights_state_dict):
vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.')
self.vgg_head.load_state_dict(vgg_head_state_dict)
def get_dict_by_prefix(self, weights_state_dict, prefix):
return {k[len(prefix):]: v
for k, v in weights_state_dict.items()
if k.startswith(prefix)}
def load_weights(self, weights_state_dict):
self.load_base_vgg(weights_state_dict)
self.load_vgg_head(weights_state_dict)
def split_vgg_head(self):
self.classifier = list(self.vgg_head.children())[-1]
self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1])
def forward(self, x):
input = x
x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0)
intermediates = {}
fts_to_save = {16: 'pool3', 23: 'pool4'}
for i, module in enumerate(self.vgg):
x = module(x)
if i in fts_to_save:
intermediates[fts_to_save[i]] = x
ft_to_save = 5 # Dropout before classifier
last_ft = {}
for i, module in enumerate(self.vgg_head):
x = module(x)
if i == ft_to_save:
last_ft = x
_, _, h, w = x.size()
upscore2 = self.upscore2(x)
pool4 = intermediates['pool4']
score_pool4 = self.score_pool4(0.01 * pool4)
score_pool4c = _crop(score_pool4, upscore2, offset=5)
fuse_pool4 = upscore2 + score_pool4c
upscore_pool4 = self.upscore_pool4(fuse_pool4)
pool3 = intermediates['pool3']
score_pool3 = self.score_pool3(0.0001 * pool3)
score_pool3c = _crop(score_pool3, upscore_pool4, offset=9)
fuse_pool3 = upscore_pool4 + score_pool3c
upscore8 = self.upscore8(fuse_pool3)
score = _crop(upscore8, input, offset=31)
if self.output_last_ft:
return score, last_ft
else:
return score
def load_base_weights(self):
"""This is complicated because we converted the base model to be fully
convolutional, so some surgery needs to happen here."""
base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16'])
vgg_state_dict = {k[len('features.'):]: v
for k, v in base_state_dict.items()
if k.startswith('features.')}
self.vgg.load_state_dict(vgg_state_dict)
vgg_head_params = self.vgg_head.parameters()
for k, v in base_state_dict.items():
if not k.startswith('classifier.'):
continue
if k.startswith('classifier.6.'):
# skip final classifier output
continue
vgg_head_param = next(vgg_head_params)
vgg_head_param.data = v.view(vgg_head_param.size())
class VGG16_FCN8s_caffe(VGG16_FCN8s):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.458, 0.408],
std=[0.00392156862745098] * 3),
torchvision.transforms.Lambda(
lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1))
])
def load_base_weights(self):
base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth')
vgg_state_dict = {k[len('features.'):]: v
for k, v in base_state_dict.items()
if k.startswith('features.')}
self.vgg.load_state_dict(vgg_state_dict)
vgg_head_params = self.vgg_head.parameters()
for k, v in base_state_dict.items():
if not k.startswith('classifier.'):
continue
if k.startswith('classifier.6.'):
# skip final classifier output
continue
vgg_head_param = next(vgg_head_params)
vgg_head_param.data = v.view(vgg_head_param.size())
class Discriminator(nn.Module):
def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''):
super().__init__()
dim1 = 1024 if input_dim == 4096 else 512
dim2 = int(dim1 / 2)
self.D = nn.Sequential(
nn.Conv2d(input_dim, dim1, 1),
nn.Dropout2d(p=0.5),
nn.ReLU(inplace=True),
nn.Conv2d(dim1, dim2, 1),
nn.Dropout2d(p=0.5),
nn.ReLU(inplace=True),
nn.Conv2d(dim2, output_dim, 1)
)
if pretrained and weights_init is not None:
self.load_weights(weights_init)
def forward(self, x):
d_score = self.D(x)
return d_score
def load_weights(self, weights):
print('Loading discriminator weights')
self.load_state_dict(torch.load(weights))
class Transform_Module(nn.Module):
def __init__(self, input_dim=4096):
super().__init__()
self.transform = nn.Sequential(
nn.Conv2d(input_dim, input_dim, 1),
nn.ReLU(inplace=True),
# nn.Conv2d(input_dim, input_dim, 1),
# nn.ReLU(inplace=True),
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_eye(m.weight)
m.bias.data.zero_()
def forward(self, x):
t_x = self.transform(x)
return t_x
def init_eye(tensor):
if isinstance(tensor, Variable):
init_eye(tensor.data)
return tensor
return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))
def _crop(input, shape, offset=0):
_, _, h, w = shape.size()
return input[:, :, offset:offset + h, offset:offset + w].contiguous()
def make_layers(cfg, batch_norm=False):
"""This is almost verbatim from torchvision.models.vgg, except that the
MaxPool2d modules are configured with ceil_mode=True.
"""
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
modules = [conv2d, nn.ReLU(inplace=True)]
if batch_norm:
modules.insert(1, nn.BatchNorm2d(v))
layers.extend(modules)
in_channels = v
return nn.Sequential(*layers)
================================================
FILE: cycada/models/models.py
================================================
import torch
models = {}
def register_model(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def get_model(name, num_cls=10, **args):
net = models[name](num_cls=num_cls, **args)
if torch.cuda.is_available():
net = net.cuda()
return net
================================================
FILE: cycada/models/task_net.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
from .models import register_model
from .util import init_weights
import numpy as np
class TaskNet(nn.Module):
num_channels = 3
image_size = 32
name = 'TaskNet'
"Basic class which does classification."
def __init__(self, num_cls=10, weights_init=None):
super(TaskNet, self).__init__()
self.num_cls = num_cls
self.setup_net()
self.criterion = nn.CrossEntropyLoss()
if weights_init is not None:
self.load(weights_init)
else:
init_weights(self)
def forward(self, x, with_ft=False):
x = self.conv_params(x)
x = x.view(x.size(0), -1)
x = self.fc_params(x)
score = self.classifier(x)
if with_ft:
return score, x
else:
return score
def setup_net(self):
"""Method to be implemented in each class."""
pass
def load(self, init_path):
net_init_dict = torch.load(init_path)
self.load_state_dict(net_init_dict)
def save(self, out_path):
torch.save(self.state_dict(), out_path)
@register_model('LeNet')
class LeNet(TaskNet):
"Network used for MNIST or USPS experiments."
num_channels = 1
image_size = 28
name = 'LeNet'
out_dim = 500 # dim of last feature layer
def setup_net(self):
self.conv_params = nn.Sequential(
nn.Conv2d(self.num_channels, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(20, 50, kernel_size=5),
nn.Dropout2d(p=0.5),
nn.MaxPool2d(2),
nn.ReLU(),
)
self.fc_params = nn.Linear(50*4*4, 500)
self.classifier = nn.Sequential(
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(500, self.num_cls)
)
@register_model('DTN')
class DTNClassifier(TaskNet):
"Classifier used for SVHN->MNIST Experiment"
num_channels = 3
image_size = 32
name = 'DTN'
out_dim = 512 # dim of last feature layer
def setup_net(self):
self.conv_params = nn.Sequential (
nn.Conv2d(self.num_channels, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.Dropout2d(0.1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(128),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(256),
nn.Dropout2d(0.5),
nn.ReLU()
)
self.fc_params = nn.Sequential (
nn.Linear(256*4*4, 512),
nn.BatchNorm1d(512),
)
self.classifier = nn.Sequential(
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, self.num_cls)
)
================================================
FILE: cycada/models/util.py
================================================
import torch.nn as nn
from torch.nn import init
def init_weights(obj):
for m in obj.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.reset_parameters()
================================================
FILE: cycada/tools/__init__.py
================================================
================================================
FILE: cycada/tools/train_adda_net.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
# Import from torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# Import from within Package
from ..models.models import get_model
from ..data.data_loader import load_data
from ..tools.test_task_net import test
from ..tools.util import make_variable
def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch):
log_interval = 100 # specifies how often to display
N = min(len(loader_src.dataset), len(loader_tgt.dataset))
joint_loader = zip(loader_src, loader_tgt)
net.train()
last_update = -1
for batch_idx, ((data_s, _), (data_t, _)) in enumerate(joint_loader):
# log basic adda train info
info_str = "[Train Adda] Epoch: {} [{}/{} ({:.2f}%)]".format(
epoch, batch_idx*len(data_t), N, 100 * batch_idx / N)
########################
# Setup data variables #
########################
data_s = make_variable(data_s, requires_grad=False)
data_t = make_variable(data_t, requires_grad=False)
##########################
# Optimize discriminator #
##########################
# zero gradients for optimizer
opt_dis.zero_grad()
# extract and concat features
score_s = net.src_net(data_s)
score_t = net.tgt_net(data_t)
f = torch.cat((score_s, score_t), 0)
# predict with discriminator
pred_concat = net.discriminator(f)
# prepare real and fake labels: source=1, target=0
target_dom_s = make_variable(torch.ones(len(data_s)).long(), requires_grad=False)
target_dom_t = make_variable(torch.zeros(len(data_t)).long(), requires_grad=False)
label_concat = torch.cat((target_dom_s, target_dom_t), 0)
# compute loss for disciminator
loss_dis = net.gan_criterion(pred_concat, label_concat)
loss_dis.backward()
# optimize discriminator
opt_dis.step()
# compute discriminator acc
pred_dis = torch.squeeze(pred_concat.max(1)[1])
acc = (pred_dis == label_concat).float().mean()
# log discriminator update info
info_str += " acc: {:0.1f} D: {:.3f}".format(acc.item()*100, loss_dis.item())
###########################
# Optimize target network #
###########################
# only update net if discriminator is strong
if acc.item() > 0.6:
last_update = batch_idx
# zero out optimizer gradients
opt_dis.zero_grad()
opt_net.zero_grad()
# extract target features
score_t = net.tgt_net(data_t)
# predict with discriinator
pred_tgt = net.discriminator(score_t)
# create fake label
label_tgt = make_variable(torch.ones(pred_tgt.size(0)).long(), requires_grad=False)
# compute loss for target network
loss_gan_t = net.gan_criterion(pred_tgt, label_tgt)
loss_gan_t.backward()
# optimize tgt network
opt_net.step()
# log net update info
info_str += " G: {:.3f}".format(loss_gan_t.item())
###########
# Logging #
###########
if batch_idx % log_interval == 0:
print(info_str)
return last_update
def train_adda(src, tgt, model, num_cls, num_epoch=200,
batch=128, datadir="", outdir="",
src_weights=None, weights=None, lr=1e-5, betas=(0.9,0.999),
weight_decay=0):
"""Main function for training ADDA."""
###########################
# Setup cuda and networks #
###########################
# setup cuda
if torch.cuda.is_available():
kwargs = {'num_workers': 1, 'pin_memory': True}
else:
kwargs = {}
# setup network
net = get_model('AddaNet', model=model, num_cls=num_cls,
src_weights_init=src_weights)
# print network and arguments
print(net)
print('Training Adda {} model for {}->{}'.format(model, src, tgt))
#######################################
# Setup data for training and testing #
#######################################
train_src_data = load_data(src, 'train', batch=batch,
rootdir=join(datadir, src), num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
train_tgt_data = load_data(tgt, 'train', batch=batch,
rootdir=join(datadir, tgt), num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
######################
# Optimization setup #
######################
net_param = net.tgt_net.parameters()
opt_net = optim.Adam(net_param, lr=lr, weight_decay=weight_decay, betas=betas)
opt_dis = optim.Adam(net.discriminator.parameters(), lr=lr,
weight_decay=weight_decay, betas=betas)
##############
# Train Adda #
##############
for epoch in range(num_epoch):
err = train(train_src_data, train_tgt_data, net, opt_net, opt_dis, epoch)
if err == -1:
print("No suitable discriminator")
break
##############
# Save Model #
##############
os.makedirs(outdir, exist_ok=True)
outfile = join(outdir, 'adda_{:s}_net_{:s}_{:s}.pth'.format(
model, src, tgt))
print('Saving to', outfile)
net.save(outfile)
================================================
FILE: cycada/tools/train_task_net.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
import argparse
# Import from torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# Import from Cycada Package
from ..models.models import get_model
from ..data.data_loader import load_data
from .test_task_net import test
from .util import make_variable
def train_epoch(loader, net, opt_net, epoch):
log_interval = 100 # specifies how often to display
net.train()
for batch_idx, (data, target) in enumerate(loader):
# make data variables
data = make_variable(data, requires_grad=False)
target = make_variable(target, requires_grad=False)
# zero out gradients
opt_net.zero_grad()
# forward pass
score = net(data)
loss = net.criterion(score, target)
# backward pass
loss.backward()
# optimize classifier and representation
opt_net.step()
# Logging
if batch_idx % log_interval == 0:
print('[Train] Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(loader.dataset),
100. * batch_idx / len(loader), loss.item()), end="")
pred = score.data.max(1)[1]
correct = pred.eq(target.data).cpu().sum()
acc = correct.item() / len(pred) * 100.0
print(' Acc: {:.2f}'.format(acc))
def train(data, datadir, model, num_cls, outdir='',
num_epoch=100, batch=128,
lr=1e-4, betas=(0.9, 0.999), weight_decay=0):
"""Train a classification net and evaluate on test set."""
# Setup GPU Usage
if torch.cuda.is_available():
kwargs = {'num_workers': 1, 'pin_memory': True}
else:
kwargs = {}
############
# Load Net #
############
net = get_model(model, num_cls=num_cls)
print('-------Training net--------')
print(net)
############################
# Load train and test data #
############################
train_data = load_data(data, 'train', batch=batch,
rootdir=datadir, num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
test_data = load_data(data, 'test', batch=batch,
rootdir=datadir, num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
###################
# Setup Optimizer #
###################
opt_net = optim.Adam(net.parameters(), lr=lr, betas=betas,
weight_decay=weight_decay)
#########
# Train #
#########
print('Training {} model for {}'.format(model, data))
for epoch in range(num_epoch):
train_epoch(train_data, net, opt_net, epoch)
########
# Test #
########
if test_data is not None:
print('Evaluating {}-{} model on {} test set'.format(model, data, data))
test(test_data, net)
############
# Save net #
############
os.makedirs(outdir, exist_ok=True)
outfile = join(outdir, '{:s}_net_{:s}.pth'.format(model, data))
print('Saving to', outfile)
net.save(outfile)
return net
================================================
FILE: cycada/tools/util.py
================================================
from functools import partial
import torch
from torch.autograd import Variable
def make_variable(tensor, volatile=False, requires_grad=True):
if torch.cuda.is_available():
tensor = tensor.cuda()
if volatile:
requires_grad = False
return Variable(tensor, volatile=volatile, requires_grad=requires_grad)
def pairwise_distance(x, y):
if not len(x.shape) == len(y.shape):
raise ValueError('Both inputs should be matrices.')
if x.shape[1] != y.shape[1]:
raise ValueError('The number of features should be the same.')
x = x.view(x.shape[0], x.shape[1], 1)
y = torch.transpose(y, 0, 1)
output = torch.sum((x - y) ** 2, 1)
output = torch.transpose(output, 0, 1)
return output
def gaussian_kernel_matrix(x, y, sigmas):
sigmas = sigmas.view(sigmas.shape[0], 1)
beta = 1. / (2. * sigmas)
dist = pairwise_distance(x, y).contiguous()
dist_ = dist.view(1, -1)
s = torch.matmul(beta, dist_)
return torch.sum(torch.exp(-s), 0).view_as(dist)
def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
cost = torch.mean(kernel(x, x))
cost += torch.mean(kernel(y, y))
cost -= 2 * torch.mean(kernel(x, y))
return cost
def mmd_loss(source_features, target_features):
sigmas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
gaussian_kernel = partial(
gaussian_kernel_matrix, sigmas=Variable(torch.cuda.FloatTensor(sigmas))
)
loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel)
loss_value = loss_value
return loss_value
================================================
FILE: cycada/transforms.py
================================================
"""These random transforms extend the transforms provided in torchvision to
allow for transforming multiple images at the same time. This ensures that the
images receive the same transformation, e.g. the provided images are either all
mirrored or all left unchanged.
For example, this is useful in segmentation tasks, where a transformation to the
image necessitates that same transformation on the label.
"""
import numbers
import random
import torch
import torchvision
class RandomCrop(object):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, tensors):
output = []
h, w = None, None
th, tw = self.size
for tensor in tensors:
if h is None and w is None:
_, h, w = tensor.size()
elif tensor.size()[-2:] != (h, w):
print(tensor.size(), (h, w))
raise ValueError('Images must be same size')
if w == tw and h == th:
return tensors
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for tensor in tensors:
output.append(tensor[..., y1:y1 + th, x1:x1 + tw].contiguous())
return output
class HalfCrop(object):
"""Crops halt the given PIL.Image randomly takes left or right to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __call__(self, tensors):
output = []
th, tw = self.size
tw_half = tw // 2
left_side = random.randint(0, 1)
x1 = 0 + left_size * tw_half # random.randint(0, w - tw)
for tensor in tensors:
output.append(tensor[..., ..., x1:x1 + tw_half].contiguous())
return output
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __call__(self, tensors):
if random.random() < 0.5:
output = []
for tensor in tensors:
indices = torch.arange(tensor.size(-1) - 1, -1, -1).long()
output.append(tensor.index_select(-1, indices))
return output
return tensors
def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=None):
transforms = []
if crop is not None:
transforms.append(RandomCrop(crop))
if halfcrop is not None:
transforms.append(HalfCrop())
if flip:
transforms.append(RandomHorizontalFlip())
transform = torchvision.transforms.Compose(transforms)
batch = [transform(x) for x in batch]
return torch.utils.data.dataloader.default_collate(batch)
================================================
FILE: cycada/util.py
================================================
import logging
import logging.config
import os.path
from collections import OrderedDict
import numpy as np
import torch
import yaml
from torch.nn.parameter import Parameter
from tqdm import tqdm
class TqdmHandler(logging.StreamHandler):
def __init__(self):
logging.StreamHandler.__init__(self)
def emit(self, record):
msg = self.format(record)
tqdm.write(msg)
def config_logging(logfile=None):
path = os.path.join(os.path.dirname(__file__), 'logging.yml')
with open(path, 'r') as f:
config = yaml.load(f.read())
if logfile is None:
del config['handlers']['file_handler']
del config['root']['handlers'][-1]
else:
config['handlers']['file_handler']['filename'] = logfile
logging.config.dictConfig(config)
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
def safe_load_state_dict(net, state_dict):
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. Any params in :attr:`state_dict`
that do not match the keys returned by :attr:`net`'s :func:`state_dict()`
method or have differing sizes are skipped.
Arguments:
state_dict (dict): A dict containing parameters and
persistent buffers.
"""
own_state = net.state_dict()
skipped = []
for name, param in state_dict.items():
if name not in own_state:
skipped.append(name)
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
if own_state[name].size() != param.size():
skipped.append(name)
continue
own_state[name].copy_(param)
if skipped:
logging.info('Skipped loading some parameters: {}'.format(skipped))
def step_lr(optimizer, mult):
for param_group in optimizer.param_groups:
lr = param_group['lr']
param_group['lr'] = lr * mult
================================================
FILE: cyclegan/.gitignore
================================================
.DS_Store
debug*
checkpoints/
results/
build/
dist/
*.png
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
.idea
================================================
FILE: cyclegan/data/__init__.py
================================================
import sys
import torch.utils.data
from data.base_data_loader import BaseDataLoader
sys.path.append('/nfs/project/libo_i/MADAN')
from cycada.transforms import augment_collate
def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'synthia_cityscapes':
from data.synthia_cityscapes import SynthiaCityscapesDataset
dataset = SynthiaCityscapesDataset()
elif opt.dataset_mode == 'gta5_cityscapes':
from data.gta5_cityscapes import GTAVCityscapesDataset
dataset = GTAVCityscapesDataset()
elif opt.dataset_mode == 'gta_synthia_cityscapes':
from data.gta_synthia_cityscapes import GTASynthiaCityscapesDataset
dataset = GTASynthiaCityscapesDataset()
elif opt.dataset_mode == 'merged_gta_synthia_cityscapes':
from data.merged_gta_synthia_cityscapes import MergedGTASynthiaCityscapesDataset
dataset = MergedGTASynthiaCityscapesDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: cyclegan/data/base_data_loader.py
================================================
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
================================================
FILE: cyclegan/data/base_dataset.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
# TODO: 增加crop的部分
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [int(opt.loadSize), int(opt.loadSize)]
transform_list.append(transforms.Resize(osize, interpolation=Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.resize_or_crop == 'resize_only':
osize = [int(opt.loadSize), int(opt.loadSize)]
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def get_label_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'resize_only':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
transform_list.append(transforms.RandomCrop(opt.fineSize))
# transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.Lambda(lambda img: to_tensor_raw(img)))
return transforms.Compose(transform_list)
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
================================================
FILE: cyclegan/data/cityscapes.py
================================================
import numpy as np
ignore_label = 255
id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle']
def remap_labels_to_train_ids(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
================================================
FILE: cyclegan/data/gta5_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
# This dataset is used to conduct GTA->CityScapes images transfer procedure.
class GTAVCityscapesDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot, 'gta5', 'images')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label = os.path.join(opt.dataroot, 'gta5', 'labels')
self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
self.A_paths = make_dataset(self.dir_A)
self.B_paths = make_dataset(self.dir_B)
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.A_labels = make_dataset(self.dir_A_label)
self.B_labels = make_cs_labels(self.dir_B_label)
self.A_labels = sorted(self.A_labels)
self.B_labels = sorted(self.B_labels)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path = self.A_labels[index % self.A_size]
B_label_path = self.B_labels[index_B]
A_label = Image.open(A_label_path)
B_label = Image.open(B_label_path)
A_label = np.asarray(A_label)
A_label = remap_labels_to_train_ids(A_label)
A_label = Image.fromarray(A_label, 'L')
B_label = np.asarray(B_label)
B_label = remap_labels_to_train_ids(B_label)
B_label = Image.fromarray(B_label, 'L')
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A = self.transform(A_img)
B = self.transform(B_img)
A_label = self.label_transform(A_label)
B_label = self.label_transform(B_label)
# print(A_label.unique())
# print(B_label.unique())
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B,
'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
def __len__(self):
return max(self.A_size, self.B_size)
def name(self):
return 'GTA5_Cityscapes'
================================================
FILE: cyclegan/data/gta_synthia_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
# This dataset is used to conduct double cyclegan for both GTAV->CityScapes and Synthia->CityScapes
class GTASynthiaCityscapesDataset(BaseDataset):
def initialize(self, opt):
# SYNTHIA as dataset 1
# GTAV as dataset 2
self.opt = opt
self.root = opt.dataroot
self.dir_A_1 = os.path.join(opt.dataroot, 'synthia', 'RGB')
self.dir_A_2 = os.path.join(opt.dataroot, 'gta5', 'images')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label_1 = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
self.dir_A_label_2 = os.path.join(opt.dataroot, 'gta5', 'labels')
self.A_paths_1 = make_dataset(self.dir_A_1)
self.A_paths_2 = make_dataset(self.dir_A_2)
self.B_paths = make_dataset(self.dir_B)
self.A_paths_1 = sorted(self.A_paths_1)
self.A_paths_2 = sorted(self.A_paths_2)
self.B_paths = sorted(self.B_paths)
self.A_size_1 = len(self.A_paths_1)
self.A_size_2 = len(self.A_paths_2)
self.B_size = len(self.B_paths)
self.A_labels_1 = make_dataset(self.dir_A_label_1)
self.A_labels_2 = make_dataset(self.dir_A_label_2)
self.A_labels_1 = sorted(self.A_labels_1)
self.A_labels_2 = sorted(self.A_labels_2)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path_1 = self.A_paths_1[index % self.A_size_1]
A_path_2 = self.A_paths_2[index % self.A_size_2]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path_1 = self.A_labels_1[index % self.A_size_1]
A_label_path_2 = self.A_labels_2[index % self.A_size_2]
A_label_1 = Image.open(A_label_path_1)
A_label_2 = Image.open(A_label_path_2)
# remaping label for synthia
A_label_1 = np.asarray(A_label_1)
A_label_1 = syn_relabel(A_label_1)
A_label_1 = Image.fromarray(A_label_1, 'L')
# remaping label for gta5
A_label_2 = np.asarray(A_label_2)
A_label_2 = remap_labels_to_train_ids(A_label_2)
A_label_2 = Image.fromarray(A_label_2, 'L')
A_img_1 = Image.open(A_path_1).convert('RGB')
A_img_2 = Image.open(A_path_2).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A_1 = self.transform(A_img_1)
A_2 = self.transform(A_img_2)
B = self.transform(B_img)
A_label_1 = self.label_transform(A_label_1)
A_label_2 = self.label_transform(A_label_2)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A_1[0, ...] * 0.299 + A_1[1, ...] * 0.587 + A_1[2, ...] * 0.114
A_1 = tmp.unsqueeze(0)
tmp = A_2[0, ...] * 0.299 + A_2[1, ...] * 0.587 + A_2[2, ...] * 0.114
A_2 = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A_1': A_1, 'A_2': A_2, 'B': B, 'A_paths_1': A_path_1, 'A_paths_2': A_path_2, 'B_paths': B_path, 'A_label_1': A_label_1,
'A_label_2': A_label_2}
def __len__(self):
return max(self.A_size_1, self.B_size, self.A_size_2)
def name(self):
return 'GTA5_Synthia_Cityscapes'
================================================
FILE: cyclegan/data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
import numpy as np
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_cs_labels(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
if path.endswith("_gtFine_labelIds.png"):
images.append(path)
return list(set(images))
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return list(set(images))
def load_labels(dir, images):
if os.path.exists(os.path.join(dir, 'labels.txt')):
with open(os.path.join(dir, 'labels.txt'), 'r') as f:
data = f.read().splitlines()
parse = np.array([(x.split(' ')[0], int(x.split(' ')[1])) for x in data])
label_dict = dict(parse)
labels = []
for image in images:
im_id = image.split('/')[-1].split('.')[0]
labels.append(label_dict[im_id])
elif os.path.isdir(os.path.join(dir, 'labels')):
Exception('Not yet implemented load_labels for image folder')
else:
Exception('load_labels expects %s to contain labels.txt or labels folder' % dir)
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: cyclegan/data/synthia_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.image_folder import make_cs_labels, make_dataset
from data.cityscapes import remap_labels_to_train_ids
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
class SynthiaCityscapesDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot, 'synthia', 'RGB')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
self.A_paths = make_dataset(self.dir_A)
self.B_paths = make_dataset(self.dir_B)
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.A_labels = make_dataset(self.dir_A_label)
self.B_labels = make_cs_labels(self.dir_B_label)
self.A_labels = sorted(self.A_labels)
self.B_labels = sorted(self.B_labels)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path = self.A_labels[index % self.A_size]
B_label_path = self.B_labels[index_B]
A_label = Image.open(A_label_path)
B_label = Image.open(B_label_path)
A_label = np.asarray(A_label)
A_label = syn_relabel(A_label)
A_label = Image.fromarray(A_label, 'L')
B_label = np.asarray(B_label)
B_label = remap_labels_to_train_ids(B_label)
B_label = Image.fromarray(B_label, 'L')
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A = self.transform(A_img)
B = self.transform(B_img)
A_label = self.label_transform(A_label)
B_label = self.label_transform(B_label)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
def __len__(self):
return max(self.A_size, self.B_size)
def name(self):
return 'Synthia_Cityscapes'
================================================
FILE: cyclegan/environment.yml
================================================
name: pytorch-CycleGAN-and-pix2pix
channels:
- peterjc123
- defaults
dependencies:
- python=3.5.5
- pytorch=0.3.1
- scipy
- pip:
- dominate==2.3.1
- git+https://github.com/pytorch/vision.git
- Pillow==5.0.0
- numpy==1.14.1
- visdom==0.1.7
================================================
FILE: cyclegan/models/__init__.py
================================================
import logging
def create_model(opt):
model = None
if opt.model == 'cycle_gan':
# assert(opt.dataset_mode == 'unaligned')
from .cycle_gan_model import CycleGANModel
model = CycleGANModel()
elif opt.model == 'test':
from .test_model import TestModel
model = TestModel()
elif opt.model == 'multi_cycle_gan_semantic':
from .multi_cycle_gan_semantic_model import CycleGANSemanticModel
model = CycleGANSemanticModel()
elif opt.model == 'cycle_gan_semantic_fcn':
from .cycle_gan_semantic_model import CycleGANSemanticModel
model = CycleGANSemanticModel()
else:
raise NotImplementedError('model [%s] not implemented.' % opt.model)
model.initialize(opt)
logging.info("model [%s] was created" % (model.name()))
return model
================================================
FILE: cyclegan/models/base_model.py
================================================
import os
from collections import OrderedDict
import torch
from . import networks
class BaseModel():
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# load and print networks; create shedulars
def setup(self, opt):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
# Don't save semantic consistency networks
if isinstance(name, str) and ("PixelCLS" not in name):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
================================================
FILE: cyclegan/models/cycle_gan_model.py
================================================
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class CycleGANModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_A')
visual_names_B.append('idt_B')
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
def optimize_parameters(self):
# forward
self.forward()
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad()
self.backward_D_A()
self.backward_D_B()
self.optimizer_D.step()
================================================
FILE: cyclegan/models/cycle_gan_semantic_model.py
================================================
import itertools
import sys
import torch
import torch.nn.functional as F
from util.image_pool import ImagePool
from . import networks
from .base_model import BaseModel
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model
class CycleGANSemanticModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A',
'D_B', 'G_B', 'cycle_B', 'idt_B',
'sem_AB']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_A')
visual_names_B.append('idt_B')
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
# Here for semantic consistency loss, load a fcn network as fs here.
self.netPixelCLS = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_init)
# Specially initialize Pixel CLS network
if len(self.gpu_ids) > 0:
assert (torch.cuda.is_available())
self.netPixelCLS.to(self.gpu_ids[0])
self.netPixelCLS = torch.nn.DataParallel(self.netPixelCLS, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# self.criterionCLS = torch.nn.modules.CrossEntropyLoss()
self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
if 'A_label' in input and 'B_label' in input:
self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device)
self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device)
# self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain
def forward(self):
self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)
if self.isTrain:
# Forward all four images through classifier
# Keep predictions from fake images only
self.pred_real_A = self.netPixelCLS(self.real_A)
_, self.gt_pred_A = self.pred_real_A.max(1)
self.pred_fake_B = self.netPixelCLS(self.fake_B)
_, pfB = self.pred_fake_B.max(1)
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined Loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_PixelCLS(self):
label_A = self.input_A_label
# forward only real source image through semantic classifier
pred_A = self.netPixelCLS(self.real_A)
self.loss_PixelCLS = self.criterionSemantic(F.log_softmax(pred_A, dim=1), label_A.long())
self.loss_PixelCLS.backward()
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self, opt):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = 2 * self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss standard cyclegan
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
# real_A(syn)->fake_B->(fcn_frozen)->pred_fake_B == input_A_label
if opt.semantic_loss:
self.loss_sem_AB = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B, dim=1), F.softmax(self.pred_real_A,
dim=1))
self.loss_sem_AB = opt.general_semantic_weight * torch.div(self.loss_sem_AB, self.pred_fake_B.shape[1] * self.pred_fake_B.shape[2]
* self.pred_fake_B.shape[3])
self.loss_G += self.loss_sem_AB
self.loss_G.backward()
def optimize_parameters(self, opt):
# forward
self.forward()
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad()
# self.optimizer_CLS.zero_grad()
self.backward_G(opt)
self.optimizer_G.step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad()
self.backward_D_A()
self.backward_D_B()
self.optimizer_D.step()
================================================
FILE: cyclegan/models/multi_cycle_gan_semantic_model.py
================================================
import itertools
import sys
import torch
import torch.nn.functional as F
from util.image_pool import ImagePool
from . import networks
from .base_model import BaseModel
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model
class CycleGANSemanticModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.semantic_loss = opt.semantic_loss
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A_1', 'G_A_1', 'cycle_A_1', 'idt_A_1',
'D_B_1', 'G_B_1', 'cycle_B_1', 'idt_B_1',
'D_A_2', 'G_A_2', 'cycle_A_2', 'idt_A_2',
'D_B_2', 'G_B_2', 'cycle_B_2', 'idt_B_2']
if opt.SAD:
self.loss_names.extend(['D_3_1', 'G_s1s2'])
if opt.CCD or opt.HF_CCD:
self.loss_names.extend(['D_21', 'G_s1s21'])
self.loss_names.extend(['D_12', 'G_s2s12'])
if self.semantic_loss:
self.loss_names.extend(['sem_syn', 'sem_gta'])
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A_1 = ['real_A_1', 'fake_B_1', 'rec_A_1']
visual_names_B_1 = ['real_B', 'fake_A_1', 'rec_B_1']
visual_names_A_2 = ['real_A_2', 'fake_B_2', 'rec_A_2']
visual_names_B_2 = ['fake_A_2', 'rec_B_2']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A_1.append('idt_A_1')
visual_names_B_1.append('idt_B_1')
visual_names_A_2.append('idt_A_2')
visual_names_B_2.append('idt_B_2')
self.visual_names = visual_names_A_1 + visual_names_B_1 + visual_names_A_2 + visual_names_B_2
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
# self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
if opt.Shared_DT:
self.model_names = ['G_A_1', 'G_B_1', 'D_A', 'D_B_1', 'D_B_2', 'G_A_2', 'G_B_2']
else:
self.model_names = ['G_A_1', 'G_B_1', 'D_A_1', 'D_B_1', 'G_A_2', 'G_B_2', 'D_A_2', 'D_B_2']
if opt.SAD:
self.model_names.append('D_3')
if opt.CCD or opt.HF_CCD:
self.model_names.append('D_12')
self.model_names.append('D_21')
else: # during test time, only load Gs
self.model_names = ['G_A_1', 'G_B_1', 'G_A_2', 'G_B_2']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B_1 = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B_2 = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
if opt.semantic_loss:
self.netPixelCLS_SYN = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_syn)
self.netPixelCLS_GTA = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_gta)
if len(self.gpu_ids) > 0:
assert (torch.cuda.is_available())
self.netPixelCLS_SYN.to(self.gpu_ids[0])
self.netPixelCLS_SYN = torch.nn.DataParallel(self.netPixelCLS_SYN, self.gpu_ids)
self.netPixelCLS_GTA.to(self.gpu_ids[0])
self.netPixelCLS_GTA = torch.nn.DataParallel(self.netPixelCLS_GTA, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
if opt.Shared_DT:
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
else:
self.netD_A_1 = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_A_2 = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B_1 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B_2 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if opt.SAD:
self.netD_3 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if opt.CCD or opt.HF_CCD:
self.netD_12 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_21 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if self.isTrain:
self.fake_A_1_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_1_pool = ImagePool(opt.pool_size)
self.fake_A_2_pool = ImagePool(opt.pool_size)
self.fake_B_2_pool = ImagePool(opt.pool_size)
self.fake_A_21_pool = ImagePool(opt.pool_size)
self.fake_A_12_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
# initialize optimizers
if opt.Shared_DT:
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B_1.parameters(),
self.netD_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
else:
self.optimizer_D_1 = torch.optim.Adam(itertools.chain(self.netD_A_1.parameters(), self.netD_B_1.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_2 = torch.optim.Adam(itertools.chain(self.netD_A_2.parameters(), self.netD_B_2.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_G_1 = torch.optim.Adam(itertools.chain(self.netG_A_1.parameters(), self.netG_B_1.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_G_2 = torch.optim.Adam(itertools.chain(self.netG_A_2.parameters(), self.netG_B_2.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
if opt.SAD:
self.optimizer_D_3 = torch.optim.Adam(self.netD_3.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
if opt.CCD or opt.HF_CCD:
self.optimizer_D_21 = torch.optim.Adam(self.netD_21.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_12 = torch.optim.Adam(self.netD_12.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G_1)
self.optimizers.append(self.optimizer_G_2)
if opt.Shared_DT:
self.optimizers.append(self.optimizer_D)
else:
self.optimizers.append(self.optimizer_D_1)
self.optimizers.append(self.optimizer_D_2)
if opt.SAD:
self.optimizers.append(self.optimizer_D_3)
if opt.CCD or opt.HF_CCD:
self.optimizers.append(self.optimizer_D_12)
self.optimizers.append(self.optimizer_D_21)
def set_input(self, input):
self.real_A_1 = input['A_1'].to(self.device)
self.real_A_2 = input['A_2'].to(self.device)
self.real_B = input['B'].to(self.device)
self.image_paths_1 = input['A_paths_1']
self.image_paths_2 = input['A_paths_2']
self.image_paths = self.image_paths_1 + self.image_paths_2
if 'A_label_1' in input and 'B_label' in input and 'A_label_2' in input:
self.input_A_label_1 = input['A_label_1'].to(self.device)
self.input_A_label_2 = input['A_label_2'].to(self.device)
self.input_B_label = input['B_label'].to(self.device)
def forward(self, opt):
# cycle for data input #1
self.fake_B_1 = self.netG_A_1(self.real_A_1)
self.rec_A_1 = self.netG_B_1(self.fake_B_1)
self.fake_A_1 = self.netG_B_1(self.real_B)
self.rec_B_1 = self.netG_A_1(self.fake_A_1)
# cycle for data input #2
self.fake_B_2 = self.netG_A_2(self.real_A_2)
self.rec_A_2 = self.netG_B_2(self.fake_B_2)
self.fake_A_2 = self.netG_B_2(self.real_B)
self.rec_B_2 = self.netG_A_2(self.fake_A_2)
if opt.CCD:
# generate s21 for d21 branch
self.fake_A_21 = self.netG_B_1(self.fake_B_2)
# generate s12 for d12 branch
self.fake_A_12 = self.netG_B_2(self.fake_B_1)
if self.isTrain and self.semantic_loss:
# Forward all four images through classifier
# Keep predictions from fake images only
self.pred_real_A_SYN = self.netPixelCLS_SYN(self.real_A_1)
_, self.gt_pred_A_SYN = self.pred_real_A_SYN.max(1)
self.pred_fake_B_SYN = self.netPixelCLS_SYN(self.fake_B_1)
_, pfB_SYN = self.pred_fake_B_SYN.max(1)
self.pred_real_A_GTA = self.netPixelCLS_GTA(self.real_A_2)
_, self.gt_pred_A_GTA = self.pred_real_A_GTA.max(1)
self.pred_fake_B_GTA = self.netPixelCLS_GTA(self.fake_B_2)
_, pfB_GTA = self.pred_fake_B_GTA.max(1)
def backward_D_basic(self, netD, real, fake, SAD=False):
# Real
if SAD == False:
pred_real = netD(real)
else:
pred_real = netD(real.detach())
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D_A(self, Shared_DT):
# data 1 A1->B
fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
if Shared_DT:
self.loss_D_A_1 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_1)
else:
self.loss_D_A_1 = self.backward_D_basic(self.netD_A_1, self.real_B, fake_B_1)
# data 2 A2->B
fake_B_2 = self.fake_B_2_pool.query(self.fake_B_2)
if Shared_DT:
self.loss_D_A_2 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_2)
else:
self.loss_D_A_2 = self.backward_D_basic(self.netD_A_2, self.real_B, fake_B_2)
def backward_D_B(self):
# data 1 B->A1
fake_A_1 = self.fake_A_1_pool.query(self.fake_A_1)
self.loss_D_B_1 = self.backward_D_basic(self.netD_B_1, self.real_A_1, fake_A_1)
# data 2 B->A2
fake_A_2 = self.fake_A_2_pool.query(self.fake_A_2)
self.loss_D_B_2 = self.backward_D_basic(self.netD_B_2, self.real_A_2, fake_A_2)
def backward_D(self, which_D):
if which_D == 'SAD':
fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
self.loss_D_3_1 = self.backward_D_basic(self.netD_3, self.fake_B_2, fake_B_1, SAD=True)
elif which_D == 'CCD_21':
fake_A_21 = self.fake_A_21_pool.query(self.fake_A_21)
self.loss_D_21 = self.backward_D_basic(self.netD_21, self.real_A_1, fake_A_21)
elif which_D == 'CCD_12':
fake_A_12 = self.fake_A_12_pool.query(self.fake_A_12)
self.loss_D_12 = self.backward_D_basic(self.netD_12, self.real_A_2, fake_A_12)
else:
raise Exception("Invalid Choice {}".format(which_D))
# fake_B_2 = self.fake_B_pool.query(self.fake_B_2)
# self.loss_D_3_2 = self.backward_D_basic(self.netD_3, self.fake_B_1, fake_B_2)
def backward_G(self, opt):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
self.idt_A_1 = self.netG_A_1(self.real_B)
self.loss_idt_A_1 = self.criterionIdt(self.idt_A_1, self.real_B) * lambda_B * lambda_idt
self.idt_A_2 = self.netG_A_2(self.real_B)
self.loss_idt_A_2 = self.criterionIdt(self.idt_A_2, self.real_B) * lambda_B * lambda_idt
self.idt_B_1 = self.netG_B_1(self.real_A_1)
self.loss_idt_B_1 = self.criterionIdt(self.idt_B_1, self.real_A_1) * lambda_A * lambda_idt
self.idt_B_2 = self.netG_B_2(self.real_A_2)
self.loss_idt_B_2 = self.criterionIdt(self.idt_B_2, self.real_A_2) * lambda_A * lambda_idt
else:
self.loss_idt_A_1 = 0
self.loss_idt_A_2 = 0
self.loss_idt_B_1 = 0
self.loss_idt_B_2 = 0
if opt.Shared_DT:
self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A(self.fake_B_1), True)
self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A(self.fake_B_2), True)
else:
self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A_1(self.fake_B_1), True)
self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A_2(self.fake_B_2), True)
# GAN loss D_B(G_B(B))
self.loss_G_B_1 = self.criterionGAN(self.netD_B_1(self.fake_A_1), True)
self.loss_G_B_2 = self.criterionGAN(self.netD_B_2(self.fake_A_2), True)
# Forward cycle loss
self.loss_cycle_A_1 = self.criterionCycle(self.rec_A_1, self.real_A_1) * lambda_A
self.loss_cycle_A_2 = self.criterionCycle(self.rec_A_2, self.real_A_2) * lambda_A
# Backward cycle loss
self.loss_cycle_B_1 = self.criterionCycle(self.rec_B_1, self.real_B) * lambda_B
self.loss_cycle_B_2 = self.criterionCycle(self.rec_B_2, self.real_B) * lambda_B
# combined loss standard cyclegan
self.loss_G_1 = self.loss_G_A_1 + self.loss_G_B_1 + self.loss_cycle_A_1 + self.loss_cycle_B_1 + self.loss_idt_A_1 + self.loss_idt_B_1
self.loss_G_2 = self.loss_G_A_2 + self.loss_G_B_2 + self.loss_cycle_A_2 + self.loss_cycle_B_2 + self.loss_idt_A_2 + self.loss_idt_B_2
self.loss_G = self.loss_G_1 + self.loss_G_2
if opt.SAD:
# D3 loss
if opt.SAD_frozen_epoch != -1 and opt.current_epoch > opt.SAD_frozen_epoch:
self.loss_G_s1s2 = self.criterionGAN(self.netD_3(self.fake_B_1), True)
else:
self.loss_G_s1s2 = 0
self.loss_G += self.loss_G_s1s2
if opt.CCD:
# D21 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
self.loss_G += self.loss_G_s1s21 * opt.D1D2_weight
else:
self.loss_G_s1s21 = 0
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
self.loss_G += self.loss_G_s2s12 * opt.D1D2_weight
else:
self.loss_G_s2s12 = 0
if opt.semantic_loss:
self.loss_sem_syn = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_SYN, dim=1),
F.softmax(self.pred_real_A_SYN, dim=1))
self.loss_sem_gta = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_GTA, dim=1),
F.softmax(self.pred_real_A_GTA, dim=1))
self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_syn, self.pred_fake_B_SYN.shape[1] * self.pred_fake_B_SYN.shape[2]
* self.pred_fake_B_SYN.shape[3])
self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_gta, self.pred_fake_B_GTA.shape[1] * self.pred_fake_B_GTA.shape[2]
* self.pred_fake_B_GTA.shape[3])
self.loss_G.backward()
def backward_HF_CCD(self, opt):
self.fake_B_1 = self.netG_A_1(self.real_A_1)
self.fake_B_2 = self.netG_A_2(self.real_A_2)
# generate s21 for d21 branch
self.fake_A_21 = self.netG_B_1(self.fake_B_2)
# generate s12 for d12 branch
self.fake_A_12 = self.netG_B_2(self.fake_B_1)
# D12 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
else:
self.loss_G_s2s12 = 0
# D21 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
else:
self.loss_G_s1s21 = 0
# self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
# self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
self.loss_G_HF = self.loss_G_s1s21 * opt.CCD_weight + self.loss_G_s2s12 * opt.CCD_weight
if isinstance(self.loss_G_HF, torch.Tensor):
self.loss_G_HF.backward()
def optimize_parameters(self, opt):
# forward
self.forward(opt)
# G_A and G_B
# set D to false, back prop G's gradients
if opt.Shared_DT:
self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], False)
else:
self.set_requires_grad([self.netD_A_1, self.netD_B_1], False)
self.set_requires_grad([self.netD_A_2, self.netD_B_2], False)
if opt.SAD:
self.set_requires_grad([self.netD_3], False)
if opt.CCD or opt.HF_CCD:
self.set_requires_grad([self.netD_21], False)
self.set_requires_grad([self.netD_12], False)
self.set_requires_grad([self.netG_A_1, self.netG_B_1], True)
self.set_requires_grad([self.netG_A_2, self.netG_B_2], True)
self.optimizer_G_1.zero_grad()
self.optimizer_G_2.zero_grad()
# self.optimizer_CLS.zero_grad()
self.backward_G(opt)
self.optimizer_G_1.step()
self.optimizer_G_2.step()
if opt.HF_CCD:
self.optimizer_G_1.zero_grad()
self.optimizer_G_2.zero_grad()
self.set_requires_grad([self.netG_A_1, self.netG_A_2], True)
self.set_requires_grad([self.netG_B_1, self.netG_B_2], False)
self.backward_HF_CCD(opt)
self.optimizer_G_1.step()
self.optimizer_G_2.step()
# D_A and D_B
if opt.Shared_DT:
self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], True)
else:
self.set_requires_grad([self.netD_A_1, self.netD_B_1], True)
self.set_requires_grad([self.netD_A_2, self.netD_B_2], True)
if opt.Shared_DT:
self.optimizer_D.zero_grad()
else:
self.optimizer_D_1.zero_grad()
self.optimizer_D_2.zero_grad()
self.backward_D_B()
self.backward_D_A(opt.Shared_DT)
if opt.Shared_DT:
self.optimizer_D.step()
else:
self.optimizer_D_1.step()
self.optimizer_D_2.step()
if opt.SAD:
self.set_requires_grad([self.netD_3], True)
self.optimizer_D_3.zero_grad()
self.backward_D('SAD')
self.optimizer_D_3.step()
if opt.CCD or opt.HF_CCD:
self.set_requires_grad([self.netD_21], True)
self.optimizer_D_21.zero_grad()
self.backward_D('CCD_21')
self.optimizer_D_21.step()
self.set_requires_grad([self.netD_12], True)
self.optimizer_D_12.zero_grad()
self.backward_D('CCD_12')
self.optimizer_D_12.step()
================================================
FILE: cyclegan/models/networks.py
================================================
import functools
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', gpu_ids=[]):
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
init_weights(net, init_type)
return net
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
return init_net(netG, init_type, gpu_ids)
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
elif which_model_netD == 'pixel':
netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
return init_net(netD, init_type, gpu_ids)
def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]):
# if output_nc == 3:
# netC = get_model('DTN', num_cls=10)
# else:
# Exception('classifier only implemented for 32x32x3 images')
netC = Classifier(output_nc, ndf)
return init_net(netC, init_type, gpu_ids)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(input)
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
assert (n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
if use_sigmoid:
self.net.append(nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def forward(self, input):
return self.net(input)
class Classifier(nn.Module):
def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d):
super(Classifier, self).__init__()
kw = 3
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(3):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2),
norm_layer(ndf * nf_mult, affine=True),
nn.LeakyReLU(0.2, True)
]
self.before_linear = nn.Sequential(*sequence)
sequence = [
nn.Linear(ndf * nf_mult, 1024),
nn.Linear(1024, 10)
]
self.after_linear = nn.Sequential(*sequence)
def forward(self, x):
bs = x.size(0)
out = self.after_linear(self.before_linear(x).view(bs, -1))
return out
# return nn.functional.log_softmax(out, dim=1)
================================================
FILE: cyclegan/models/test_model.py
================================================
from . import networks
from .base_model import BaseModel
class TestModel(BaseModel):
def name(self):
return 'TestModel'
def initialize(self, opt):
assert (not opt.isTrain)
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = []
# specify the images you want to save/display. The program will call base_model.get_current_visuals
self.visual_names = ['real_A']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if opt.dataset_mode == 'synthia_cityscapes':
self.model_names = ['G_A_1']
self.visual_names.append('fake_B_1')
self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
elif opt.dataset_mode == 'gta5_cityscapes':
self.model_names = ['G_A_2']
self.visual_names.append('fake_B_2')
self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
def set_input(self, input):
# we need to use single_dataset mode
self.real_A = input['A'].to(self.device)
self.image_paths = input['A_paths']
def forward(self):
if hasattr(self, 'netG_A_1'):
self.fake_B_1 = self.netG_A_1(self.real_A)
elif hasattr(self, 'netG_A_2'):
self.fake_B_2 = self.netG_A_2(self.real_A)
================================================
FILE: cyclegan/options/__init__.py
================================================
================================================
FILE: cyclegan/options/base_options.py
================================================
import argparse
import os
import torch
from util import util
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--loadSize', type=int, default=600, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=600, help='then crop to this size')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--which_model_netD', type=str, default='n_layers', help='selects model to use for netD')
self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--name', type=str, default='experiment_name',
help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
help='chooses how datasets are loaded. [unaligned | aligned | single]')
self.parser.add_argument('--model', type=str, default='cycle_gan',
help='chooses which model to use. cycle_gan, pix2pix, test')
self.parser.add_argument('--weights_model_type', type=str, default='drn26',
help='chooses which model to use. drn26, fcn8s')
self.parser.add_argument('--num_cls', default=19, type=int)
self.parser.add_argument('--max_epoch', default=20, type=int)
self.parser.add_argument('--current_epoch', default=0, type=int)
self.parser.add_argument('--weights_init', type=str)
self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
self.parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, '
'only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width_and_crop',
help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
self.parser.add_argument('--suffix', default='', type=str,
help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
self.parser.add_argument('--out_all', action='store_true', help='output all stylized images(fake_B_{})')
self.parser.add_argument('--SAD', action='store_true', help='Sub-domain Aggregation Discriminator module')
self.parser.add_argument('--CCD', action='store_true', help='Cross-domain Cycle Discriminator module')
self.parser.add_argument('--CCD_weight', type=float, default=1, help='weight for cross domain cycle discriminator loss')
self.parser.add_argument('--HF_CCD', action='store_true', help='Half Freeze Cross-domain Cycle Discriminator module')
self.parser.add_argument('--CCD_frozen_epoch', type=int, default=-1)
self.parser.add_argument('--SAD_frozen_epoch', type=int, default=-1)
self.parser.add_argument('--Shared_DT', type=bool, default=True, help="Through ")
self.parser.add_argument('--model_type', type=str, default='fcn8s', help="choose to load which type of model (fcn8s, drn26, deeplabv2)")
self.parser.add_argument('--semantic_loss', action='store_true', help='use semantic loss')
self.parser.add_argument('--general_semantic_weight', type=float, default=0.2, help='weight for semantic loss')
self.parser.add_argument('--weights_syn', type=str, default='', help='init weights for synthia')
self.parser.add_argument('--weights_gta', type=str, default='', help='init weights for gta')
self.parser.add_argument('--inference_script', type=str, default='', help='inference script')
self.parser.add_argument('--dynamic_weight', type=float, default=10, help='Weight for Dynamic Semantic Loss(KL div) loss')
self.initialized = True
def parse(self):
if not self.initialized:
self.initialize()
opt = self.parser.parse_args()
opt.isTrain = self.isTrain # train or test
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# set gpu ids
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
args = vars(opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
self.opt = opt
return self.opt
================================================
FILE: cyclegan/options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.isTrain = False
================================================
FILE: cyclegan/options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
self.parser.add_argument('--display_ncols', type=int, default=4,
help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, we save the model by , +, ...')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for cycle loss (B -> A -> B)')
self.parser.add_argument('--lambda_identity', type=float, default=0,
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the '
'identity mapping loss.'
'For example, if the weight of the identity loss should be 10 times smaller than the weight of the '
'reconstruction loss, please set lambda_identity = 0.1')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true',
help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
================================================
FILE: cyclegan/test.py
================================================
import os
import sys
import torch
from models import create_model
from options.test_options import TestOptions
from util import html
from util.visualizer import save_images
from data import CreateDataLoader
import logging
sys.path.append("/nfs/project/libo_i/MADAN")
if __name__ == '__main__':
opt = TestOptions().parse()
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.display_id = -1 # no visdom display
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
model.setup(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
# check img size
if i == 0:
for item in data.items():
if isinstance(item[1], torch.Tensor):
logging.info(item[0], item[1].size())
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
# remove reductant files when outputing
if opt.out_all:
remove_list = []
for item in visuals:
if 'fake_B' not in item:
remove_list.append(item)
for rm_item in remove_list:
del visuals[rm_item]
img_path = model.get_image_paths()
if i % 5 == 0:
logging.info('processing (%04d)-th image...' % (i * opt.batchSize))
if 'mul' in opt.model:
save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, multi_flag=True)
else:
save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
================================================
FILE: cyclegan/train.py
================================================
import subprocess
import sys
import time
sys.path.append("/nfs/project/libo_i/MADAN/cyclegan")
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
import torch
import logging
if __name__ == '__main__':
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
logging.info('#training images = %d' % dataset_size)
model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)
total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
opt.current_epoch = epoch
logging.info("Current epoch update to {}".format(opt.current_epoch))
for i, data in enumerate(dataset):
if total_steps == 0:
for item in data.items():
if isinstance(item[1], torch.Tensor):
logging.info(item[1].size())
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters(opt)
if total_steps % opt.display_freq == 0:
save_result = total_steps % opt.update_html_freq == 0
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
if total_steps % opt.save_latest_freq == 0:
logging.info('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.save_networks('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
logging.info('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)
logging.info('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.max_epoch, time.time() - epoch_start_time))
model.update_learning_rate()
================================================
FILE: cyclegan/util/__init__.py
================================================
================================================
FILE: cyclegan/util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""
Download CycleGAN or Pix2Pix Data.
Args:
technique : str
One of: 'cyclegan' or 'pix2pix'.
verbose : bool
If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Args:
save_path : str
A directory to save the data to.
dataset : str, optional
A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full : str
The absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: cyclegan/util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, reflesh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: cyclegan/util/image_pool.py
================================================
import random
import torch
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = torch.cat(return_images, 0)
return return_images
================================================
FILE: cyclegan/util/util.py
================================================
from __future__ import print_function
import os
import numpy as np
import torch
from PIL import Image
# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, imtype=np.uint8):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor.cpu().float().numpy()
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
# def tensor2im(input_image, imtype=np.uint8):
# if isinstance(input_image, torch.Tensor):
# image_tensor = input_image.data
# else:
# return input_image
# image_numpy = image_tensor[0].cpu().float().numpy()
# if image_numpy.shape[0] == 1:
# image_numpy = np.tile(image_numpy, (3, 1, 1))
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
# return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: cyclegan/util/visualizer.py
================================================
import ntpath
import os
import time
import numpy as np
from . import html, util
# save image to the disk
def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=256, multi_flag=False):
for i in range(len(image_path)):
short_path = ntpath.basename(image_path[i])
name = os.path.splitext(short_path)[0]
for ind, (label, im_data) in enumerate(visuals.items()):
# align visual names and real image name
if multi_flag is True and (str(i + 1) not in label):
continue
im = util.tensor2im(im_data[i, :, :, :])
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
util.save_image(im, save_path)
class Visualizer():
def __init__(self, opt):
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.opt = opt
self.saved = False
if self.display_id > 0:
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
self.saved = False
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, save_result):
if self.display_id > 0: # show images in the browser
ncols = self.ncols
if ncols > 0:
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """""" % (w, h)
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '%s | ' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '%s
' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += ' | '
idx += 1
if label_html_row != '':
label_html += '%s
' % label_html_row
# pane col = image row
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
if self.use_html and (save_result or not self.saved): # save images to a html file
self.saved = True
for label, image in visuals.items():
image_numpy = util.tensor2im(image[0, :, :, :])
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
# image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# losses: dictionary of error labels and values
def plot_current_losses(self, epoch, counter_ratio, opt, losses):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, i, losses, t, t_data):
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
================================================
FILE: requirements.txt
================================================
scipy
torchvision
tensorboardX
tensorflow
click
tqdm
requests
colorlog
pyyaml
torch>=1.1.0
torchvision>=0.3.0
dominate>=2.3.1
visdom>=0.1.8.3
================================================
FILE: scripts/ADDA/adda_cyclegta2cs_feat.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
######################
# loss weight params #
######################
lr=2e-5
momentum=0.9
lambda_d=1
lambda_g=0.1
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
export PYTHONPATH='/usr/bin/python3'
################
# train params #
################
max_iter=50000
crop=600
snapshot=5000
batch=8
weight_share='weights_shared'
discrim='discrim_feat'
########
# Data #
########
src='cyclegta5'
tgt='cityscapes'
data_flag='V4_SEM_300'
datadir='/nfs/project/libo_i/cycada/data/'
resdir="results/${src}_to_${tgt}/adda_sgd_${weight_share}_nolsgan_${discrim}_${data_flag}"
# init with pre-trained cyclegta5 model
#model='drn26'
#baseiter=115000
model='fcn8s'
baseiter=100000
base_model="/nfs/project/libo_i/cycada/pretrained_models/GTA2CS_score_V4_SEM_300_net-itercurr.pth"
discrim_model="/nfs/project/libo_i/cycada/pretrained_models/new_Dis_cyclegta5_sem_v4_300x540_iter-iter5440_abv60.pth"
outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}_${discrim}"
echo $outdir
echo $base_model
cd /nfs/project/libo_i/cycada
# Run python script #
python3 scripts/train_fcn_adda.py \
${outdir} \
--dataset ${src} --dataset ${tgt} --datadir ${datadir} \
--lr ${lr} --momentum ${momentum} --gpu ${gpu} \
--lambda_d ${lambda_d} --lambda_g ${lambda_g} \
--weights_init ${base_model} --model ${model} \
--"${weight_share}" --${discrim} --no_lsgan \
--max_iter ${max_iter} --batch ${batch} \
--snapshot ${snapshot} --no_mmd_loss --small 0 --resize 300 --data_flag ${data_flag}
================================================
FILE: scripts/ADDA/adda_cyclegta2cs_score.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
######################
# loss weight params #
######################
lr=2e-5
momentum=0.9
lambda_d=1
lambda_g=0.1
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
export PYTHONPATH='/usr/bin/python3'
################
# train params #
################
max_iter=25000
crop=600
snapshot=1000
batch=8
weight_share='weights_shared'
discrim='discrim_score'
########
# Data #
########
src='cyclegta5'
tgt='cityscapes'
data_flag='V4_SEM_300'
datadir='/nfs/project/libo_i/cycada/data/'
resdir="results/${src}_to_${tgt}/adda_sgd_${weight_share}_nolsgan_${discrim}_${data_flag}"
# init with pre-trained cyclegta5 model
#model='drn26'
#baseiter=115000
model='fcn8s'
baseiter=100000
base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclegta_V4_SEM_Final_best_model.pth"
discrim_model="/nfs/project/libo_i/cycada/pretrained_models/new_Dis_cyclegta5_sem_v4_300x540_iter-iter5440_abv60.pth"
outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}_${discrim}"
echo $outdir
echo $base_model
cd /nfs/project/libo_i/cycada
# Run python script #
python3 scripts/train_fcn_adda.py \
${outdir} \
--dataset ${src} --dataset ${tgt} --datadir ${datadir} \
--lr ${lr} --momentum ${momentum} --gpu ${gpu} \
--lambda_d ${lambda_d} --lambda_g ${lambda_g} \
--weights_init ${base_model} --model ${model} \
--"${weight_share}" --${discrim} --no_lsgan \
--max_iter ${max_iter} --batch ${batch} --weights_discrim ${discrim_model} \
--snapshot ${snapshot} --no_mmd_loss --small 0 --resize 300 --data_flag ${data_flag}
================================================
FILE: scripts/ADDA/adda_cyclesyn2cs_feat.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
######################
# loss weight params #
######################
lr=1e-5
momentum=0.99
lambda_d=1
lambda_g=0.1
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
export PYTHONPATH='/usr/bin/python3'
################
# train params #
################
max_iter=100000
crop=800
snapshot=5000
batch=4
weight_share='weights_shared'
discrim='discrim_score'
########
# Data #
########
src='cyclesynthia'
tgt='cityscapes'
data_flag='V2_SEM'
datadir='/nfs/project/libo_i/cycada/data/'
resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}"
# init with pre-trained cyclegta5 model
#model='drn26'
#baseiter=115000
model='fcn8s'
baseiter=100000
base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V2_SEM_fcn8s-iter21000.pth"
outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}"
echo $outdir
echo $base_model
cd /nfs/project/libo_i/cycada
# Run python script #
python3 scripts/train_fcn_adda.py ${outdir} \
--dataset ${src} --dataset ${tgt} --datadir ${datadir} \
--lr ${lr} --momentum ${momentum} --gpu ${gpu} \
--lambda_d ${lambda_d} --lambda_g ${lambda_g} \
--weights_init ${base_model} --model ${model} \
--"${weight_share}" --${discrim} --no_lsgan \
--max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \
--snapshot ${snapshot}
================================================
FILE: scripts/ADDA/adda_cyclesyn2cs_score.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
######################
# loss weight params #
######################
lr=1e-5
momentum=0.99
lambda_d=1
lambda_g=0.1
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
export PYTHONPATH='/usr/bin/python3'
################
# train params #
################
max_iter=100000
crop=800
snapshot=5000
batch=4
weight_share='weights_shared'
discrim='discrim_score'
########
# Data #
########
src='cyclesynthia'
tgt='cityscapes'
data_flag='V2_SEM'
datadir='/nfs/project/libo_i/cycada/data/'
resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}"
# init with pre-trained cyclegta5 model
#model='drn26'
#baseiter=115000
model='fcn8s'
baseiter=100000
base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V2_SEM_fcn8s-iter21000.pth"
outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}"
echo $outdir
echo $base_model
cd /nfs/project/libo_i/cycada
# Run python script #
python3 scripts/train_fcn_adda.py ${outdir} \
--dataset ${src} --dataset ${tgt} --datadir ${datadir} \
--lr ${lr} --momentum ${momentum} --gpu ${gpu} \
--lambda_d ${lambda_d} --lambda_g ${lambda_g} \
--weights_init ${base_model} --model ${model} \
--"${weight_share}" --${discrim} --no_lsgan \
--max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \
--snapshot ${snapshot}
================================================
FILE: scripts/ADDA/adda_templates.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
######################
# loss weight params #
######################
lr=1e-5
momentum=0.99
lambda_d=1
lambda_g=0.1
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
export PYTHONPATH='/usr/bin/python3'
################
# train params #
################
max_iter=100000
crop=800
snapshot=5000
batch=4
weight_share='weights_shared'
discrim='discrim_score'
########
# Data #
########
src=$1
tgt='cityscapes'
data_flag=$2
datadir='/nfs/project/libo_i/cycada/data/'
resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}"
# init with pre-trained cyclegta5 model
#model='drn26'
#baseiter=115000
model=$2
baseiter=$3
base_model=$4
outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}"
echo $outdir
echo $base_model
cd /nfs/project/libo_i/cycada
# Run python script #
python3 scripts/train_fcn_adda.py \
${outdir} \
--dataset ${src} --dataset ${tgt} --datadir ${datadir} \
--lr ${lr} --momentum ${momentum} --gpu ${gpu} \
--lambda_d ${lambda_d} --lambda_g ${lambda_g} \
--weights_init ${base_model} --model ${model} \
--"${weight_share}" --${discrim} --no_lsgan \
--max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \
--snapshot ${snapshot}
================================================
FILE: scripts/CycleGAN/cyclegan_gta2cityscapes.sh
================================================
#!/usr/bin/env bash
cd /nfs/project/libo_i/MADAN/cyclegan
sudo python3 train.py --name cyclegan_gta2cityscapes \
--resize_or_crop scale_width_and_crop --loadSize 600 --fineSize 500 --which_model_netD n_layers --n_layers_D 3 \
--model cycle_gan_semantic_fcn --no_flip --batchSize 2 --nThreads 8 \
--dataset_mode gta5_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \
--model_type drn26 --weights_init /nfs/project/libo_i/MADAN/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \
--semantic_loss --gpu 0
================================================
FILE: scripts/CycleGAN/cyclegan_gta_synthia2cityscapes.sh
================================================
#!/usr/bin/env bash
cd /nfs/project/libo_i/MADAN/cyclegan
python3 train.py --name cyclegan_gta_synthia2cityscapes_noIdentity \
--resize_or_crop scale_width_and_crop --loadSize 500 --fineSize 400 \
--model multi_cycle_gan_semantic --no_flip --batchSize 4 \
--dataset_mode gta_synthia_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \
--DSC --general_semantic_weight 20 --CCD --SAD --CCD_weight 0.2 --SAD_frozen_epoch 5 --CCD_frozen_epoch 10 --max_epoch 40 \
--weights_syn /nfs/project/libo_i/MADAN/pretrained_models/cyclesynthia_drn26_iter2000.pth \
--weights_gta /nfs/project/libo_i/cycada/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \
--gpu 0,1,2,3 --semantic_loss
================================================
FILE: scripts/CycleGAN/cyclegan_synthia2cityscapes.sh
================================================
#!/usr/bin/env bash
cd /root/MADAN/cyclegan
python3 train.py --name cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_SEM_final_scale \
--resize_or_crop scale_width_and_crop --loadSize 500 --fineSize 400 \
--model multi_cycle_gan_semantic --no_flip --batchSize 4 \
--dataset_mode gta_synthia_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \
--DSC --general_semantic_weight 20 --CCD --SAD --CCD_weight 0.2 --SAD_frozen_epoch 5 --CCD_frozen_epoch 10 --max_epoch 40 --gpu 0,1,2,3 \
--weights_syn /nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V4_SEM_Final_iter_6000.pth \
--weights_gta /nfs/project/libo_i/cycada/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \
--gpu 0,1,2,3 --semantic_loss
================================================
FILE: scripts/CycleGAN/test_templates.sh
================================================
#!/usr/bin/env bash
how_many=100000
cd /root/MADAN/cyclegan
name=$1
epoch=$2
python3 test.py --name ${name} --resize_or_crop=None \
--which_model_netD n_layers --n_layers_D 3 \
--model $3 --loadSize 600 \
--no_flip --batchSize 32 --nThreads 16 \
--dataset_mode $4 --dataroot /nfs/project/libo_i/cycada/data \
--which_direction AtoB \
--phase train --out_all \
--how_many ${how_many} --which_epoch ${epoch} --gpu 0
================================================
FILE: scripts/CycleGAN/test_templates_cycle.sh
================================================
#!/usr/bin/env bash
# Sequentially load two generators(GTA, Synthia) and finish
how_many=100000
cd /root/MADAN/cyclegan
model=$1
epoch=$2
python3 test.py --name ${model} --resize_or_crop=None \
--which_model_netD n_layers --n_layers_D 3 \
--model $3 \
--no_flip --batchSize 32 --nThreads 16 \
--dataset_mode $4 --dataroot /nfs/project/libo_i/cycada/data \
--which_direction AtoB \
--phase train --out_all \
--how_many ${how_many} --which_epoch ${epoch} --gpu 0
python3 test.py --name ${model} --resize_or_crop=None \
--which_model_netD n_layers --n_layers_D 3 \
--model $3 \
--no_flip --batchSize 32 --nThreads 16 \
--dataset_mode $5 --dataroot /nfs/project/libo_i/cycada/data \
--which_direction AtoB \
--phase train --out_all \
--how_many ${how_many} --which_epoch ${epoch} --gpu 0
# cyclegan/test_templates_cycle.sh cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_SEM_final_scale 15 test synthia_cityscapes gta5_cityscapes
================================================
FILE: scripts/FCN/train_fcn8s_cyclesgta5.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
data=cyclegta5
model=fcn8s
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
datadir=/root/MADAN/data
batch=8
iterations=30000
snapshot=2000
num_cls=19
data_flag=V4_SEM_Final_Scale
cd /root/MADAN
outdir=/root/MADAN/results/${data}/${data}_${data_flag}/${model}
mkdir -p results/${data}/${data}_${data_flag}/${model}
echo $outdir
python3 scripts/train_fcn.py ${outdir} --model ${model} \
--num_cls ${num_cls} --gpu ${gpu} \
-b ${batch} --adam \
--iterations ${iterations} \
--datadir ${datadir} \
--snapshot ${snapshot} \
--dataset ${data} --data_flag ${data_flag}
================================================
FILE: scripts/FCN/train_fcn8s_cyclesynthia.sh
================================================
#!/usr/bin/env bash
gpu=0,1,2,3
data=cyclesynthia
model=fcn8s
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
datadir=/root/MADAN/data
batch=28
iterations=20000
snapshot=1000
num_cls=19
data_flag=V4_SEM_Final_Scale
cd /root/MADAN
outdir=/root/MADAN/cycada/results/${data}/${data}_${data_flag}/${model}
mkdir -p results/${data}/${data}_${data_flag}/${model}
echo $outdir
python3 scripts/train_fcn.py ${outdir} --model ${model} \
--num_cls ${num_cls} --gpu ${gpu} \
-b ${batch} --adam \
--iterations ${iterations} \
--datadir ${datadir} \
--snapshot ${snapshot} --small 1 \
--dataset ${data} --data_flag ${data_flag}
================================================
FILE: scripts/eval_fcn.py
================================================
import os
import sys
from torchvision.transforms import transforms
sys.path.append('/nfs/project/libo_iMADAN')
import json
import click
import numpy as np
import torch
from torch.autograd import Variable
from tqdm import *
from cycada.data.data_loader import dataset_obj, get_fcn_dataset
from cycada.models.models import get_model, models
from cycada.util import to_tensor_raw
import torchvision
from PIL import Image
loader = transforms.Compose([
transforms.ToTensor()])
unloader = transforms.ToPILImage()
def fmt_array(arr, fmt=','):
strs = ['{:.3f}'.format(x) for x in arr]
return fmt.join(strs)
def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def result_stats(hist):
acc_overall = np.diag(hist).sum() / hist.sum() * 100
acc_percls = np.diag(hist) / (hist.sum(1) + 1e-8) * 100
iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-8) * 100
freq = hist.sum(1) / hist.sum()
fwIU = (freq[freq > 0] * iu[freq > 0]).sum()
return acc_overall, acc_percls, iu, fwIU
@click.command()
@click.argument('path', type=click.Path(exists=True))
@click.option('--dataset', default='cityscapes',
type=click.Choice(dataset_obj.keys()))
@click.option('--datadir', default='',
type=click.Path(exists=True))
@click.option('--model', default='fcn8s', type=click.Choice(models.keys()))
@click.option('--gpu', default='0')
@click.option('--num_cls', default=19)
@click.option('--batch_size', default=16)
@click.option('--loadSize', default=None)
@click.option('--fineSize', default=None)
def main(path, dataset, datadir, model, gpu, num_cls, batch_size, loadSize, fineSize):
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
net = get_model(model, num_cls=num_cls, weights_init=path)
str_ids = gpu.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
# set gpu ids
if len(gpu_ids) > 0:
torch.cuda.set_device(gpu_ids[0])
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
net.eval()
if (loadSize and fineSize) is not None:
print("Loading Center Crop DataLoader Transform")
data_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.BICUBIC),
net.module.transform.transforms[0], net.module.transform.transforms[1]])
target_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.NEAREST),
transforms.Lambda(lambda img: to_tensor_raw(img))])
else:
data_transform = net.module.transform
target_transform = torchvision.transforms.Compose([transforms.Lambda(lambda img: to_tensor_raw(img))])
ds = get_fcn_dataset(dataset, datadir, num_cls=num_cls, split='val', transform=data_transform, target_transform=target_transform)
classes = ds.classes
loader = torch.utils.data.DataLoader(ds, num_workers=16, batch_size=batch_size)
errs = []
hist = np.zeros((num_cls, num_cls))
if len(loader) == 0:
print('Empty data loader')
return
iterations = tqdm(enumerate(loader))
for im_i, (im, label) in iterations:
if im_i == 0:
print(im.size())
print(label.size())
if im_i > 32:
break
im = Variable(im.cuda())
score = net(im).data
_, preds = torch.max(score, 1)
hist += fast_hist(label.numpy().flatten(), preds.cpu().numpy().flatten(), num_cls)
acc_overall, acc_percls, iu, fwIU = result_stats(hist)
iterations.set_postfix({'mIoU': ' {:0.2f} fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.format(np.nanmean(iu), fwIU, acc_overall,
np.nanmean(acc_percls))})
print()
synthia_metric_iu = 0
# line = ""
for index, item in enumerate(classes):
print(classes[index], " {:0.1f}".format(iu[index]))
if classes[index] != 'terrain' and classes[index] != 'truck' and classes[index] != 'train':
synthia_metric_iu += iu[index]
# line += " {:0.1f} &".format(iu[index])
# variable "line" is used for adding format results into latex grids
# print(line)
print(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))
print("16 Class-Wise mIOU is {}".format(synthia_metric_iu / 16))
print('Errors:', errs)
cur_path = path.split('/')[-1]
parent_path = path.replace(cur_path, '')
results_dict_path = os.path.join(parent_path, 'result.json')
results_dict = {}
results_dict[cur_path] = [np.nanmean(iu), synthia_metric_iu / 16]
if os.path.exists(results_dict_path) is False:
with open(results_dict_path, 'w') as fp:
json.dump(results_dict, fp)
else:
with open(results_dict_path, 'r') as fp:
exist_dict = json.load(fp)
with open(results_dict_path, 'w') as fp:
exist_dict.update(results_dict)
json.dump(exist_dict, fp)
if __name__ == '__main__':
main()
================================================
FILE: scripts/train_fcn.py
================================================
import logging
import os.path
import sys
from collections import deque
import click
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from tensorboardX import SummaryWriter
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.data.data_loader import get_fcn_dataset as get_dataset
from cycada.models import get_model
from cycada.models.models import models
from cycada.transforms import augment_collate
from cycada.util import config_logging
from cycada.util import to_tensor_raw, step_lr
from cycada.tools.util import make_variable
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
def roundrobin_infinite(*loaders):
if not loaders:
return
iters = [iter(loader) for loader in loaders]
while True:
for i in range(len(iters)):
it = iters[i]
try:
yield next(it)
except StopIteration:
iters[i] = iter(loaders[i])
yield next(iters[i])
def supervised_loss(score, label, weights=None):
loss_fn_ = torch.nn.NLLLoss2d(weight=weights, size_average=True, ignore_index=255)
loss = loss_fn_(F.log_softmax(score), label)
return loss
@click.command()
@click.argument('output')
@click.option('--dataset', required=True, multiple=True)
@click.option('--datadir', default="", type=click.Path(exists=True))
@click.option('--batch_size', '-b', default=1)
@click.option('--lr', '-l', default=0.001)
@click.option('--step', type=int)
@click.option('--iterations', '-i', default=100000)
@click.option('--momentum', '-m', default=0.9)
@click.option('--snapshot', '-s', default=5000)
@click.option('--downscale', type=int)
@click.option('--resize_to', type=int, default=720)
@click.option('--augmentation/--no-augmentation', default=False)
@click.option('--adam/--sgd', default=False)
@click.option('--small', type=int, default=2)
@click.option('--preprocessing', default=False)
@click.option('--force_split', default=False)
@click.option('--fyu/--torch', default=False)
@click.option('--crop_size', default=720)
@click.option('--weights', type=click.Path(exists=True))
@click.option('--model_weights', type=click.Path(exists=True))
@click.option('--model', default='fcn8s', type=click.Choice(models.keys()))
@click.option('--num_cls', default=19, type=int)
@click.option('--nthreads', default=8, type=int)
@click.option('--gpu', default='0')
@click.option('--start_step', default=0)
@click.option('--data_flag', default='', type=str)
@click.option('--rundir_flag', default='', type=str)
@click.option('--serial_batches', type=bool, default=False, help='if true, takes images in order to make batches, otherwise takes them randomly')
def main(output, dataset, datadir, batch_size, lr, step, iterations,
momentum, snapshot, downscale, augmentation, fyu, crop_size,
weights, model, gpu, num_cls, nthreads, model_weights, data_flag,
serial_batches, resize_to, start_step, preprocessing, small, rundir_flag, force_split, adam):
if weights is not None:
raise RuntimeError("weights don't work because eric is bad at coding")
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
config_logging()
logdir_flag = data_flag
if rundir_flag != "":
logdir_flag += "_{}".format(rundir_flag)
logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset), logdir_flag)
writer = SummaryWriter(log_dir=logdir)
if model == 'fcn8s':
net = get_model(model, num_cls=num_cls, weights_init=model_weights)
else:
net = get_model(model, num_cls=num_cls, finetune=True, weights_init=model_weights)
net.cuda()
str_ids = gpu.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
# set gpu ids
if len(gpu_ids) > 0:
torch.cuda.set_device(gpu_ids[0])
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
transform = []
target_transform = []
if preprocessing:
transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)])])
target_transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.NEAREST)])
transform.extend([net.module.transform])
target_transform.extend([to_tensor_raw])
transform = torchvision.transforms.Compose(transform)
target_transform = torchvision.transforms.Compose(target_transform)
if force_split:
datasets = []
datasets.append(
get_dataset(dataset[0], os.path.join(datadir, dataset[0]), num_cls=num_cls, transform=transform, target_transform=target_transform,
data_flag=data_flag))
datasets.append(
get_dataset(dataset[1], os.path.join(datadir, dataset[1]), num_cls=num_cls, transform=transform, target_transform=target_transform))
else:
datasets = [get_dataset(name, os.path.join(datadir, name), num_cls=num_cls, transform=transform, target_transform=target_transform,
data_flag=data_flag) for name in dataset]
if weights is not None:
weights = np.loadtxt(weights)
if adam:
print("Using Adam")
opt = torch.optim.Adam(net.module.parameters(), lr=1e-4)
else:
print("Using SGD")
opt = torch.optim.SGD(net.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005)
if augmentation:
collate_fn = lambda batch: augment_collate(batch, crop=crop_size, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
loaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads, collate_fn=collate_fn,
pin_memory=True) for dataset in datasets]
iteration = start_step
losses = deque(maxlen=10)
for loader in loaders:
loader.dataset.__getitem__(0, debug=True)
for im, label in roundrobin_infinite(*loaders):
# Clear out gradients
opt.zero_grad()
# load data/label
im = make_variable(im, requires_grad=False)
label = make_variable(label, requires_grad=False)
if iteration == 0:
print("im size: {}".format(im.size()))
print("label size: {}".format(label.size()))
# forward pass and compute loss
preds = net(im)
loss = supervised_loss(preds, label)
# backward pass
loss.backward()
losses.append(loss.item())
# step gradients
opt.step()
# log results
if iteration % 10 == 0:
logging.info('Iteration {}:\t{}'.format(iteration, np.mean(losses)))
writer.add_scalar('loss', np.mean(losses), iteration)
iteration += 1
if step is not None and iteration % step == 0:
logging.info('Decreasing learning rate by 0.1.')
step_lr(opt, 0.1)
if iteration % snapshot == 0:
torch.save(net.module.state_dict(),
'{}/iter_{}.pth'.format(output, iteration))
if iteration >= iterations:
logging.info('Optimization complete.')
if __name__ == '__main__':
main()
================================================
FILE: scripts/train_fcn_adda.py
================================================
import logging
import os
import os.path
import sys
from collections import deque
from datetime import datetime
import click
import numpy as np
import torch
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from torch.autograd import Variable
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.data.adda_datasets import AddaDataLoader
from cycada.models import get_model
from cycada.models.models import models
from cycada.models import Discriminator
from cycada.util import config_logging
from cycada.tools.util import make_variable, mmd_loss
def check_label(label, num_cls):
"Check that no labels are out of range"
label_classes = np.unique(label.numpy().flatten())
label_classes = label_classes[label_classes < 255]
if len(label_classes) == 0:
print('All ignore labels')
return False
class_too_large = label_classes.max() > num_cls
if class_too_large or label_classes.min() < 0:
print('Labels out of bound')
print(label_classes)
return False
return True
def forward_pass(net, discriminator, im, requires_grad=False, discrim_feat=False):
if discrim_feat:
score, feat = net(im)
dis_score = discriminator(feat)
else:
score = net(im)
dis_score = discriminator(score)
if not requires_grad:
score = Variable(score.data, requires_grad=False)
return score, dis_score
def supervised_loss(score, label, weights=None):
loss_fn_ = torch.nn.NLLLoss(weight=weights, reduction='mean', ignore_index=255)
loss = loss_fn_(F.log_softmax(score, dim=1), label)
return loss
def discriminator_loss(score, target_val, lsgan=False):
if lsgan:
loss = 0.5 * torch.mean((score - target_val) ** 2)
else:
_, _, h, w = score.size()
target_val_vec = Variable(target_val * torch.ones(1, h, w), requires_grad=False).long().cuda()
loss = supervised_loss(score, target_val_vec)
return loss
def fast_hist(a, b, n):
k = (a >= 0) & (a < n)
return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n)
def seg_accuracy(score, label, num_cls):
_, preds = torch.max(score.data, 1)
hist = fast_hist(label.cpu().numpy().flatten(),
preds.cpu().numpy().flatten(), num_cls)
intersections = np.diag(hist)
unions = (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-8) * 100
acc = np.diag(hist).sum() / hist.sum()
return intersections, unions, acc
@click.command()
@click.argument('output')
@click.option('--dataset', required=True, multiple=True)
@click.option('--datadir', default="", type=click.Path(exists=True))
@click.option('--lr', '-l', default=0.0001)
@click.option('--momentum', '-m', default=0.9)
@click.option('--batch', default=1)
@click.option('--snapshot', '-s', default=5000)
@click.option('--downscale', type=int)
@click.option('--resize', default=None, type=int)
@click.option('--crop_size', default=None, type=int)
@click.option('--half_crop', default=None)
@click.option('--cls_weights', type=click.Path(exists=True))
@click.option('--weights_discrim', type=click.Path(exists=True))
@click.option('--weights_init', type=click.Path(exists=True))
@click.option('--model', default='fcn8s', type=click.Choice(models.keys()))
@click.option('--lsgan/--no_lsgan', default=False)
@click.option('--num_cls', type=int, default=19)
@click.option('--gpu', default='0')
@click.option('--max_iter', default=10000)
@click.option('--lambda_d', default=1.0)
@click.option('--lambda_g', default=1.0)
@click.option('--train_discrim_only', default=False)
@click.option('--with_mmd_loss/--no_mmd_loss', default=False)
@click.option('--discrim_feat/--discrim_score', default=False)
@click.option('--weights_shared/--weights_unshared', default=False)
@click.option('--data_flag', type=str, default=None)
@click.option('--small', type=int, default=2)
def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cls_weights, gpu,
weights_init, num_cls, lsgan, max_iter, lambda_d, lambda_g,
train_discrim_only, weights_discrim, crop_size, weights_shared,
discrim_feat, half_crop, batch, model, data_flag, resize, with_mmd_loss, small):
# So data is sampled in consistent way
np.random.seed(1336)
torch.manual_seed(1336)
logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(model, dataset[0],
dataset[1], lr, lambda_d, lambda_g)
if weights_shared:
logdir += '_weights_shared'
else:
logdir += '_weights_unshared'
if discrim_feat:
logdir += '_discrim_feat'
else:
logdir += '_discrim_score'
logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M')
writer = SummaryWriter(log_dir=logdir)
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
config_logging()
print('Train Discrim Only', train_discrim_only)
if model == 'fcn8s':
net = get_model(model, num_cls=num_cls, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat)
else:
net = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat)
net.cuda()
str_ids = gpu.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
# set gpu ids
if len(gpu_ids) > 0:
torch.cuda.set_device(gpu_ids[0])
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
if weights_shared:
net_src = net # shared weights
else:
net_src = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat)
net_src.eval()
# initialize Discrminator
odim = 1 if lsgan else 2
idim = num_cls if not discrim_feat else 4096
print('Discrim_feat', discrim_feat, idim)
print('Discriminator init weights: ', weights_discrim)
discriminator = Discriminator(input_dim=idim, output_dim=odim,
pretrained=not (weights_discrim == None),
weights_init=weights_discrim).cuda()
discriminator.to(gpu_ids[0])
discriminator = torch.nn.DataParallel(discriminator, gpu_ids)
loader = AddaDataLoader(net.module.transform, dataset, datadir, downscale, resize=resize,
crop_size=crop_size, half_crop=half_crop, batch_size=batch,
shuffle=True, num_workers=16, src_data_flag=data_flag, small=small)
print('dataset', dataset)
# Class weighted loss?
if cls_weights is not None:
weights = np.loadtxt(cls_weights)
else:
weights = None
# setup optimizers
opt_dis = torch.optim.SGD(discriminator.module.parameters(), lr=lr,
momentum=momentum, weight_decay=0.0005)
opt_rep = torch.optim.SGD(net.module.parameters(), lr=lr,
momentum=momentum, weight_decay=0.0005)
iteration = 0
num_update_g = 0
last_update_g = -1
losses_super_s = deque(maxlen=100)
losses_super_t = deque(maxlen=100)
losses_dis = deque(maxlen=100)
losses_rep = deque(maxlen=100)
accuracies_dom = deque(maxlen=100)
intersections = np.zeros([100, num_cls])
iu_deque = deque(maxlen=100)
unions = np.zeros([100, num_cls])
accuracy = deque(maxlen=100)
print('Max Iter:', max_iter)
net.train()
discriminator.train()
loader.loader_src.dataset.__getitem__(0, debug=True)
loader.loader_tgt.dataset.__getitem__(0, debug=True)
while iteration < max_iter:
for im_s, im_t, label_s, label_t in loader:
if iteration == 0:
print("IM S: {}".format(im_s.size()))
print("Label S: {}".format(label_s.size()))
print("IM T: {}".format(im_t.size()))
print("Label T: {}".format(label_t.size()))
if iteration > max_iter:
break
info_str = 'Iteration {}: '.format(iteration)
if not check_label(label_s, num_cls):
continue
###########################
# 1. Setup Data Variables #
###########################
im_s = make_variable(im_s, requires_grad=False)
label_s = make_variable(label_s, requires_grad=False)
im_t = make_variable(im_t, requires_grad=False)
label_t = make_variable(label_t, requires_grad=False)
#############################
# 2. Optimize Discriminator #
#############################
# zero gradients for optimizer
opt_dis.zero_grad()
opt_rep.zero_grad()
# extract features
if discrim_feat:
score_s, feat_s = net_src(im_s)
score_s = Variable(score_s.data, requires_grad=False)
f_s = Variable(feat_s.data, requires_grad=False)
else:
score_s = Variable(net_src(im_s).data, requires_grad=False)
f_s = score_s
dis_score_s = discriminator(f_s)
if discrim_feat:
score_t, feat_t = net(im_t)
score_t = Variable(score_t.data, requires_grad=False)
f_t = Variable(feat_t.data, requires_grad=False)
else:
score_t = Variable(net(im_t).data, requires_grad=False)
f_t = score_t
dis_score_t = discriminator(f_t)
dis_pred_concat = torch.cat((dis_score_s, dis_score_t))
# prepare real and fake labels
batch_t, _, h, w = dis_score_t.size()
batch_s, _, _, _ = dis_score_s.size()
dis_label_concat = make_variable(
torch.cat(
[torch.ones(batch_s, h, w).long(),
torch.zeros(batch_t, h, w).long()]
), requires_grad=False)
# compute loss for discriminator
loss_dis = supervised_loss(dis_pred_concat, dis_label_concat)
(lambda_d * loss_dis).backward()
losses_dis.append(loss_dis.item())
# optimize discriminator
opt_dis.step()
# compute discriminator acc
pred_dis = torch.squeeze(dis_pred_concat.max(1)[1])
dom_acc = (pred_dis == dis_label_concat).float().mean().item()
accuracies_dom.append(dom_acc * 100.)
# add discriminator info to log
info_str += " domacc:{:0.1f} D:{:.3f}".format(np.mean(accuracies_dom),
np.mean(losses_dis))
writer.add_scalar('loss/discriminator', np.mean(losses_dis), iteration)
writer.add_scalar('acc/discriminator', np.mean(accuracies_dom), iteration)
###########################
# Optimize Target Network #
########################### np.mean(accuracies_dom) > dom_acc_thresh
dom_acc_thresh = 60
if train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh:
os.makedirs(output, exist_ok=True)
torch.save(discriminator.module.state_dict(),
'{}/discriminator_abv60.pth'.format(output, iteration))
break
if not train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh:
last_update_g = iteration
num_update_g += 1
if num_update_g % 1 == 0:
print('Updating G with adversarial loss ({:d} times)'.format(num_update_g))
# zero out optimizer gradients
opt_dis.zero_grad()
opt_rep.zero_grad()
# extract features
if discrim_feat:
score_t, feat_t = net(im_t)
score_t = Variable(score_t.data, requires_grad=False)
f_t = feat_t
else:
score_t = net(im_t)
f_t = score_t
# score_t = net(im_t)
dis_score_t = discriminator(f_t)
# create fake label
batch, _, h, w = dis_score_t.size()
target_dom_fake_t = make_variable(torch.ones(batch, h, w).long(),
requires_grad=False)
# compute loss for target net
loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t)
(lambda_g * loss_gan_t).backward()
losses_rep.append(loss_gan_t.item())
writer.add_scalar('loss/generator', np.mean(losses_rep), iteration)
# optimize target net
opt_rep.step()
# log net update info
info_str += ' G:{:.3f}'.format(np.mean(losses_rep))
if (not train_discrim_only) and weights_shared and np.mean(accuracies_dom) > dom_acc_thresh:
print('Updating G using source supervised loss.')
# zero out optimizer gradients
opt_dis.zero_grad()
opt_rep.zero_grad()
# extract features
if discrim_feat:
score_s, feat_s = net(im_s)
else:
score_s = net(im_s)
loss_supervised_s = supervised_loss(score_s, label_s, weights=weights)
if with_mmd_loss:
print("Updating G using discrepancy loss")
lambda_discrepancy = 0.1
loss_mmd = mmd_loss(feat_s, feat_t) * 0.5 + mmd_loss(score_s, score_t) * 0.5
loss_supervised_s += lambda_discrepancy * loss_mmd
loss_supervised_s.backward()
losses_super_s.append(loss_supervised_s.item())
info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s))
writer.add_scalar('loss/supervised/source', np.mean(losses_super_s), iteration)
# optimize target net
opt_rep.step()
# compute supervised losses for target -- monitoring only!!!no backward()
loss_supervised_t = supervised_loss(score_t, label_t, weights=weights)
losses_super_t.append(loss_supervised_t.item())
info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t))
writer.add_scalar('loss/supervised/target', np.mean(losses_super_t), iteration)
###########################
# Log and compute metrics #
###########################
if iteration % 10 == 0 and iteration > 0:
# compute metrics
intersection, union, acc = seg_accuracy(score_t, label_t.data, num_cls)
intersections = np.vstack([intersections[1:, :], intersection[np.newaxis, :]])
unions = np.vstack([unions[1:, :], union[np.newaxis, :]])
accuracy.append(acc.item() * 100)
acc = np.mean(accuracy)
mIoU = np.mean(np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100
iu = (intersection / union) * 10000
iu_deque.append(np.nanmean(iu))
info_str += ' acc:{:0.2f} mIoU:{:0.2f}'.format(acc, np.mean(iu_deque))
writer.add_scalar('metrics/acc', np.mean(accuracy), iteration)
writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration)
logging.info(info_str)
iteration += 1
################
# Save outputs #
################
# every 500 iters save current model
if iteration % 500 == 0:
os.makedirs(output, exist_ok=True)
if not train_discrim_only:
torch.save(net.module.state_dict(),
'{}/net-itercurr.pth'.format(output))
torch.save(discriminator.module.state_dict(),
'{}/discriminator-itercurr.pth'.format(output))
# save labeled snapshots
if iteration % snapshot == 0:
os.makedirs(output, exist_ok=True)
if not train_discrim_only:
torch.save(net.module.state_dict(),
'{}/net-iter{}.pth'.format(output, iteration))
torch.save(discriminator.module.state_dict(),
'{}/discriminator-iter{}.pth'.format(output, iteration))
if iteration - last_update_g >= 3 * len(loader):
print('No suitable discriminator found -- returning.')
torch.save(net.module.state_dict(),
'{}/net-iter{}.pth'.format(output, iteration))
iteration = max_iter # make sure outside loop breaks
break
writer.close()
if __name__ == '__main__':
main()
================================================
FILE: scripts/train_fcn_mdan.py
================================================
import itertools
import json
import logging
import os.path
import subprocess
import sys
from collections import deque
import click
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from tensorboardX import SummaryWriter
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.data.data_loader import get_fcn_dataset as get_dataset
from cycada.models import get_model
from cycada.models.models import models
from cycada.models.MDAN import MDANet
from cycada.transforms import augment_collate
from cycada.util import config_logging
from cycada.util import to_tensor_raw, step_lr
from cycada.tools.util import make_variable
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
def roundrobin_infinite(*loaders):
if not loaders:
return
iters = [iter(loader) for loader in loaders]
while True:
for i in range(len(iters)):
it = iters[i]
try:
yield next(it)
except StopIteration:
iters[i] = iter(loaders[i])
yield next(iters[i])
def multi_source_infinite(loaders, target_loader):
if not loaders:
return
iters_syn = iter(loaders[0])
iters_gta = iter(loaders[1])
iters_cs = iter(target_loader)
while True:
try:
yield next(iters_syn), next(iters_gta), next(iters_cs)
except StopIteration:
iters_syn = iter(loaders[0])
iters_gta = iter(loaders[1])
iters_cs = iter(target_loader)
yield next(iters_syn), next(iters_gta), next(iters_cs)
def supervised_loss(score, label, weights=None):
loss_fn_ = torch.nn.NLLLoss2d(weight=weights, size_average=True, ignore_index=255)
loss = loss_fn_(F.log_softmax(score), label)
return loss
@click.command()
@click.argument('output')
@click.option('--dataset', required=True, multiple=True)
@click.option('--target_name', required=True)
@click.option('--datadir', default="", type=click.Path(exists=True))
@click.option('--batch_size', '-b', default=1)
@click.option('--lr', '-l', default=0.001)
@click.option('--iterations', '-i', default=100000)
@click.option('--momentum', '-m', default=0.9)
@click.option('--snapshot', '-s', default=5000)
@click.option('--downscale', type=int)
@click.option('--resize_to', type=int, default=720)
@click.option('--augmentation/--no-augmentation', default=False)
@click.option('--small', type=int, default=2)
@click.option('--preprocessing', default=False)
@click.option('--fyu/--torch', default=False)
@click.option('--crop_size', default=720)
@click.option('--weights', type=click.Path(exists=True))
@click.option('--model_weights', type=click.Path(exists=True))
@click.option('--model', default='fcn8s', type=click.Choice(models.keys()))
@click.option('--num_cls', default=19, type=int)
@click.option('--nthreads', default=16, type=int)
@click.option('--gpu', default='0')
@click.option('--start_step', default=0)
@click.option('--data_flag', default='', type=str)
@click.option('--rundir_flag', default='', type=str)
@click.option('--serial_batches', type=bool, default=False, help='if true, takes images in order to make batches, otherwise takes them randomly')
def main(output, dataset, target_name, datadir, batch_size, lr, iterations,
momentum, snapshot, downscale, augmentation, fyu, crop_size,
weights, model, gpu, num_cls, nthreads, model_weights, data_flag, serial_batches, resize_to, start_step, preprocessing, small, rundir_flag):
if weights is not None:
raise RuntimeError("weights don't work because eric is bad at coding")
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
config_logging()
logdir_flag = data_flag
if rundir_flag != "":
logdir_flag += "_{}".format(rundir_flag)
logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset), logdir_flag)
writer = SummaryWriter(log_dir=logdir)
if model == 'fcn8s':
net = get_model(model, num_cls=num_cls, weights_init=model_weights, output_last_ft=True)
else:
net = get_model(model, num_cls=num_cls, finetune=True, weights_init=model_weights)
net.cuda()
str_ids = gpu.split(',')
gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
gpu_ids.append(id)
# set gpu ids
if len(gpu_ids) > 0:
torch.cuda.set_device(gpu_ids[0])
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
transform = []
target_transform = []
if preprocessing:
transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.BICUBIC)])
target_transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.NEAREST)])
transform.extend([net.module.transform])
target_transform.extend([to_tensor_raw])
transform = torchvision.transforms.Compose(transform)
target_transform = torchvision.transforms.Compose(target_transform)
datasets = [get_dataset(name, os.path.join(datadir, name), num_cls=num_cls, transform=transform, target_transform=target_transform,
data_flag=data_flag, small=small) for name in dataset]
target_dataset = get_dataset(target_name, os.path.join(datadir, target_name), num_cls=num_cls, transform=transform,
target_transform=target_transform,
data_flag=data_flag, small=small)
if weights is not None:
weights = np.loadtxt(weights)
if augmentation:
collate_fn = lambda batch: augment_collate(batch, crop=crop_size, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
loaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads, collate_fn=collate_fn,
pin_memory=True, drop_last=True) for dataset in datasets]
target_loader = torch.utils.data.DataLoader(target_dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads,
collate_fn=collate_fn,
pin_memory=True, drop_last=True)
iteration = start_step
losses = deque(maxlen=10)
losses_domain_syn = deque(maxlen=10)
losses_domain_gta = deque(maxlen=10)
losses_task = deque(maxlen=10)
for loader in loaders:
loader.dataset.__getitem__(0, debug=True)
input_dim = 2048
configs = {"input_dim": input_dim, "hidden_layers": [1000, 500, 100], "num_classes": 2, 'num_domains': 2, 'mode': 'dynamic', 'mu': 1e-2,
'gamma': 10.0}
mdan = MDANet(configs).to(gpu_ids[0])
mdan = torch.nn.DataParallel(mdan, gpu_ids)
mdan.train()
opt = torch.optim.Adam(itertools.chain(mdan.module.parameters(), net.module.parameters()), lr=1e-4)
# cnt = 0
for (im_syn, label_syn), (im_gta, label_gta), (im_cs, label_cs) in multi_source_infinite(loaders, target_loader):
# cnt += 1
# print(cnt)
# Clear out gradients
opt.zero_grad()
# load data/label
im_syn = make_variable(im_syn, requires_grad=False)
label_syn = make_variable(label_syn, requires_grad=False)
im_gta = make_variable(im_gta, requires_grad=False)
label_gta = make_variable(label_gta, requires_grad=False)
im_cs = make_variable(im_cs, requires_grad=False)
label_cs = make_variable(label_cs, requires_grad=False)
if iteration == 0:
print("im_syn size: {}".format(im_syn.size()))
print("label_syn size: {}".format(label_syn.size()))
print("im_gta size: {}".format(im_gta.size()))
print("label_gta size: {}".format(label_gta.size()))
print("im_cs size: {}".format(im_cs.size()))
print("label_cs size: {}".format(label_cs.size()))
if not (im_syn.size() == im_gta.size() == im_cs.size()):
print(im_syn.size())
print(im_gta.size())
print(im_cs.size())
# forward pass and compute loss
preds_syn, ft_syn = net(im_syn)
# pooled_ft_syn = avg_pool(ft_syn)
preds_gta, ft_gta = net(im_gta)
# pooled_ft_gta = avg_pool(ft_gta)
preds_cs, ft_cs = net(im_cs)
# pooled_ft_cs = avg_pool(ft_cs)
loss_synthia = supervised_loss(preds_syn, label_syn)
loss_gta = supervised_loss(preds_gta, label_gta)
loss = loss_synthia + loss_gta
losses_task.append(loss.item())
logprobs, sdomains, tdomains = mdan(ft_syn, ft_gta, ft_cs)
slabels = torch.ones(batch_size, requires_grad=False).type(torch.LongTensor).to(gpu_ids[0])
tlabels = torch.zeros(batch_size, requires_grad=False).type(torch.LongTensor).to(gpu_ids[0])
# TODO: increase task loss
# Compute prediction accuracy on multiple training sources.
domain_losses = torch.stack([F.nll_loss(sdomains[j], slabels) + F.nll_loss(tdomains[j], tlabels) for j in range(configs['num_domains'])])
losses_domain_syn.append(domain_losses[0].item())
losses_domain_gta.append(domain_losses[1].item())
# Different final loss function depending on different training modes.
if configs['mode'] == "maxmin":
loss = torch.max(loss) + configs['mu'] * torch.min(domain_losses)
elif configs['mode'] == "dynamic":
loss = torch.log(torch.sum(torch.exp(configs['gamma'] * (loss + configs['mu'] * domain_losses)))) / configs['gamma']
# backward pass
loss.backward()
losses.append(loss.item())
torch.nn.utils.clip_grad_norm_(net.module.parameters(), 10)
torch.nn.utils.clip_grad_norm_(mdan.module.parameters(), 10)
# step gradients
opt.step()
# log results
if iteration % 10 == 0:
logging.info(
'Iteration {}:\t{:.3f} Domain SYN: {:.3f} Domain GTA: {:.3f} Task: {:.3f}'.format(iteration, np.mean(losses),
np.mean(losses_domain_syn),
np.mean(losses_domain_gta), np.mean(losses_task)))
writer.add_scalar('loss', np.mean(losses), iteration)
writer.add_scalar('domain_syn', np.mean(losses_domain_syn), iteration)
writer.add_scalar('domain_gta', np.mean(losses_domain_gta), iteration)
writer.add_scalar('task', np.mean(losses_task), iteration)
iteration += 1
if iteration % 500 == 0:
os.makedirs(output, exist_ok=True)
torch.save(net.module.state_dict(), '{}/net-itercurr.pth'.format(output))
if iteration % snapshot == 0:
torch.save(net.module.state_dict(), '{}/iter_{}.pth'.format(output, iteration))
if iteration >= iterations:
logging.info('Optimization complete.')
if __name__ == '__main__':
main()
================================================
FILE: tools/__init__.py
================================================
================================================
FILE: tools/eval_templates.sh
================================================
#!/usr/bin/env bash
export LC_ALL=C.UTF-8
export LANG=C.UTF-8
cd /nfs/project/libo_i/MADAN
ckpt_path=$1
datadir=/nfs/project/libo_i/MADAN/data/cityscapes
model=fcn8s
num_cls=19
gpu=0
sudo python3 scripts/eval_fcn.py ${ckpt_path} \
--dataset cityscapes \
--datadir ${datadir} \
--model ${model} --num_cls ${num_cls} \
--gpu ${gpu}