Showing preview only (259K chars total). Download the full file or copy to clipboard to get everything.
Repository: Luodian/MADAN
Branch: master
Commit: 7a2918da44f5
Files: 88
Total size: 237.8 KB
Directory structure:
gitextract_rf2a9vy3/
├── .gitignore
├── .idea/
│ ├── MADAN.iml
│ ├── deployment.xml
│ ├── inspectionProfiles/
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── remote-mappings.xml
│ └── vcs.xml
├── LICENSE
├── README.md
├── cycada/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── adda_datasets.py
│ │ ├── bdds.py
│ │ ├── cityscapes.py
│ │ ├── cityscapes_labels.py
│ │ ├── cyclegan.py
│ │ ├── cyclegta5.py
│ │ ├── cyclesynthia.py
│ │ ├── cyclesynthia_cyclegta5.py
│ │ ├── data_loader.py
│ │ ├── gta5.py
│ │ ├── rotater.py
│ │ ├── synthia.py
│ │ └── util.py
│ ├── logging.yml
│ ├── models/
│ │ ├── MDAN.py
│ │ ├── __init__.py
│ │ ├── adda_net.py
│ │ ├── drn.py
│ │ ├── fcn8s.py
│ │ ├── models.py
│ │ ├── task_net.py
│ │ └── util.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── train_adda_net.py
│ │ ├── train_task_net.py
│ │ └── util.py
│ ├── transforms.py
│ └── util.py
├── cyclegan/
│ ├── .gitignore
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── cityscapes.py
│ │ ├── gta5_cityscapes.py
│ │ ├── gta_synthia_cityscapes.py
│ │ ├── image_folder.py
│ │ └── synthia_cityscapes.py
│ ├── environment.yml
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── cycle_gan_model.py
│ │ ├── cycle_gan_semantic_model.py
│ │ ├── multi_cycle_gan_semantic_model.py
│ │ ├── networks.py
│ │ └── test_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── requirements.txt
├── scripts/
│ ├── ADDA/
│ │ ├── adda_cyclegta2cs_feat.sh
│ │ ├── adda_cyclegta2cs_score.sh
│ │ ├── adda_cyclesyn2cs_feat.sh
│ │ ├── adda_cyclesyn2cs_score.sh
│ │ └── adda_templates.sh
│ ├── CycleGAN/
│ │ ├── cyclegan_gta2cityscapes.sh
│ │ ├── cyclegan_gta_synthia2cityscapes.sh
│ │ ├── cyclegan_synthia2cityscapes.sh
│ │ ├── test_templates.sh
│ │ └── test_templates_cycle.sh
│ ├── FCN/
│ │ ├── train_fcn8s_cyclesgta5.sh
│ │ └── train_fcn8s_cyclesynthia.sh
│ ├── eval_fcn.py
│ ├── train_fcn.py
│ ├── train_fcn_adda.py
│ └── train_fcn_mdan.py
└── tools/
├── __init__.py
└── eval_templates.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf
# Generated files
.idea/**/contentModel.xml
# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml
# Gradle
.idea/**/gradle.xml
.idea/**/libraries
# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn. Uncomment if using
# auto-import.
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr
# CMake
cmake-build-*/
# Mongo Explorer plugin
.idea/**/mongoSettings.xml
# File-based project format
*.iws
# IntelliJ
out/
# mpeltonen/sbt-idea plugin
.idea_modules/
# JIRA plugin
atlassian-ide-plugin.xml
# Cursive Clojure plugin
.idea/replstate.xml
# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
# Editor-based Rest Client
.idea/httpRequests
# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser
/models/
================================================
FILE: .idea/MADAN.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Remote Python 3.5.2 (sftp://luban@10.84.217.43:8022/usr/bin/python3)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="TestRunnerService">
<option name="projectConfiguration" value="Twisted Trial" />
<option name="PROJECT_TEST_RUNNER" value="Twisted Trial" />
</component>
</module>
================================================
FILE: .idea/deployment.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" autoUpload="Always" serverName="MV_CyCADA_22G">
<serverData>
<paths name="MV_CyCADA_22G">
<serverdata>
<mappings>
<mapping deploy="/nfs/project/libo_i/MADAN" local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
<option name="myAutoUpload" value="ALWAYS" />
</component>
</project>
================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="JavaScriptSettings">
<option name="languageLevel" value="ES6" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.5.2 (sftp://luban@10.84.217.43:8022/usr/bin/python3)" project-jdk-type="Python SDK" />
<component name="PythonCompatibilityInspectionAdvertiser">
<option name="version" value="3" />
</component>
</project>
================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/MADAN.iml" filepath="$PROJECT_DIR$/.idea/MADAN.iml" />
</modules>
</component>
</project>
================================================
FILE: .idea/remote-mappings.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="RemoteMappingsManager">
<list>
<list>
<remote-mappings server-id="python@sftp://luban@10.84.217.43:8022/usr/bin/python3">
<settings>
<list>
<mapping local-root="$PROJECT_DIR$" remote-root="/nfs/project/libo_i/MADAN" />
</list>
</settings>
</remote-mappings>
</list>
</list>
</component>
</project>
================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2019 liljprime
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# MADAN
A Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181)
If you use this code in your research please consider citing:
```
@InProceedings{zhao2019madan,
title = {Multi-source Domain Adaptation for Semantic Segmentation},
author = {Zhao, Sicheng and Li, Bo and Yue, Xiangyu and Gu, Yang and Xu, Pengfei and Tan, Hu, Runbo and Chai, Hua and Keutzer, Kurt},
booktitle = {Advances in Neural Information Processing Systems},
year = {2019}
}
```
## Quick Look
Our multi-source domain adaptation builds on the work [CyCADA](https://github.com/jhoffman/cycada_release) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). Since we focus on Semantic Segmentation task, we remove Digit Classfication part in CyCADA.
We add following modules and achieve startling improvements.
1. Dynamic Semantic Consistency Module
2. Adversarial Aggregation Module
1. Sub-domain Aggregation Discriminator
2. Cross-domain Cycle Discriminator
While we implements [MDAN](https://openreview.net/pdf?id=ryDNZZZAW) for Semantic Segmentation task in Pytorch as our baseline comparasion.
## Overall Structure

## Setup
Check out this repo:
```bash
git clone https://github.com/pikachusocute/MADAN.git
```
Install Python3 requirements
```bash
pip3 install -r requirements.txt
```
## Dynamic Adversarial Image Generation
We follow the way in CyCADA, in the first step, we need to train Image Adaptation module to transfer source image(GTA, Synthia or Multi-source) to "source as target".

We refer Image Adaptation module from GTA to Cityscapes as GTA->Cityscapes in the following.
#### GTA->Cityscapes
```bash
cd scripts/CycleGAN
bash cyclegan_gta2cityscapes.sh
```
In the training process, snapshot files will be stored in `cyclegan/checkpoints/[EXP_NAME]`.
Usually, afer we run for 20 epochs, there'll be a file `20_net_G_A.pth ` in previous folder path.
Then we run the test process.
```bash
bash scripts/CycleGAN/test_templates.sh [EXP_NAME] 20 cycle_gan_semantic_fcn gta5_cityscapes
```
In multi-source case, there are both `20_net_G_A_1.pth` and `20_net_G_A_2.pth` exist. We use another script to run test process.

```bash
bash scripts/CycleGAN/test_templates_cycle.sh [EXP_NAME] 20 test synthia_cityscapes gta5_cityscapes
```
New dataset will be generated at `~/cyclegan/results/[EXP_NAME]/train_20`.
After we obtain a new source stylized dataset, we then train segmenter on the new dataset.
## Pixel Level Adaptation
In this part, we train our new segmenter on new dataset.
```bash
ln -s ~/cyclegan/results/[EXP_NAME]/train_20 ~/data/cyclegta5/[EXP_NAME]_TRAIN_60
```
Then we set `dataflag = [EXP_NAME]_TRAIN_60` to find datasets' paths, and follow instructions to train segmenter to perform pixel level adaptation.
```bash
bash scripts/FCN/train_fcn8s_cyclesgta5_DSC.sh
```
## Feature Level Adaptation
For adaptation, we use
```bash
bash scripts/ADDA/adda_cyclegta2cs_score.sh
```
Make sure you choose the desired `src` and `tgt` and `datadir` before. In this process, you should load your `base_model` trained on synthetic dataset and perform adaptation in feature level to real scene dataset.
### Our Model
We release our adaptation model in the `./models`, you can use `scripts/eval_templates.sh` to evaluate its validity.
1. [CycleGTA5_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/1moGF7L2hkTHUPUzqsSxPwKNlHCHQm4Ms/view?usp=sharing)
2. [CycleSYNTHIA_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/19V5J1zyF3ct3247gSSr-u3WVkDJqQvUk/view?usp=sharing)
3. [Multi_Source_SAD_CCD](https://drive.google.com/file/d/1xgmLwhsbwv-isy7R5FkNevVSH9gcMxuq/view?usp=sharing)
### Transfered Dataset
We will release our transfer dataset soon, where our `CycleGTA5_Dynamic_Semantic_Consistency` model is trained to perform pixel level adaptation.
================================================
FILE: cycada/__init__.py
================================================
================================================
FILE: cycada/data/__init__.py
================================================
from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds
from . import adda_datasets
================================================
FILE: cycada/data/adda_datasets.py
================================================
import os.path
import torch.utils.data
from .data_loader import get_transform_dataset
from ..transforms import augment_collate
class AddaDataLoader(object):
def __init__(self, net_transform, dataset, rootdir, downscale, crop_size=None, resize=None,
batch_size=1, shuffle=False, num_workers=2, half_crop=None, src_data_flag=None, small=False):
self.dataset = dataset
self.downscale = downscale
self.resize = resize
self.crop_size = crop_size
self.half_crop = half_crop
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
assert len(self.dataset) == 2, 'Requires two datasets: source, target'
sourcedir = os.path.join(rootdir, self.dataset[0])
targetdir = os.path.join(rootdir, self.dataset[1])
self.source = get_transform_dataset(self.dataset[0], sourcedir, net_transform, downscale, resize, src_data_flag=src_data_flag, small=small)
self.target = get_transform_dataset(self.dataset[1], targetdir, net_transform, downscale, resize, small=small)
print('Source length:', len(self.source), 'Target length:', len(self.target))
self.n = max(len(self.source), len(self.target)) # make sure you see all images
self.num = 0
self.set_loader_src()
self.set_loader_tgt()
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
if self.num % len(self.iters_src) == 0:
print('restarting source dataset')
self.set_loader_src()
if self.num % len(self.iters_tgt) == 0:
print('restarting target dataset')
self.set_loader_tgt()
img_src, label_src = next(self.iters_src)
img_tgt, label_tgt = next(self.iters_tgt)
self.num += 1
return img_src, img_tgt, label_src, label_tgt
def __len__(self):
return min(len(self.source), len(self.target))
def set_loader_src(self):
batch_size = self.batch_size
shuffle = self.shuffle
num_workers = self.num_workers
if self.crop_size is not None or self.resize is not None:
collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
halfcrop=self.half_crop, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
self.loader_src = torch.utils.data.DataLoader(self.source,
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)
self.iters_src = iter(self.loader_src)
def set_loader_tgt(self):
batch_size = self.batch_size
shuffle = self.shuffle
num_workers = self.num_workers
if self.crop_size is not None or self.resize is not None:
collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
halfcrop=self.half_crop, flip=True)
else:
collate_fn = torch.utils.data.dataloader.default_collate
self.loader_tgt = torch.utils.data.DataLoader(self.target,
batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)
self.iters_tgt = iter(self.loader_tgt)
================================================
FILE: cycada/data/bdds.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import register_dataset_obj
@register_dataset_obj('bdds')
class BDDS(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
self.root = root
self.split = split
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.data_flag = data_flag
self.num_cls = num_cls
self.ids = self.collect_ids()
def collect_ids(self):
splits = []
path = os.path.join(self.root, "images", self.split)
files = os.listdir(path)
for item in files:
fip = os.path.join(path, item)
splits.append(fip.split('/')[-1])
return splits
def img_path(self, filename):
return os.path.join(self.root, "images", self.split, filename)
def label_path(self, filename):
return os.path.join(self.root, 'labels', self.split, "{}_train_id.png".format(filename[:-4]))
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cityscapes.py
================================================
import os.path
import sys
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
def remap_labels_to_train_ids(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cityscapes')
class CityScapesParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('cityscapes')
class Cityscapes(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None,
target_transform=None):
self.root = root
sys.path.append(root)
self.split = split
self.remap_labels = remap_labels
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
self.num_cls = num_cls
self.id2label = id2label
self.classes = classes
def collect_ids(self):
im_dir = os.path.join(self.root, 'leftImg8bit', self.split)
ids = []
for dirpath, dirnames, filenames in os.walk(im_dir):
for filename in filenames:
if filename.endswith('.png'):
ids.append('_'.join(filename.split('_')[:3]))
return ids
def img_path(self, id):
fmt = 'leftImg8bit/{}/{}/{}_leftImg8bit.png'
subdir = id.split('_')[0]
path = fmt.format(self.split, subdir, id)
return os.path.join(self.root, path)
def label_path(self, id):
fmt = 'gtFine/{}/{}/{}_gtFine_labelIds.png'
subdir = id.split('_')[0]
path = fmt.format(self.split, subdir, id)
return os.path.join(self.root, path)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img = Image.open(self.img_path(id)).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(self.label_path(id)).convert('L')
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(np.uint8(target), 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cityscapes_labels.py
================================================
# function for colorizing a label image:
# camera-ready
import numpy as np
def label_img_to_color(img):
label_to_color = {
0: [128, 64, 128],
1: [244, 35, 232],
2: [70, 70, 70],
3: [102, 102, 156],
4: [190, 153, 153],
5: [153, 153, 153],
6: [250, 170, 30],
7: [220, 220, 0],
8: [107, 142, 35],
9: [152, 251, 152],
10: [70, 130, 180],
11: [220, 20, 60],
12: [255, 0, 0],
13: [0, 0, 142],
14: [0, 0, 70],
15: [0, 60, 100],
16: [0, 80, 100],
17: [0, 0, 230],
18: [119, 11, 32]
}
img_height, img_width = img.shape
img_color = np.zeros((img_height, img_width, 3))
for row in range(img_height):
for col in range(img_width):
label = img[row, col]
img_color[row, col] = np.array(label_to_color[label])
return img_color
================================================
FILE: cycada/data/cyclegan.py
================================================
import os
from os.path import join
import glob
from PIL import Image
import torch.utils.data as data
from .data_loader import DatasetParams
from .data_loader import register_dataset_obj, register_data_params
class CycleGANDataset(data.Dataset):
def __init__(self, root, regexp, transform=None, target_transform=None,
download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.image_paths, self.labels = self.find_images(regexp)
def find_images(self, regexp='*.png'):
basenames = sorted(glob.glob(join(self.root, regexp)))
image_paths = []
labels = []
for basename in basenames:
image_paths.append(os.path.join(self.root, basename))
labels.append(int(basename.split('/')[-1].split('_')[0]))
return image_paths, labels
def __getitem__(self, index):
im = Image.open(self.image_paths[index]) #.convert('L')
target = self.labels[index]
if self.transform is not None:
im = self.transform(im)
if self.target_transform is not None:
target = self.target_transform(target)
return im, target
def __len__(self):
return len(self.image_paths)
@register_dataset_obj('svhn2mnist')
class Svhn2MNIST(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for svhn2mnist.')
self.image_paths = []
else:
super(Svhn2MNIST, self).__init__(root, '*_fake_B.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('svhn2mnist')
class Svhn2MNISTParams(DatasetParams):
num_channels = 3
image_size = 32
mean = 0.5
std = 0.5
#mean = 0.1307
#std = 0.3081
# mean and std (when scaled between [0,1])
#mean = 0.127 # ep50
#mean = 0.21 # ep100 -- more white pixels...
#std = 0.29
#mean = 0.21
#std = 0.2
num_cls = 10
target_transform = None
@register_dataset_obj('usps2mnist')
class Usps2Mnist(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for usps2mnist.')
self.image_paths = []
else:
super(Usps2Mnist, self).__init__(root, '*_fake_A.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('usps2mnist')
class Usps2MnistParams(DatasetParams):
num_channels = 3
image_size = 16
#mean = 0.1307
#std = 0.3081
mean = 0.5
std = 0.5
num_cls = 10
target_transform = None
@register_dataset_obj('mnist2usps')
class Mnist2Usps(CycleGANDataset):
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
if not train:
print('No test set for mnist2usps.')
self.image_paths = []
else:
super(Mnist2Usps, self).__init__(root, '*_fake_B.png',
transform=transform, target_transform=target_transform,
download=download)
@register_data_params('mnist2usps')
class Mnist2UspsParams(DatasetParams):
num_channels = 3
image_size = 16 # this seems wrong...
#mean = 0.25
#std = 0.37
#mean = 0.1307
#std = 0.3081
mean = 0.5
std = 0.5
num_cls = 10
target_transform = None
================================================
FILE: cycada/data/cyclegta5.py
================================================
import os.path
import numpy as np
from PIL import Image
from .cityscapes import remap_labels_to_train_ids
from .data_loader import register_dataset_obj
from .gta5 import GTA5 # , LABEL2TRAIN
@register_dataset_obj('cyclegta5')
class CycleGTA5(GTA5):
def collect_ids(self):
# ids = GTA5.collect_ids(self)
existing_ids = []
if self.data_flag:
path = os.path.join(self.root, self.data_flag)
else:
path = os.path.join(self.root, "images")
files = os.listdir(path)
for item in files:
full_path = os.path.join(path, item)
if os.path.exists(full_path) is False:
continue
existing_ids.append(full_path.split('/')[-1])
return sorted(existing_ids)
def __getitem__(self, index, debug=False):
filename = self.ids[index]
if self.data_flag == '' or self.data_flag is None:
img_path = os.path.join(self.root, "images", filename)
else:
img_path = os.path.join(self.root, self.data_flag, filename)
if self.data_flag == '' or self.data_flag is None:
label_path = os.path.join(self.root, 'labels_600x1080', filename)
else:
if filename.endswith('_fake_B.png'):
label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B.png', '.png'))
elif filename.endswith('_fake_B_2.png'):
label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B_2.png', '.png'))
img = Image.open(img_path).convert('RGB')
target = Image.open(label_path)
img = img.resize(target.size, resample=Image.BILINEAR)
if self.transform is not None:
img = self.transform(img)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
================================================
FILE: cycada/data/cyclesynthia.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cyclesynthia')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('cyclesynthia')
class CycleSYNTHIA(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
self.root = root.replace('cycle', '')
self.split = split
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.ids = self.collect_ids()
def collect_ids(self):
splits = []
if self.data_flag:
path = os.path.join(self.root, self.data_flag)
else:
path = os.path.join(self.root, 'Cycle')
files = os.listdir(path)
for item in files:
fip = os.path.join(path, item)
if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
return splits
def img_path(self, filename):
return os.path.join(self.root, filename)
def label_path(self, filename):
# Case for loading images generated in multi-source cycle
# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
if filename.endswith('_fake_B_1.png'):
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/cyclesynthia_cyclegta5.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .cityscapes import remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('cyclesynthia_cyclegta5')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
# In this class, we iteratively load transferred images from cyclesynthia and cyclegta5
@register_dataset_obj('cyclesynthia_cyclegta5')
class CycleSYNTHIACycleGTA5(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
self.dataset_name = os.path.basename(root)
self.parent_path = root.replace(self.dataset_name, '')
self.syn_name = os.path.join(self.parent_path, 'synthia')
self.gta_name = os.path.join(self.parent_path, 'cyclegta5')
self.remap_labels = remap_labels
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.syn_ids = self.collect_ids('syn')
self.gta_ids = self.collect_ids('gta')
def collect_ids(self, datasets_name):
splits = []
if datasets_name == 'syn':
files = os.listdir(self.syn_name)
for item in files:
fip = os.path.join(self.syn_name, item)
if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
elif datasets_name == 'gta':
files = os.listdir(self.gta_name)
for item in files:
fip = os.path.join(self.gta_name, item)
if (fip.endswith('_fake_B_2.png') or fip.endswith('_fake_B.png')):
splits.append(fip.split('/')[-1])
else:
print("Don't Recognize {}".format(datasets_name))
return splits
def img_path(self, prefix, filename):
return os.path.join(prefix, filename)
# Case for loading images generated in multi-source cycle
# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
def syn_label_path(self, filename):
if filename.endswith('_fake_B_1.png'):
return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
def gta_label_path(self, filename):
if filename.endswith('_fake_B_2.png'):
return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B_2.png', '.png'))
elif filename.endswith('_fake_B.png'):
return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B.png', '.png'))
def __getitem__(self, index, debug=False):
# we iteratively load images from cyclesynthia and cyclegta5
if index % 2:
id = self.syn_ids[index % len(self.syn_ids)]
img_path = self.img_path(self.syn_name, id)
label_path = self.syn_label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
else:
id = self.gta_ids[index % len(self.gta_ids)]
img_path = self.img_path(self.gta_name, id)
label_path = self.gta_label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
# if debug:
# print(self.__class__.__name__)
# print("IMG Path: {}".format(img_path))
# print("Label Path: {}".format(label_path))
#
return img, target
def __len__(self):
return len(self.syn_ids) + len(self.gta_ids)
================================================
FILE: cycada/data/data_loader.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from ..util import to_tensor_raw
def load_data(name, dset, batch=64, rootdir='', num_channels=3,
image_size=32, download=True, kwargs={}):
is_train = (dset == 'train')
if isinstance(name, list) and len(name) == 2: # load adda data
src_dataset = get_dataset(name[0], join(rootdir, name[0]), dset,
image_size, num_channels, download=download)
tgt_dataset = get_dataset(name[1], join(rootdir, name[1]), dset,
image_size, num_channels, download=download)
dataset = AddaDataset(src_dataset, tgt_dataset)
else:
dataset = get_dataset(name, rootdir, dset, image_size, num_channels,
download=download)
if len(dataset) == 0:
return None
loader = torch.utils.data.DataLoader(dataset, batch_size=batch,
shuffle=is_train, **kwargs)
return loader
def get_transform_dataset(dataset_name, rootdir, net_transform, downscale, resize=None, src_data_flag=None, small=False):
user_paths = os.environ['PYTHONPATH'].split(os.pathsep)
transform, target_transform = get_transform2(dataset_name, net_transform, downscale, resize)
return get_fcn_dataset(dataset_name, rootdir, transform=transform, target_transform=target_transform, data_flag=src_data_flag, small=small)
sizes = {'cyclesynthia_cyclegta5': 1280, 'cyclesynthia': 1280, 'cityscapes': 1280, 'gta5': 1280, 'cyclegta5': 1280, "synthia": 1280}
def get_orig_size(dataset_name):
"Size of images in the dataset for relative scaling."
try:
return sizes[dataset_name]
except:
raise Exception('Unknown dataset size:', dataset_name)
def get_transform2(dataset_name, net_transform, downscale, resize):
"Returns image and label transform to downscale, crop and prepare for net."
orig_size = get_orig_size(dataset_name)
transform = []
target_transform = []
if downscale is not None:
transform.append(transforms.Resize(orig_size // downscale))
target_transform.append(transforms.Resize(orig_size // downscale, interpolation=Image.NEAREST))
if resize is not None:
transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.BICUBIC)])
target_transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.NEAREST)])
transform.extend([net_transform])
target_transform.extend([to_tensor_raw])
transform = transforms.Compose(transform)
target_transform = transforms.Compose(target_transform)
return transform, target_transform
def get_transform(params, image_size, num_channels):
# Transforms for PIL Images: Gray <-> RGB
Gray2RGB = transforms.Lambda(lambda x: x.convert('RGB'))
RGB2Gray = transforms.Lambda(lambda x: x.convert('L'))
transform = []
# Does size request match original size?
if not image_size == params.image_size:
transform.append(transforms.Resize(image_size))
# Does number of channels requested match original?
if not num_channels == params.num_channels:
if num_channels == 1:
transform.append(RGB2Gray)
elif num_channels == 3:
transform.append(Gray2RGB)
else:
print('NumChannels should be 1 or 3', num_channels)
raise Exception
transform += [transforms.ToTensor(),
transforms.Normalize((params.mean,), (params.std,))]
return transforms.Compose(transform)
def get_target_transform(params):
transform = params.target_transform
t_uniform = transforms.Lambda(lambda x: x[:, 0]
if isinstance(x, (list, np.ndarray)) and len(x) == 2 else x)
if transform is None:
return t_uniform
else:
return transforms.Compose([transform, t_uniform])
class AddaDataset(data.Dataset):
def __init__(self, src_data, tgt_data):
self.src = src_data
self.tgt = tgt_data
def __getitem__(self, index):
ns = len(self.src)
nt = len(self.tgt)
xs, ys = self.src[index % ns]
xt, yt = self.tgt[index % nt]
return (xs, ys), (xt, yt)
def __len__(self):
return min(len(self.src), len(self.tgt))
data_params = {}
def register_data_params(name):
def decorator(cls):
data_params[name] = cls
return cls
return decorator
dataset_obj = {}
def register_dataset_obj(name):
def decorator(cls):
dataset_obj[name] = cls
return cls
return decorator
class DatasetParams(object):
"Class variables defined."
num_channels = 1
image_size = 16
mean = 0.1307
std = 0.3081
num_cls = 10
target_transform = None
def get_dataset(name, rootdir, dset, image_size, num_channels, download=True):
is_train = (dset == 'train')
print('get dataset:', name, rootdir, dset)
params = data_params[name]
transform = get_transform(params, image_size, num_channels)
target_transform = get_target_transform(params)
return dataset_obj[name](rootdir, train=is_train, transform=transform,
target_transform=target_transform, download=download)
def get_fcn_dataset(name, rootdir, **kwargs):
return dataset_obj[name](rootdir, **kwargs)
================================================
FILE: cycada/data/gta5.py
================================================
import os.path
import numpy as np
import scipy.io
import torch.utils.data as data
from PIL import Image
from .cityscapes import id2label as LABEL2TRAIN, remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
@register_data_params('gta5')
class GTA5Params(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('gta5')
class GTA5(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
self.root = root
self.split = split
self.remap_labels = remap_labels
self.data_flag = data_flag
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
m = scipy.io.loadmat(os.path.join(self.root, 'mapping.mat'))
full_classes = [x[0] for x in m['classes'][0]]
self.classes = []
for old_id, new_id in LABEL2TRAIN.items():
if not new_id == 255 and old_id > 0:
self.classes.append(full_classes[old_id])
self.num_cls = num_cls
def collect_ids(self):
splits = scipy.io.loadmat(os.path.join(self.root, 'split.mat'))
ids = splits['{}Ids'.format(self.split)].squeeze()
return ids
def img_path(self, id):
filename = '{:05d}.png'.format(id)
return os.path.join(self.root, 'images', filename)
def label_path(self, id):
filename = '{:05d}.png'.format(id)
return os.path.join(self.root, 'labels', filename)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = remap_labels_to_train_ids(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/rotater.py
================================================
class Rotater(object):
def __init__(self, dataset, orientations=6, transform=None,
target_transform=None):
self.dataset = dataset
self.orientations = orientations
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
im, target = self.dataset[index]
rotation = index % self.orientations
degrees = 360 / self.orientations * rotation
im = im.rotate(degrees)
if self.transform is not None:
im = self.transform(im)
if self.target_transform is not None:
target = self.target_transform(target)
return im, target, degrees
def __len__(self):
return len(self.dataset)
================================================
FILE: cycada/data/synthia.py
================================================
import os.path
import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
@register_data_params('synthia')
class SYNTHIAParams(DatasetParams):
num_channels = 3
image_size = 1024
mean = 0.5
std = 0.5
num_cls = 19
target_transform = None
@register_dataset_obj('synthia')
class SYNTHIA(data.Dataset):
def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None, small=2):
self.root = root
self.split = split
self.small = small
self.remap_labels = remap_labels
self.ids = self.collect_ids()
self.transform = transform
self.target_transform = target_transform
self.classes = classes
self.num_cls = num_cls
self.data_flag = data_flag
def collect_ids(self):
splits = []
with open(os.path.join(self.root, 'SYNTHIA_imagelist_{}.txt'.format(self.split))) as f:
for line in f:
line = line.strip('\n')
splits.append(line.split('/')[-1])
return splits
def img_path(self, filename):
if self.small == 0:
return os.path.join(self.root, 'RGB_300x540', filename)
elif self.small == 1:
return os.path.join(self.root, 'RGB_600x1080', filename)
else:
return os.path.join(self.root, 'RGB', filename)
def label_path(self, filename):
if self.small == 0:
return os.path.join(self.root, 'GT', 'parsed_LABELS_300x540', filename)
elif self.small == 1:
return os.path.join(self.root, 'GT', 'parsed_LABELS_600x1080', filename)
else:
return os.path.join(self.root, 'GT', 'parsed_LABELS', filename)
def __getitem__(self, index, debug=False):
id = self.ids[index]
img_path = self.img_path(id)
label_path = self.label_path(id)
if debug:
print(self.__class__.__name__)
print("IMG Path: {}".format(img_path))
print("Label Path: {}".format(label_path))
img = Image.open(img_path).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = Image.open(label_path)
if self.remap_labels:
target = np.asarray(target)
target = syn_relabel(target)
target = Image.fromarray(target, 'L')
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
================================================
FILE: cycada/data/util.py
================================================
import logging
import os.path
import requests
logger = logging.getLogger(__name__)
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
def maybe_download(url, dest):
"""Download the url to dest if necessary, optionally checking file
integrity.
"""
if not os.path.exists(dest):
logger.info('Downloading %s to %s', url, dest)
download(url, dest)
def download(url, dest):
"""Download the url to dest, overwriting dest if it already exists."""
response = requests.get(url, stream=True)
with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
================================================
FILE: cycada/logging.yml
================================================
---
version: 1
disable_existing_loggers: False
formatters:
simple:
format: "[%(asctime)s] %(levelname)-8s %(message)s"
color:
class: colorlog.ColoredFormatter
format: "[%(asctime)s] %(log_color)s%(levelname)-8s%(reset)s %(message)s"
log_colors:
DEBUG: "cyan"
INFO: "green"
WARNING: "yellow"
ERROR: "red"
CRITICAL: "red,bg_white"
handlers:
console:
class: cycada.util.TqdmHandler
level: INFO
formatter: color
file_handler:
class: logging.FileHandler
level: INFO
formatter: simple
encoding: utf8
root:
level: INFO
handlers: [console, file_handler]
================================================
FILE: cycada/models/MDAN.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class GradientReversalLayer(torch.autograd.Function):
"""
Implement the gradient reversal layer for the convenience of domain adaptation neural network.
The forward part is the identity function while the backward part is the negative function.
"""
def forward(self, inputs):
return inputs
def backward(self, grad_output):
grad_input = grad_output.clone()
grad_input = -grad_input
return grad_input
class MDANet(nn.Module):
"""
Multi-layer perceptron with adversarial regularizer by domain classification.
"""
def __init__(self, configs):
super(MDANet, self).__init__()
self.pooling_layer = nn.AdaptiveAvgPool2d((2, 2))
self.dim_reduction = nn.Conv2d(4096, 512, kernel_size=1)
nn.init.xavier_normal_(self.dim_reduction.weight)
nn.init.constant_(self.dim_reduction.bias, 0.1)
self.input_dim = configs["input_dim"]
self.num_hidden_layers = len(configs["hidden_layers"])
self.num_neurons = [] + [self.input_dim] + configs["hidden_layers"]
self.num_domains = configs["num_domains"]
# Parameters of hidden, fully-connected layers, feature learning component.
self.hiddens = nn.ModuleList([nn.Linear(self.num_neurons[i], self.num_neurons[i + 1])
for i in range(self.num_hidden_layers)])
# Parameter of the final softmax classification layer.
self.softmax = nn.Linear(self.num_neurons[-1], configs["num_classes"])
# Parameter of the domain classification layer, multiple sources single target domain adaptation.
self.domains = nn.ModuleList([nn.Linear(self.num_neurons[-1], 2) for _ in range(self.num_domains)])
# Gradient reversal layer.
self.grls = [GradientReversalLayer() for _ in range(self.num_domains)]
def forward(self, sinputs_syn, sinputs_gta, tinputs):
"""
:param sinputs: A list of k inputs from k source domains.
:param tinputs: Input from the target domain.
:return:
"""
sinputs_gta = self.pooling_layer(sinputs_gta)
sinputs_syn = self.pooling_layer(sinputs_syn)
tinputs = self.pooling_layer(tinputs)
sinputs_gta = self.dim_reduction(sinputs_gta)
sinputs_syn = self.dim_reduction(sinputs_syn)
tinputs = self.dim_reduction(tinputs)
b = sinputs_gta.size()[0]
syn_relu, gta_relu, th_relu = sinputs_syn.view(b, -1), sinputs_gta.view(b, -1), tinputs.view(b, -1)
assert (syn_relu[0].size()[0] == self.input_dim)
for hidden in self.hiddens:
syn_relu = F.relu(hidden(syn_relu))
gta_relu = F.relu(hidden(gta_relu))
for hidden in self.hiddens:
th_relu = F.relu(hidden(th_relu))
# Classification probabilities on k source domains.
logprobs = []
logprobs.append(F.log_softmax(self.softmax(syn_relu), dim=1))
logprobs.append(F.log_softmax(self.softmax(gta_relu), dim=1))
# Domain classification accuracies.
sdomains, tdomains = [], []
sdomains.append(F.log_softmax(self.domains[0](self.grls[0](syn_relu)), dim=1))
tdomains.append(F.log_softmax(self.domains[0](self.grls[0](th_relu)), dim=1))
sdomains.append(F.log_softmax(self.domains[1](self.grls[1](gta_relu)), dim=1))
tdomains.append(F.log_softmax(self.domains[1](self.grls[1](th_relu)), dim=1))
return logprobs, sdomains, tdomains
def inference(self, inputs):
h_relu = inputs
for hidden in self.hiddens:
h_relu = F.relu(hidden(h_relu))
# Classification probability.
logprobs = F.log_softmax(self.softmax(h_relu), dim=1)
return logprobs
================================================
FILE: cycada/models/__init__.py
================================================
from .models import get_model
from .task_net import LeNet
from .task_net import DTNClassifier
from .adda_net import AddaNet
from .fcn8s import VGG16_FCN8s, Discriminator
from .drn import drn26
================================================
FILE: cycada/models/adda_net.py
================================================
import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from .util import init_weights
from .models import register_model, get_model
@register_model('AddaNet')
class AddaNet(nn.Module):
"Defines and Adda Network."
def __init__(self, num_cls=10, model='LeNet', src_weights_init=None,
weights_init=None):
super(AddaNet, self).__init__()
self.name = 'AddaNet'
self.base_model = model
self.num_cls = num_cls
self.cls_criterion = nn.CrossEntropyLoss()
self.gan_criterion = nn.CrossEntropyLoss()
self.setup_net()
if weights_init is not None:
self.load(weights_init)
elif src_weights_init is not None:
self.load_src_net(src_weights_init)
else:
raise Exception('AddaNet must be initialized with weights.')
def forward(self, x_s, x_t):
"""Pass source and target images through their
respective networks."""
score_s, x_s = self.src_net(x_s, with_ft=True)
score_t, x_t = self.tgt_net(x_t, with_ft=True)
if self.discrim_feat:
d_s = self.discriminator(x_s)
d_t = self.discriminator(x_t)
else:
d_s = self.discriminator(score_s)
d_t = self.discriminator(score_t)
return score_s, score_t, d_s, d_t
def setup_net(self):
"""Setup source, target and discriminator networks."""
self.src_net = get_model(self.base_model, num_cls=self.num_cls)
self.tgt_net = get_model(self.base_model, num_cls=self.num_cls)
input_dim = self.num_cls
self.discriminator = nn.Sequential(
nn.Linear(input_dim, 500),
nn.ReLU(),
nn.Linear(500, 500),
nn.ReLU(),
nn.Linear(500, 2),
)
self.image_size = self.src_net.image_size
self.num_channels = self.src_net.num_channels
def load(self, init_path):
"Loads full src and tgt models."
net_init_dict = torch.load(init_path)
self.load_state_dict(net_init_dict)
def load_src_net(self, init_path):
"""Initialize source and target with source
weights."""
self.src_net.load(init_path)
self.tgt_net.load(init_path)
def save(self, out_path):
torch.save(self.state_dict(), out_path)
def save_tgt_net(self, out_path):
torch.save(self.tgt_net.state_dict(), out_path)
================================================
FILE: cycada/models/drn.py
================================================
import math
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torchvision
from .models import register_model
from ..util import safe_load_state_dict
__all__ = ['DRN', 'drn26', 'drn42', 'drn58']
model_urls = {
'drn26': 'https://tigress-web.princeton.edu/~fy/drn/models/drn26-ddedf421.pth',
'drn42': 'https://tigress-web.princeton.edu/~fy/drn/models/drn42-9d336e8c.pth',
'drn58': 'https://tigress-web.princeton.edu/~fy/drn/models/drn58-0a53a92c.pth'
}
def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=padding, bias=False, dilation=dilation)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
dilation=(1, 1), residual=True):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride,
padding=dilation[0], dilation=dilation[0])
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes,
padding=dilation[1], dilation=dilation[1])
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
self.residual = residual
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
if self.residual:
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
dilation=(1, 1), residual=True):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=dilation[1], bias=False,
dilation=dilation[1])
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class DRN(nn.Module):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __init__(self, block, layers, num_cls=1000,
channels=(16, 32, 64, 128, 256, 512, 512, 512),
out_map=False, out_middle=False, pool_size=28,
weights_init=None, pretrained=True, finetune=False,
output_last_ft=False, modelname='drn26'):
if output_last_ft:
print('DRN discrim feat not implemented, using scores')
super(DRN, self).__init__()
self.inplanes = channels[0]
self.output_last_ft = output_last_ft
self.out_map = out_map
self.out_dim = channels[-1]
self.out_middle = out_middle
self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(channels[0])
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1)
self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2)
self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)
self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2,
new_level=False)
self.layer6 = None if layers[5] == 0 else \
self._make_layer(block, channels[5], layers[5], dilation=4,
new_level=False)
self.layer7 = None if layers[6] == 0 else \
self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
new_level=False, residual=False)
self.layer8 = None if layers[7] == 0 else \
self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
new_level=False, residual=False)
if num_cls > 0:
self.avgpool = nn.AvgPool2d(pool_size)
# self.fc = nn.Linear(self.out_dim, num_classes)
self.fc = nn.Conv2d(self.out_dim, num_cls, kernel_size=1,
stride=1, padding=0, bias=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if pretrained:
if not weights_init is None:
state_dict = torch.load(weights_init)
print('Using state dict from', weights_init)
else:
state_dict = model_zoo.load_url(model_urls[modelname])
if finetune:
del state_dict['fc.weight']
del state_dict['fc.bias']
safe_load_state_dict(self, state_dict)
print('Finetune: remove last layer')
else:
self.load_state_dict(state_dict)
print('Loading full model')
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
new_level=True, residual=True):
assert dilation == 1 or dilation % 2 == 0
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(
self.inplanes, planes, stride, downsample,
dilation=(1, 1) if dilation == 1 else (
dilation // 2 if new_level else dilation, dilation),
residual=residual))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, residual=residual,
dilation=(dilation, dilation)))
return nn.Sequential(*layers)
def forward(self, x):
_, _, h, w = x.size()
y = list()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
y.append(x)
x = self.layer2(x)
y.append(x)
x = self.layer3(x)
y.append(x)
x = self.layer4(x)
y.append(x)
x = self.layer5(x)
y.append(x)
if self.layer6 is not None:
x = self.layer6(x)
y.append(x)
if self.layer7 is not None:
x = self.layer7(x)
y.append(x)
if self.layer8 is not None:
x = self.layer8(x)
y.append(x)
if self.output_last_ft:
ft_to_save = x
if self.out_map:
x = self.fc(x)
x = nn.functional.interpolate(x, (h, w), mode='bilinear', align_corners=True)
else:
x = self.avgpool(x)
x = self.fc(x)
x = x.view(x.size(0), -1)
if self.out_middle:
return x, y
elif self.output_last_ft:
return x, ft_to_save
else:
return x
@register_model('drn26')
def drn26(pretrained=True, finetune=False, out_map=True, **kwargs):
model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], modelname='drn26',
out_map=out_map, finetune=finetune, **kwargs)
# if pretrained:
# state_dict = model_zoo.load_url(model_urls['drn26'])
# if finetune:
# del state_dict['fc.weight']
# del state_dict['fc.bias']
# safe_load_state_dict(model, state_dict)
# else:
# model.load_state_dict(state_dict)
return model
@register_model('drn42')
def drn42(pretrained=False, finetune=False, out_map=True, **kwargs):
model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], modelname='drn42',
out_map=out_map, finetune=finetune, **kwargs)
# if pretrained:
# model.load_state_dict(model_zoo.load_url(model_urls['drn42']))
return model
def drn58(pretrained=False, **kwargs):
model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['drn58']))
return model
================================================
FILE: cycada/models/fcn8s.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.nn import init
from torch.utils import model_zoo
from torchvision.models import vgg
from .models import register_model
def get_upsample_filter(size):
"""Make a 2D bilinear kernel suitable for upsampling"""
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
filter = (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
return torch.from_numpy(filter).float()
class Bilinear(nn.Module):
def __init__(self, factor, num_channels):
super().__init__()
self.factor = factor
filter = get_upsample_filter(factor * 2)
w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2)
for i in range(num_channels):
w[i, i] = filter
self.register_buffer('w', w)
def forward(self, x):
return F.conv_transpose2d(x, Variable(self.w), stride=self.factor)
@register_model('fcn8s')
class VGG16_FCN8s(nn.Module):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
def __init__(self, num_cls=19, pretrained=True, weights_init=None,
output_last_ft=False):
super().__init__()
self.output_last_ft = output_last_ft
if weights_init:
batch_norm = False
else:
batch_norm = True
self.vgg = make_layers(vgg.cfg['D'], batch_norm=False)
self.vgg_head = nn.Sequential(
nn.Conv2d(512, 4096, 7),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.5),
nn.Conv2d(4096, 4096, 1),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.5),
nn.Conv2d(4096, num_cls, 1)
)
self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls)
self.upscore8 = Bilinear(8, num_cls)
self.score_pool4 = nn.Conv2d(512, num_cls, 1)
for param in self.score_pool4.parameters():
# init.constant(param, 0)
init.constant_(param, 0)
self.score_pool3 = nn.Conv2d(256, num_cls, 1)
for param in self.score_pool3.parameters():
# init.constant(param, 0)
init.constant_(param, 0)
if pretrained:
if weights_init is not None:
self.load_weights(torch.load(weights_init))
else:
self.load_base_weights()
def load_base_vgg(self, weights_state_dict):
vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.')
self.vgg.load_state_dict(vgg_state_dict)
def load_vgg_head(self, weights_state_dict):
vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.')
self.vgg_head.load_state_dict(vgg_head_state_dict)
def get_dict_by_prefix(self, weights_state_dict, prefix):
return {k[len(prefix):]: v
for k, v in weights_state_dict.items()
if k.startswith(prefix)}
def load_weights(self, weights_state_dict):
self.load_base_vgg(weights_state_dict)
self.load_vgg_head(weights_state_dict)
def split_vgg_head(self):
self.classifier = list(self.vgg_head.children())[-1]
self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1])
def forward(self, x):
input = x
x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0)
intermediates = {}
fts_to_save = {16: 'pool3', 23: 'pool4'}
for i, module in enumerate(self.vgg):
x = module(x)
if i in fts_to_save:
intermediates[fts_to_save[i]] = x
ft_to_save = 5 # Dropout before classifier
last_ft = {}
for i, module in enumerate(self.vgg_head):
x = module(x)
if i == ft_to_save:
last_ft = x
_, _, h, w = x.size()
upscore2 = self.upscore2(x)
pool4 = intermediates['pool4']
score_pool4 = self.score_pool4(0.01 * pool4)
score_pool4c = _crop(score_pool4, upscore2, offset=5)
fuse_pool4 = upscore2 + score_pool4c
upscore_pool4 = self.upscore_pool4(fuse_pool4)
pool3 = intermediates['pool3']
score_pool3 = self.score_pool3(0.0001 * pool3)
score_pool3c = _crop(score_pool3, upscore_pool4, offset=9)
fuse_pool3 = upscore_pool4 + score_pool3c
upscore8 = self.upscore8(fuse_pool3)
score = _crop(upscore8, input, offset=31)
if self.output_last_ft:
return score, last_ft
else:
return score
def load_base_weights(self):
"""This is complicated because we converted the base model to be fully
convolutional, so some surgery needs to happen here."""
base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16'])
vgg_state_dict = {k[len('features.'):]: v
for k, v in base_state_dict.items()
if k.startswith('features.')}
self.vgg.load_state_dict(vgg_state_dict)
vgg_head_params = self.vgg_head.parameters()
for k, v in base_state_dict.items():
if not k.startswith('classifier.'):
continue
if k.startswith('classifier.6.'):
# skip final classifier output
continue
vgg_head_param = next(vgg_head_params)
vgg_head_param.data = v.view(vgg_head_param.size())
class VGG16_FCN8s_caffe(VGG16_FCN8s):
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.458, 0.408],
std=[0.00392156862745098] * 3),
torchvision.transforms.Lambda(
lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1))
])
def load_base_weights(self):
base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth')
vgg_state_dict = {k[len('features.'):]: v
for k, v in base_state_dict.items()
if k.startswith('features.')}
self.vgg.load_state_dict(vgg_state_dict)
vgg_head_params = self.vgg_head.parameters()
for k, v in base_state_dict.items():
if not k.startswith('classifier.'):
continue
if k.startswith('classifier.6.'):
# skip final classifier output
continue
vgg_head_param = next(vgg_head_params)
vgg_head_param.data = v.view(vgg_head_param.size())
class Discriminator(nn.Module):
def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''):
super().__init__()
dim1 = 1024 if input_dim == 4096 else 512
dim2 = int(dim1 / 2)
self.D = nn.Sequential(
nn.Conv2d(input_dim, dim1, 1),
nn.Dropout2d(p=0.5),
nn.ReLU(inplace=True),
nn.Conv2d(dim1, dim2, 1),
nn.Dropout2d(p=0.5),
nn.ReLU(inplace=True),
nn.Conv2d(dim2, output_dim, 1)
)
if pretrained and weights_init is not None:
self.load_weights(weights_init)
def forward(self, x):
d_score = self.D(x)
return d_score
def load_weights(self, weights):
print('Loading discriminator weights')
self.load_state_dict(torch.load(weights))
class Transform_Module(nn.Module):
def __init__(self, input_dim=4096):
super().__init__()
self.transform = nn.Sequential(
nn.Conv2d(input_dim, input_dim, 1),
nn.ReLU(inplace=True),
# nn.Conv2d(input_dim, input_dim, 1),
# nn.ReLU(inplace=True),
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_eye(m.weight)
m.bias.data.zero_()
def forward(self, x):
t_x = self.transform(x)
return t_x
def init_eye(tensor):
if isinstance(tensor, Variable):
init_eye(tensor.data)
return tensor
return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))
def _crop(input, shape, offset=0):
_, _, h, w = shape.size()
return input[:, :, offset:offset + h, offset:offset + w].contiguous()
def make_layers(cfg, batch_norm=False):
"""This is almost verbatim from torchvision.models.vgg, except that the
MaxPool2d modules are configured with ceil_mode=True.
"""
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
modules = [conv2d, nn.ReLU(inplace=True)]
if batch_norm:
modules.insert(1, nn.BatchNorm2d(v))
layers.extend(modules)
in_channels = v
return nn.Sequential(*layers)
================================================
FILE: cycada/models/models.py
================================================
import torch
models = {}
def register_model(name):
def decorator(cls):
models[name] = cls
return cls
return decorator
def get_model(name, num_cls=10, **args):
net = models[name](num_cls=num_cls, **args)
if torch.cuda.is_available():
net = net.cuda()
return net
================================================
FILE: cycada/models/task_net.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
from .models import register_model
from .util import init_weights
import numpy as np
class TaskNet(nn.Module):
num_channels = 3
image_size = 32
name = 'TaskNet'
"Basic class which does classification."
def __init__(self, num_cls=10, weights_init=None):
super(TaskNet, self).__init__()
self.num_cls = num_cls
self.setup_net()
self.criterion = nn.CrossEntropyLoss()
if weights_init is not None:
self.load(weights_init)
else:
init_weights(self)
def forward(self, x, with_ft=False):
x = self.conv_params(x)
x = x.view(x.size(0), -1)
x = self.fc_params(x)
score = self.classifier(x)
if with_ft:
return score, x
else:
return score
def setup_net(self):
"""Method to be implemented in each class."""
pass
def load(self, init_path):
net_init_dict = torch.load(init_path)
self.load_state_dict(net_init_dict)
def save(self, out_path):
torch.save(self.state_dict(), out_path)
@register_model('LeNet')
class LeNet(TaskNet):
"Network used for MNIST or USPS experiments."
num_channels = 1
image_size = 28
name = 'LeNet'
out_dim = 500 # dim of last feature layer
def setup_net(self):
self.conv_params = nn.Sequential(
nn.Conv2d(self.num_channels, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(20, 50, kernel_size=5),
nn.Dropout2d(p=0.5),
nn.MaxPool2d(2),
nn.ReLU(),
)
self.fc_params = nn.Linear(50*4*4, 500)
self.classifier = nn.Sequential(
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(500, self.num_cls)
)
@register_model('DTN')
class DTNClassifier(TaskNet):
"Classifier used for SVHN->MNIST Experiment"
num_channels = 3
image_size = 32
name = 'DTN'
out_dim = 512 # dim of last feature layer
def setup_net(self):
self.conv_params = nn.Sequential (
nn.Conv2d(self.num_channels, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.Dropout2d(0.1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(128),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(256),
nn.Dropout2d(0.5),
nn.ReLU()
)
self.fc_params = nn.Sequential (
nn.Linear(256*4*4, 512),
nn.BatchNorm1d(512),
)
self.classifier = nn.Sequential(
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, self.num_cls)
)
================================================
FILE: cycada/models/util.py
================================================
import torch.nn as nn
from torch.nn import init
def init_weights(obj):
for m in obj.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
init.xavier_normal_(m.weight)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.reset_parameters()
================================================
FILE: cycada/tools/__init__.py
================================================
================================================
FILE: cycada/tools/train_adda_net.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
# Import from torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# Import from within Package
from ..models.models import get_model
from ..data.data_loader import load_data
from ..tools.test_task_net import test
from ..tools.util import make_variable
def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch):
log_interval = 100 # specifies how often to display
N = min(len(loader_src.dataset), len(loader_tgt.dataset))
joint_loader = zip(loader_src, loader_tgt)
net.train()
last_update = -1
for batch_idx, ((data_s, _), (data_t, _)) in enumerate(joint_loader):
# log basic adda train info
info_str = "[Train Adda] Epoch: {} [{}/{} ({:.2f}%)]".format(
epoch, batch_idx*len(data_t), N, 100 * batch_idx / N)
########################
# Setup data variables #
########################
data_s = make_variable(data_s, requires_grad=False)
data_t = make_variable(data_t, requires_grad=False)
##########################
# Optimize discriminator #
##########################
# zero gradients for optimizer
opt_dis.zero_grad()
# extract and concat features
score_s = net.src_net(data_s)
score_t = net.tgt_net(data_t)
f = torch.cat((score_s, score_t), 0)
# predict with discriminator
pred_concat = net.discriminator(f)
# prepare real and fake labels: source=1, target=0
target_dom_s = make_variable(torch.ones(len(data_s)).long(), requires_grad=False)
target_dom_t = make_variable(torch.zeros(len(data_t)).long(), requires_grad=False)
label_concat = torch.cat((target_dom_s, target_dom_t), 0)
# compute loss for disciminator
loss_dis = net.gan_criterion(pred_concat, label_concat)
loss_dis.backward()
# optimize discriminator
opt_dis.step()
# compute discriminator acc
pred_dis = torch.squeeze(pred_concat.max(1)[1])
acc = (pred_dis == label_concat).float().mean()
# log discriminator update info
info_str += " acc: {:0.1f} D: {:.3f}".format(acc.item()*100, loss_dis.item())
###########################
# Optimize target network #
###########################
# only update net if discriminator is strong
if acc.item() > 0.6:
last_update = batch_idx
# zero out optimizer gradients
opt_dis.zero_grad()
opt_net.zero_grad()
# extract target features
score_t = net.tgt_net(data_t)
# predict with discriinator
pred_tgt = net.discriminator(score_t)
# create fake label
label_tgt = make_variable(torch.ones(pred_tgt.size(0)).long(), requires_grad=False)
# compute loss for target network
loss_gan_t = net.gan_criterion(pred_tgt, label_tgt)
loss_gan_t.backward()
# optimize tgt network
opt_net.step()
# log net update info
info_str += " G: {:.3f}".format(loss_gan_t.item())
###########
# Logging #
###########
if batch_idx % log_interval == 0:
print(info_str)
return last_update
def train_adda(src, tgt, model, num_cls, num_epoch=200,
batch=128, datadir="", outdir="",
src_weights=None, weights=None, lr=1e-5, betas=(0.9,0.999),
weight_decay=0):
"""Main function for training ADDA."""
###########################
# Setup cuda and networks #
###########################
# setup cuda
if torch.cuda.is_available():
kwargs = {'num_workers': 1, 'pin_memory': True}
else:
kwargs = {}
# setup network
net = get_model('AddaNet', model=model, num_cls=num_cls,
src_weights_init=src_weights)
# print network and arguments
print(net)
print('Training Adda {} model for {}->{}'.format(model, src, tgt))
#######################################
# Setup data for training and testing #
#######################################
train_src_data = load_data(src, 'train', batch=batch,
rootdir=join(datadir, src), num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
train_tgt_data = load_data(tgt, 'train', batch=batch,
rootdir=join(datadir, tgt), num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
######################
# Optimization setup #
######################
net_param = net.tgt_net.parameters()
opt_net = optim.Adam(net_param, lr=lr, weight_decay=weight_decay, betas=betas)
opt_dis = optim.Adam(net.discriminator.parameters(), lr=lr,
weight_decay=weight_decay, betas=betas)
##############
# Train Adda #
##############
for epoch in range(num_epoch):
err = train(train_src_data, train_tgt_data, net, opt_net, opt_dis, epoch)
if err == -1:
print("No suitable discriminator")
break
##############
# Save Model #
##############
os.makedirs(outdir, exist_ok=True)
outfile = join(outdir, 'adda_{:s}_net_{:s}_{:s}.pth'.format(
model, src, tgt))
print('Saving to', outfile)
net.save(outfile)
================================================
FILE: cycada/tools/train_task_net.py
================================================
from __future__ import print_function
import os
from os.path import join
import numpy as np
import argparse
# Import from torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# Import from Cycada Package
from ..models.models import get_model
from ..data.data_loader import load_data
from .test_task_net import test
from .util import make_variable
def train_epoch(loader, net, opt_net, epoch):
log_interval = 100 # specifies how often to display
net.train()
for batch_idx, (data, target) in enumerate(loader):
# make data variables
data = make_variable(data, requires_grad=False)
target = make_variable(target, requires_grad=False)
# zero out gradients
opt_net.zero_grad()
# forward pass
score = net(data)
loss = net.criterion(score, target)
# backward pass
loss.backward()
# optimize classifier and representation
opt_net.step()
# Logging
if batch_idx % log_interval == 0:
print('[Train] Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(loader.dataset),
100. * batch_idx / len(loader), loss.item()), end="")
pred = score.data.max(1)[1]
correct = pred.eq(target.data).cpu().sum()
acc = correct.item() / len(pred) * 100.0
print(' Acc: {:.2f}'.format(acc))
def train(data, datadir, model, num_cls, outdir='',
num_epoch=100, batch=128,
lr=1e-4, betas=(0.9, 0.999), weight_decay=0):
"""Train a classification net and evaluate on test set."""
# Setup GPU Usage
if torch.cuda.is_available():
kwargs = {'num_workers': 1, 'pin_memory': True}
else:
kwargs = {}
############
# Load Net #
############
net = get_model(model, num_cls=num_cls)
print('-------Training net--------')
print(net)
############################
# Load train and test data #
############################
train_data = load_data(data, 'train', batch=batch,
rootdir=datadir, num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
test_data = load_data(data, 'test', batch=batch,
rootdir=datadir, num_channels=net.num_channels,
image_size=net.image_size, download=True, kwargs=kwargs)
###################
# Setup Optimizer #
###################
opt_net = optim.Adam(net.parameters(), lr=lr, betas=betas,
weight_decay=weight_decay)
#########
# Train #
#########
print('Training {} model for {}'.format(model, data))
for epoch in range(num_epoch):
train_epoch(train_data, net, opt_net, epoch)
########
# Test #
########
if test_data is not None:
print('Evaluating {}-{} model on {} test set'.format(model, data, data))
test(test_data, net)
############
# Save net #
############
os.makedirs(outdir, exist_ok=True)
outfile = join(outdir, '{:s}_net_{:s}.pth'.format(model, data))
print('Saving to', outfile)
net.save(outfile)
return net
================================================
FILE: cycada/tools/util.py
================================================
from functools import partial
import torch
from torch.autograd import Variable
def make_variable(tensor, volatile=False, requires_grad=True):
if torch.cuda.is_available():
tensor = tensor.cuda()
if volatile:
requires_grad = False
return Variable(tensor, volatile=volatile, requires_grad=requires_grad)
def pairwise_distance(x, y):
if not len(x.shape) == len(y.shape):
raise ValueError('Both inputs should be matrices.')
if x.shape[1] != y.shape[1]:
raise ValueError('The number of features should be the same.')
x = x.view(x.shape[0], x.shape[1], 1)
y = torch.transpose(y, 0, 1)
output = torch.sum((x - y) ** 2, 1)
output = torch.transpose(output, 0, 1)
return output
def gaussian_kernel_matrix(x, y, sigmas):
sigmas = sigmas.view(sigmas.shape[0], 1)
beta = 1. / (2. * sigmas)
dist = pairwise_distance(x, y).contiguous()
dist_ = dist.view(1, -1)
s = torch.matmul(beta, dist_)
return torch.sum(torch.exp(-s), 0).view_as(dist)
def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
cost = torch.mean(kernel(x, x))
cost += torch.mean(kernel(y, y))
cost -= 2 * torch.mean(kernel(x, y))
return cost
def mmd_loss(source_features, target_features):
sigmas = [
1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
1e3, 1e4, 1e5, 1e6
]
gaussian_kernel = partial(
gaussian_kernel_matrix, sigmas=Variable(torch.cuda.FloatTensor(sigmas))
)
loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel)
loss_value = loss_value
return loss_value
================================================
FILE: cycada/transforms.py
================================================
"""These random transforms extend the transforms provided in torchvision to
allow for transforming multiple images at the same time. This ensures that the
images receive the same transformation, e.g. the provided images are either all
mirrored or all left unchanged.
For example, this is useful in segmentation tasks, where a transformation to the
image necessitates that same transformation on the label.
"""
import numbers
import random
import torch
import torchvision
class RandomCrop(object):
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, tensors):
output = []
h, w = None, None
th, tw = self.size
for tensor in tensors:
if h is None and w is None:
_, h, w = tensor.size()
elif tensor.size()[-2:] != (h, w):
print(tensor.size(), (h, w))
raise ValueError('Images must be same size')
if w == tw and h == th:
return tensors
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
for tensor in tensors:
output.append(tensor[..., y1:y1 + th, x1:x1 + tw].contiguous())
return output
class HalfCrop(object):
"""Crops halt the given PIL.Image randomly takes left or right to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __call__(self, tensors):
output = []
th, tw = self.size
tw_half = tw // 2
left_side = random.randint(0, 1)
x1 = 0 + left_size * tw_half # random.randint(0, w - tw)
for tensor in tensors:
output.append(tensor[..., ..., x1:x1 + tw_half].contiguous())
return output
class RandomHorizontalFlip(object):
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __call__(self, tensors):
if random.random() < 0.5:
output = []
for tensor in tensors:
indices = torch.arange(tensor.size(-1) - 1, -1, -1).long()
output.append(tensor.index_select(-1, indices))
return output
return tensors
def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=None):
transforms = []
if crop is not None:
transforms.append(RandomCrop(crop))
if halfcrop is not None:
transforms.append(HalfCrop())
if flip:
transforms.append(RandomHorizontalFlip())
transform = torchvision.transforms.Compose(transforms)
batch = [transform(x) for x in batch]
return torch.utils.data.dataloader.default_collate(batch)
================================================
FILE: cycada/util.py
================================================
import logging
import logging.config
import os.path
from collections import OrderedDict
import numpy as np
import torch
import yaml
from torch.nn.parameter import Parameter
from tqdm import tqdm
class TqdmHandler(logging.StreamHandler):
def __init__(self):
logging.StreamHandler.__init__(self)
def emit(self, record):
msg = self.format(record)
tqdm.write(msg)
def config_logging(logfile=None):
path = os.path.join(os.path.dirname(__file__), 'logging.yml')
with open(path, 'r') as f:
config = yaml.load(f.read())
if logfile is None:
del config['handlers']['file_handler']
del config['root']['handlers'][-1]
else:
config['handlers']['file_handler']['filename'] = logfile
logging.config.dictConfig(config)
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
def safe_load_state_dict(net, state_dict):
"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. Any params in :attr:`state_dict`
that do not match the keys returned by :attr:`net`'s :func:`state_dict()`
method or have differing sizes are skipped.
Arguments:
state_dict (dict): A dict containing parameters and
persistent buffers.
"""
own_state = net.state_dict()
skipped = []
for name, param in state_dict.items():
if name not in own_state:
skipped.append(name)
continue
if isinstance(param, Parameter):
# backwards compatibility for serialized parameters
param = param.data
if own_state[name].size() != param.size():
skipped.append(name)
continue
own_state[name].copy_(param)
if skipped:
logging.info('Skipped loading some parameters: {}'.format(skipped))
def step_lr(optimizer, mult):
for param_group in optimizer.param_groups:
lr = param_group['lr']
param_group['lr'] = lr * mult
================================================
FILE: cyclegan/.gitignore
================================================
.DS_Store
debug*
checkpoints/
results/
build/
dist/
*.png
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
.idea
================================================
FILE: cyclegan/data/__init__.py
================================================
import sys
import torch.utils.data
from data.base_data_loader import BaseDataLoader
sys.path.append('/nfs/project/libo_i/MADAN')
from cycada.transforms import augment_collate
def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
def CreateDataset(opt):
dataset = None
if opt.dataset_mode == 'synthia_cityscapes':
from data.synthia_cityscapes import SynthiaCityscapesDataset
dataset = SynthiaCityscapesDataset()
elif opt.dataset_mode == 'gta5_cityscapes':
from data.gta5_cityscapes import GTAVCityscapesDataset
dataset = GTAVCityscapesDataset()
elif opt.dataset_mode == 'gta_synthia_cityscapes':
from data.gta_synthia_cityscapes import GTASynthiaCityscapesDataset
dataset = GTASynthiaCityscapesDataset()
elif opt.dataset_mode == 'merged_gta_synthia_cityscapes':
from data.merged_gta_synthia_cityscapes import MergedGTASynthiaCityscapesDataset
dataset = MergedGTASynthiaCityscapesDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batchSize >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: cyclegan/data/base_data_loader.py
================================================
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
================================================
FILE: cyclegan/data/base_dataset.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
# TODO: 增加crop的部分
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [int(opt.loadSize), int(opt.loadSize)]
transform_list.append(transforms.Resize(osize, interpolation=Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.resize_or_crop == 'resize_only':
osize = [int(opt.loadSize), int(opt.loadSize)]
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def get_label_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'resize_only':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
transform_list.append(transforms.RandomCrop(opt.fineSize))
# transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.Lambda(lambda img: to_tensor_raw(img)))
return transforms.Compose(transform_list)
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)
def to_tensor_raw(im):
return torch.from_numpy(np.array(im, np.int64, copy=False))
================================================
FILE: cyclegan/data/cityscapes.py
================================================
import numpy as np
ignore_label = 255
id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle']
def remap_labels_to_train_ids(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
================================================
FILE: cyclegan/data/gta5_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
# This dataset is used to conduct GTA->CityScapes images transfer procedure.
class GTAVCityscapesDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot, 'gta5', 'images')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label = os.path.join(opt.dataroot, 'gta5', 'labels')
self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
self.A_paths = make_dataset(self.dir_A)
self.B_paths = make_dataset(self.dir_B)
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.A_labels = make_dataset(self.dir_A_label)
self.B_labels = make_cs_labels(self.dir_B_label)
self.A_labels = sorted(self.A_labels)
self.B_labels = sorted(self.B_labels)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path = self.A_labels[index % self.A_size]
B_label_path = self.B_labels[index_B]
A_label = Image.open(A_label_path)
B_label = Image.open(B_label_path)
A_label = np.asarray(A_label)
A_label = remap_labels_to_train_ids(A_label)
A_label = Image.fromarray(A_label, 'L')
B_label = np.asarray(B_label)
B_label = remap_labels_to_train_ids(B_label)
B_label = Image.fromarray(B_label, 'L')
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A = self.transform(A_img)
B = self.transform(B_img)
A_label = self.label_transform(A_label)
B_label = self.label_transform(B_label)
# print(A_label.unique())
# print(B_label.unique())
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B,
'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
def __len__(self):
return max(self.A_size, self.B_size)
def name(self):
return 'GTA5_Cityscapes'
================================================
FILE: cyclegan/data/gta_synthia_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
# This dataset is used to conduct double cyclegan for both GTAV->CityScapes and Synthia->CityScapes
class GTASynthiaCityscapesDataset(BaseDataset):
def initialize(self, opt):
# SYNTHIA as dataset 1
# GTAV as dataset 2
self.opt = opt
self.root = opt.dataroot
self.dir_A_1 = os.path.join(opt.dataroot, 'synthia', 'RGB')
self.dir_A_2 = os.path.join(opt.dataroot, 'gta5', 'images')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label_1 = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
self.dir_A_label_2 = os.path.join(opt.dataroot, 'gta5', 'labels')
self.A_paths_1 = make_dataset(self.dir_A_1)
self.A_paths_2 = make_dataset(self.dir_A_2)
self.B_paths = make_dataset(self.dir_B)
self.A_paths_1 = sorted(self.A_paths_1)
self.A_paths_2 = sorted(self.A_paths_2)
self.B_paths = sorted(self.B_paths)
self.A_size_1 = len(self.A_paths_1)
self.A_size_2 = len(self.A_paths_2)
self.B_size = len(self.B_paths)
self.A_labels_1 = make_dataset(self.dir_A_label_1)
self.A_labels_2 = make_dataset(self.dir_A_label_2)
self.A_labels_1 = sorted(self.A_labels_1)
self.A_labels_2 = sorted(self.A_labels_2)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path_1 = self.A_paths_1[index % self.A_size_1]
A_path_2 = self.A_paths_2[index % self.A_size_2]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path_1 = self.A_labels_1[index % self.A_size_1]
A_label_path_2 = self.A_labels_2[index % self.A_size_2]
A_label_1 = Image.open(A_label_path_1)
A_label_2 = Image.open(A_label_path_2)
# remaping label for synthia
A_label_1 = np.asarray(A_label_1)
A_label_1 = syn_relabel(A_label_1)
A_label_1 = Image.fromarray(A_label_1, 'L')
# remaping label for gta5
A_label_2 = np.asarray(A_label_2)
A_label_2 = remap_labels_to_train_ids(A_label_2)
A_label_2 = Image.fromarray(A_label_2, 'L')
A_img_1 = Image.open(A_path_1).convert('RGB')
A_img_2 = Image.open(A_path_2).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A_1 = self.transform(A_img_1)
A_2 = self.transform(A_img_2)
B = self.transform(B_img)
A_label_1 = self.label_transform(A_label_1)
A_label_2 = self.label_transform(A_label_2)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A_1[0, ...] * 0.299 + A_1[1, ...] * 0.587 + A_1[2, ...] * 0.114
A_1 = tmp.unsqueeze(0)
tmp = A_2[0, ...] * 0.299 + A_2[1, ...] * 0.587 + A_2[2, ...] * 0.114
A_2 = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A_1': A_1, 'A_2': A_2, 'B': B, 'A_paths_1': A_path_1, 'A_paths_2': A_path_2, 'B_paths': B_path, 'A_label_1': A_label_1,
'A_label_2': A_label_2}
def __len__(self):
return max(self.A_size_1, self.B_size, self.A_size_2)
def name(self):
return 'GTA5_Synthia_Cityscapes'
================================================
FILE: cyclegan/data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
import numpy as np
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_cs_labels(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
if path.endswith("_gtFine_labelIds.png"):
images.append(path)
return list(set(images))
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return list(set(images))
def load_labels(dir, images):
if os.path.exists(os.path.join(dir, 'labels.txt')):
with open(os.path.join(dir, 'labels.txt'), 'r') as f:
data = f.read().splitlines()
parse = np.array([(x.split(' ')[0], int(x.split(' ')[1])) for x in data])
label_dict = dict(parse)
labels = []
for image in images:
im_id = image.split('/')[-1].split('.')[0]
labels.append(label_dict[im_id])
elif os.path.isdir(os.path.join(dir, 'labels')):
Exception('Not yet implemented load_labels for image folder')
else:
Exception('load_labels expects %s to contain labels.txt or labels folder' % dir)
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: cyclegan/data/synthia_cityscapes.py
================================================
import os.path
import random
import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.image_folder import make_cs_labels, make_dataset
from data.cityscapes import remap_labels_to_train_ids
ignore_label = 255
id2label = {0: ignore_label,
1: 10,
2: 2,
3: 0,
4: 1,
5: 4,
6: 8,
7: 5,
8: 13,
9: 7,
10: 11,
11: 18,
12: 17,
13: ignore_label,
14: ignore_label,
15: 6,
16: 9,
17: 12,
18: 14,
19: 15,
20: 16,
21: 3,
22: ignore_label}
classes = ['road',
'sidewalk',
'building',
'wall',
'fence',
'pole',
'traffic light',
'traffic sign',
'vegetation',
'terrain',
'sky',
'person',
'rider',
'car',
'truck',
'bus',
'train',
'motorcycle',
'bicycle']
def syn_relabel(arr):
out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
for id, label in id2label.items():
out[arr == id] = int(label)
return out
class SynthiaCityscapesDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot, 'synthia', 'RGB')
self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
self.dir_A_label = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
self.A_paths = make_dataset(self.dir_A)
self.B_paths = make_dataset(self.dir_B)
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.A_labels = make_dataset(self.dir_A_label)
self.B_labels = make_cs_labels(self.dir_B_label)
self.A_labels = sorted(self.A_labels)
self.B_labels = sorted(self.B_labels)
self.transform = get_transform(opt)
self.label_transform = get_label_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_label_path = self.A_labels[index % self.A_size]
B_label_path = self.B_labels[index_B]
A_label = Image.open(A_label_path)
B_label = Image.open(B_label_path)
A_label = np.asarray(A_label)
A_label = syn_relabel(A_label)
A_label = Image.fromarray(A_label, 'L')
B_label = np.asarray(B_label)
B_label = remap_labels_to_train_ids(B_label)
B_label = Image.fromarray(B_label, 'L')
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A = self.transform(A_img)
B = self.transform(B_img)
A_label = self.label_transform(A_label)
B_label = self.label_transform(B_label)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
def __len__(self):
return max(self.A_size, self.B_size)
def name(self):
return 'Synthia_Cityscapes'
================================================
FILE: cyclegan/environment.yml
================================================
name: pytorch-CycleGAN-and-pix2pix
channels:
- peterjc123
- defaults
dependencies:
- python=3.5.5
- pytorch=0.3.1
- scipy
- pip:
- dominate==2.3.1
- git+https://github.com/pytorch/vision.git
- Pillow==5.0.0
- numpy==1.14.1
- visdom==0.1.7
================================================
FILE: cyclegan/models/__init__.py
================================================
import logging
def create_model(opt):
model = None
if opt.model == 'cycle_gan':
# assert(opt.dataset_mode == 'unaligned')
from .cycle_gan_model import CycleGANModel
model = CycleGANModel()
elif opt.model == 'test':
from .test_model import TestModel
model = TestModel()
elif opt.model == 'multi_cycle_gan_semantic':
from .multi_cycle_gan_semantic_model import CycleGANSemanticModel
model = CycleGANSemanticModel()
elif opt.model == 'cycle_gan_semantic_fcn':
from .cycle_gan_semantic_model import CycleGANSemanticModel
model = CycleGANSemanticModel()
else:
raise NotImplementedError('model [%s] not implemented.' % opt.model)
model.initialize(opt)
logging.info("model [%s] was created" % (model.name()))
return model
================================================
FILE: cyclegan/models/base_model.py
================================================
import os
from collections import OrderedDict
import torch
from . import networks
class BaseModel():
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
if opt.resize_or_crop != 'scale_width':
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.image_paths = []
def set_input(self, input):
self.input = input
def forward(self):
pass
# load and print networks; create shedulars
def setup(self, opt):
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
self.load_networks(opt.which_epoch)
self.print_networks(opt.verbose)
# make models eval mode during test time
def eval(self):
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
# used in test time, wrapping `forward` in no_grad() so we don't save
# intermediate steps for backprop
def test(self):
with torch.no_grad():
self.forward()
# get image paths
def get_image_paths(self):
return self.image_paths
def optimize_parameters(self):
pass
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
# return visualization images. train.py will display these images, and save the images to a html
def get_current_visuals(self):
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
# return traning losses/errors. train.py will print out these errors as debugging information
def get_current_losses(self):
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
# float(...) works for both scalar tensor and float number
errors_ret[name] = float(getattr(self, 'loss_' + name))
return errors_ret
# save models to the disk
def save_networks(self, which_epoch):
for name in self.model_names:
# Don't save semantic consistency networks
if isinstance(name, str) and ("PixelCLS" not in name):
save_filename = '%s_net_%s.pth' % (which_epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
# load models from the disk
def load_networks(self, which_epoch):
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (which_epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
# print network information
def print_networks(self, verbose):
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
# set requies_grad=Fasle to avoid computation
def set_requires_grad(self, nets, requires_grad=False):
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
================================================
FILE: cyclegan/models/cycle_gan_model.py
================================================
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class CycleGANModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_A')
visual_names_B.append('idt_B')
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
def optimize_parameters(self):
# forward
self.forward()
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad()
self.backward_D_A()
self.backward_D_B()
self.optimizer_D.step()
================================================
FILE: cyclegan/models/cycle_gan_semantic_model.py
================================================
import itertools
import sys
import torch
import torch.nn.functional as F
from util.image_pool import ImagePool
from . import networks
from .base_model import BaseModel
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model
class CycleGANSemanticModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A',
'D_B', 'G_B', 'cycle_B', 'idt_B',
'sem_AB']
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A.append('idt_A')
visual_names_B.append('idt_B')
self.visual_names = visual_names_A + visual_names_B
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
# Here for semantic consistency loss, load a fcn network as fs here.
self.netPixelCLS = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_init)
# Specially initialize Pixel CLS network
if len(self.gpu_ids) > 0:
assert (torch.cuda.is_available())
self.netPixelCLS.to(self.gpu_ids[0])
self.netPixelCLS = torch.nn.DataParallel(self.netPixelCLS, self.gpu_ids)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# self.criterionCLS = torch.nn.modules.CrossEntropyLoss()
self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
# initialize optimizers
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
if 'A_label' in input and 'B_label' in input:
self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device)
self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device)
# self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain
def forward(self):
self.fake_B = self.netG_A(self.real_A)
self.rec_A = self.netG_B(self.fake_B)
self.fake_A = self.netG_B(self.real_B)
self.rec_B = self.netG_A(self.fake_A)
if self.isTrain:
# Forward all four images through classifier
# Keep predictions from fake images only
self.pred_real_A = self.netPixelCLS(self.real_A)
_, self.gt_pred_A = self.pred_real_A.max(1)
self.pred_fake_B = self.netPixelCLS(self.fake_B)
_, pfB = self.pred_fake_B.max(1)
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined Loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_PixelCLS(self):
label_A = self.input_A_label
# forward only real source image through semantic classifier
pred_A = self.netPixelCLS(self.real_A)
self.loss_PixelCLS = self.criterionSemantic(F.log_softmax(pred_A, dim=1), label_A.long())
self.loss_PixelCLS.backward()
def backward_D_A(self):
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self, opt):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed.
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed.
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = 2 * self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss standard cyclegan
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
# real_A(syn)->fake_B->(fcn_frozen)->pred_fake_B == input_A_label
if opt.semantic_loss:
self.loss_sem_AB = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B, dim=1), F.softmax(self.pred_real_A,
dim=1))
self.loss_sem_AB = opt.general_semantic_weight * torch.div(self.loss_sem_AB, self.pred_fake_B.shape[1] * self.pred_fake_B.shape[2]
* self.pred_fake_B.shape[3])
self.loss_G += self.loss_sem_AB
self.loss_G.backward()
def optimize_parameters(self, opt):
# forward
self.forward()
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False)
self.optimizer_G.zero_grad()
# self.optimizer_CLS.zero_grad()
self.backward_G(opt)
self.optimizer_G.step()
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad()
self.backward_D_A()
self.backward_D_B()
self.optimizer_D.step()
================================================
FILE: cyclegan/models/multi_cycle_gan_semantic_model.py
================================================
import itertools
import sys
import torch
import torch.nn.functional as F
from util.image_pool import ImagePool
from . import networks
from .base_model import BaseModel
sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model
class CycleGANSemanticModel(BaseModel):
def name(self):
return 'CycleGANModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.semantic_loss = opt.semantic_loss
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = ['D_A_1', 'G_A_1', 'cycle_A_1', 'idt_A_1',
'D_B_1', 'G_B_1', 'cycle_B_1', 'idt_B_1',
'D_A_2', 'G_A_2', 'cycle_A_2', 'idt_A_2',
'D_B_2', 'G_B_2', 'cycle_B_2', 'idt_B_2']
if opt.SAD:
self.loss_names.extend(['D_3_1', 'G_s1s2'])
if opt.CCD or opt.HF_CCD:
self.loss_names.extend(['D_21', 'G_s1s21'])
self.loss_names.extend(['D_12', 'G_s2s12'])
if self.semantic_loss:
self.loss_names.extend(['sem_syn', 'sem_gta'])
# specify the images you want to save/display. The program will call base_model.get_current_visuals
visual_names_A_1 = ['real_A_1', 'fake_B_1', 'rec_A_1']
visual_names_B_1 = ['real_B', 'fake_A_1', 'rec_B_1']
visual_names_A_2 = ['real_A_2', 'fake_B_2', 'rec_A_2']
visual_names_B_2 = ['fake_A_2', 'rec_B_2']
if self.isTrain and self.opt.lambda_identity > 0.0:
visual_names_A_1.append('idt_A_1')
visual_names_B_1.append('idt_B_1')
visual_names_A_2.append('idt_A_2')
visual_names_B_2.append('idt_B_2')
self.visual_names = visual_names_A_1 + visual_names_B_1 + visual_names_A_2 + visual_names_B_2
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if self.isTrain:
# self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
if opt.Shared_DT:
self.model_names = ['G_A_1', 'G_B_1', 'D_A', 'D_B_1', 'D_B_2', 'G_A_2', 'G_B_2']
else:
self.model_names = ['G_A_1', 'G_B_1', 'D_A_1', 'D_B_1', 'G_A_2', 'G_B_2', 'D_A_2', 'D_B_2']
if opt.SAD:
self.model_names.append('D_3')
if opt.CCD or opt.HF_CCD:
self.model_names.append('D_12')
self.model_names.append('D_21')
else: # during test time, only load Gs
self.model_names = ['G_A_1', 'G_B_1', 'G_A_2', 'G_B_2']
# load/define networks
# The naming conversion is different from those used in the paper
# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B_1 = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
self.netG_B_2 = networks.define_G(opt.output_nc, opt.input_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout, opt.init_type, self.gpu_ids)
if opt.semantic_loss:
self.netPixelCLS_SYN = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_syn)
self.netPixelCLS_GTA = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_gta)
if len(self.gpu_ids) > 0:
assert (torch.cuda.is_available())
self.netPixelCLS_SYN.to(self.gpu_ids[0])
self.netPixelCLS_SYN = torch.nn.DataParallel(self.netPixelCLS_SYN, self.gpu_ids)
self.netPixelCLS_GTA.to(self.gpu_ids[0])
self.netPixelCLS_GTA = torch.nn.DataParallel(self.netPixelCLS_GTA, self.gpu_ids)
if self.isTrain:
use_sigmoid = opt.no_lsgan
if opt.Shared_DT:
self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
else:
self.netD_A_1 = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_A_2 = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B_1 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_B_2 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if opt.SAD:
self.netD_3 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if opt.CCD or opt.HF_CCD:
self.netD_12 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
self.netD_21 = networks.define_D(opt.input_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid,
opt.init_type, self.gpu_ids)
if self.isTrain:
self.fake_A_1_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_1_pool = ImagePool(opt.pool_size)
self.fake_A_2_pool = ImagePool(opt.pool_size)
self.fake_B_2_pool = ImagePool(opt.pool_size)
self.fake_A_21_pool = ImagePool(opt.pool_size)
self.fake_A_12_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
# initialize optimizers
if opt.Shared_DT:
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B_1.parameters(),
self.netD_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
else:
self.optimizer_D_1 = torch.optim.Adam(itertools.chain(self.netD_A_1.parameters(), self.netD_B_1.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_2 = torch.optim.Adam(itertools.chain(self.netD_A_2.parameters(), self.netD_B_2.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_G_1 = torch.optim.Adam(itertools.chain(self.netG_A_1.parameters(), self.netG_B_1.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_G_2 = torch.optim.Adam(itertools.chain(self.netG_A_2.parameters(), self.netG_B_2.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
if opt.SAD:
self.optimizer_D_3 = torch.optim.Adam(self.netD_3.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
if opt.CCD or opt.HF_CCD:
self.optimizer_D_21 = torch.optim.Adam(self.netD_21.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_12 = torch.optim.Adam(self.netD_12.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.optimizers.append(self.optimizer_G_1)
self.optimizers.append(self.optimizer_G_2)
if opt.Shared_DT:
self.optimizers.append(self.optimizer_D)
else:
self.optimizers.append(self.optimizer_D_1)
self.optimizers.append(self.optimizer_D_2)
if opt.SAD:
self.optimizers.append(self.optimizer_D_3)
if opt.CCD or opt.HF_CCD:
self.optimizers.append(self.optimizer_D_12)
self.optimizers.append(self.optimizer_D_21)
def set_input(self, input):
self.real_A_1 = input['A_1'].to(self.device)
self.real_A_2 = input['A_2'].to(self.device)
self.real_B = input['B'].to(self.device)
self.image_paths_1 = input['A_paths_1']
self.image_paths_2 = input['A_paths_2']
self.image_paths = self.image_paths_1 + self.image_paths_2
if 'A_label_1' in input and 'B_label' in input and 'A_label_2' in input:
self.input_A_label_1 = input['A_label_1'].to(self.device)
self.input_A_label_2 = input['A_label_2'].to(self.device)
self.input_B_label = input['B_label'].to(self.device)
def forward(self, opt):
# cycle for data input #1
self.fake_B_1 = self.netG_A_1(self.real_A_1)
self.rec_A_1 = self.netG_B_1(self.fake_B_1)
self.fake_A_1 = self.netG_B_1(self.real_B)
self.rec_B_1 = self.netG_A_1(self.fake_A_1)
# cycle for data input #2
self.fake_B_2 = self.netG_A_2(self.real_A_2)
self.rec_A_2 = self.netG_B_2(self.fake_B_2)
self.fake_A_2 = self.netG_B_2(self.real_B)
self.rec_B_2 = self.netG_A_2(self.fake_A_2)
if opt.CCD:
# generate s21 for d21 branch
self.fake_A_21 = self.netG_B_1(self.fake_B_2)
# generate s12 for d12 branch
self.fake_A_12 = self.netG_B_2(self.fake_B_1)
if self.isTrain and self.semantic_loss:
# Forward all four images through classifier
# Keep predictions from fake images only
self.pred_real_A_SYN = self.netPixelCLS_SYN(self.real_A_1)
_, self.gt_pred_A_SYN = self.pred_real_A_SYN.max(1)
self.pred_fake_B_SYN = self.netPixelCLS_SYN(self.fake_B_1)
_, pfB_SYN = self.pred_fake_B_SYN.max(1)
self.pred_real_A_GTA = self.netPixelCLS_GTA(self.real_A_2)
_, self.gt_pred_A_GTA = self.pred_real_A_GTA.max(1)
self.pred_fake_B_GTA = self.netPixelCLS_GTA(self.fake_B_2)
_, pfB_GTA = self.pred_fake_B_GTA.max(1)
def backward_D_basic(self, netD, real, fake, SAD=False):
# Real
if SAD == False:
pred_real = netD(real)
else:
pred_real = netD(real.detach())
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D_A(self, Shared_DT):
# data 1 A1->B
fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
if Shared_DT:
self.loss_D_A_1 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_1)
else:
self.loss_D_A_1 = self.backward_D_basic(self.netD_A_1, self.real_B, fake_B_1)
# data 2 A2->B
fake_B_2 = self.fake_B_2_pool.query(self.fake_B_2)
if Shared_DT:
self.loss_D_A_2 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_2)
else:
self.loss_D_A_2 = self.backward_D_basic(self.netD_A_2, self.real_B, fake_B_2)
def backward_D_B(self):
# data 1 B->A1
fake_A_1 = self.fake_A_1_pool.query(self.fake_A_1)
self.loss_D_B_1 = self.backward_D_basic(self.netD_B_1, self.real_A_1, fake_A_1)
# data 2 B->A2
fake_A_2 = self.fake_A_2_pool.query(self.fake_A_2)
self.loss_D_B_2 = self.backward_D_basic(self.netD_B_2, self.real_A_2, fake_A_2)
def backward_D(self, which_D):
if which_D == 'SAD':
fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
self.loss_D_3_1 = self.backward_D_basic(self.netD_3, self.fake_B_2, fake_B_1, SAD=True)
elif which_D == 'CCD_21':
fake_A_21 = self.fake_A_21_pool.query(self.fake_A_21)
self.loss_D_21 = self.backward_D_basic(self.netD_21, self.real_A_1, fake_A_21)
elif which_D == 'CCD_12':
fake_A_12 = self.fake_A_12_pool.query(self.fake_A_12)
self.loss_D_12 = self.backward_D_basic(self.netD_12, self.real_A_2, fake_A_12)
else:
raise Exception("Invalid Choice {}".format(which_D))
# fake_B_2 = self.fake_B_pool.query(self.fake_B_2)
# self.loss_D_3_2 = self.backward_D_basic(self.netD_3, self.fake_B_1, fake_B_2)
def backward_G(self, opt):
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
self.idt_A_1 = self.netG_A_1(self.real_B)
self.loss_idt_A_1 = self.criterionIdt(self.idt_A_1, self.real_B) * lambda_B * lambda_idt
self.idt_A_2 = self.netG_A_2(self.real_B)
self.loss_idt_A_2 = self.criterionIdt(self.idt_A_2, self.real_B) * lambda_B * lambda_idt
self.idt_B_1 = self.netG_B_1(self.real_A_1)
self.loss_idt_B_1 = self.criterionIdt(self.idt_B_1, self.real_A_1) * lambda_A * lambda_idt
self.idt_B_2 = self.netG_B_2(self.real_A_2)
self.loss_idt_B_2 = self.criterionIdt(self.idt_B_2, self.real_A_2) * lambda_A * lambda_idt
else:
self.loss_idt_A_1 = 0
self.loss_idt_A_2 = 0
self.loss_idt_B_1 = 0
self.loss_idt_B_2 = 0
if opt.Shared_DT:
self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A(self.fake_B_1), True)
self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A(self.fake_B_2), True)
else:
self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A_1(self.fake_B_1), True)
self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A_2(self.fake_B_2), True)
# GAN loss D_B(G_B(B))
self.loss_G_B_1 = self.criterionGAN(self.netD_B_1(self.fake_A_1), True)
self.loss_G_B_2 = self.criterionGAN(self.netD_B_2(self.fake_A_2), True)
# Forward cycle loss
self.loss_cycle_A_1 = self.criterionCycle(self.rec_A_1, self.real_A_1) * lambda_A
self.loss_cycle_A_2 = self.criterionCycle(self.rec_A_2, self.real_A_2) * lambda_A
# Backward cycle loss
self.loss_cycle_B_1 = self.criterionCycle(self.rec_B_1, self.real_B) * lambda_B
self.loss_cycle_B_2 = self.criterionCycle(self.rec_B_2, self.real_B) * lambda_B
# combined loss standard cyclegan
self.loss_G_1 = self.loss_G_A_1 + self.loss_G_B_1 + self.loss_cycle_A_1 + self.loss_cycle_B_1 + self.loss_idt_A_1 + self.loss_idt_B_1
self.loss_G_2 = self.loss_G_A_2 + self.loss_G_B_2 + self.loss_cycle_A_2 + self.loss_cycle_B_2 + self.loss_idt_A_2 + self.loss_idt_B_2
self.loss_G = self.loss_G_1 + self.loss_G_2
if opt.SAD:
# D3 loss
if opt.SAD_frozen_epoch != -1 and opt.current_epoch > opt.SAD_frozen_epoch:
self.loss_G_s1s2 = self.criterionGAN(self.netD_3(self.fake_B_1), True)
else:
self.loss_G_s1s2 = 0
self.loss_G += self.loss_G_s1s2
if opt.CCD:
# D21 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
self.loss_G += self.loss_G_s1s21 * opt.D1D2_weight
else:
self.loss_G_s1s21 = 0
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
self.loss_G += self.loss_G_s2s12 * opt.D1D2_weight
else:
self.loss_G_s2s12 = 0
if opt.semantic_loss:
self.loss_sem_syn = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_SYN, dim=1),
F.softmax(self.pred_real_A_SYN, dim=1))
self.loss_sem_gta = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_GTA, dim=1),
F.softmax(self.pred_real_A_GTA, dim=1))
self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_syn, self.pred_fake_B_SYN.shape[1] * self.pred_fake_B_SYN.shape[2]
* self.pred_fake_B_SYN.shape[3])
self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_gta, self.pred_fake_B_GTA.shape[1] * self.pred_fake_B_GTA.shape[2]
* self.pred_fake_B_GTA.shape[3])
self.loss_G.backward()
def backward_HF_CCD(self, opt):
self.fake_B_1 = self.netG_A_1(self.real_A_1)
self.fake_B_2 = self.netG_A_2(self.real_A_2)
# generate s21 for d21 branch
self.fake_A_21 = self.netG_B_1(self.fake_B_2)
# generate s12 for d12 branch
self.fake_A_12 = self.netG_B_2(self.fake_B_1)
# D12 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
else:
self.loss_G_s2s12 = 0
# D21 loss
if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
else:
self.loss_G_s1s21 = 0
# self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
# self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
self.loss_G_HF = self.loss_G_s1s21 * opt.CCD_weight + self.loss_G_s2s12 * opt.CCD_weight
if isinstance(self.loss_G_HF, torch.Tensor):
self.loss_G_HF.backward()
def optimize_parameters(self, opt):
# forward
self.forward(opt)
# G_A and G_B
# set D to false, back prop G's gradients
if opt.Shared_DT:
self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], False)
else:
self.set_requires_grad([self.netD_A_1, self.netD_B_1], False)
self.set_requires_grad([self.netD_A_2, self.netD_B_2], False)
if opt.SAD:
self.set_requires_grad([self.netD_3], False)
if opt.CCD or opt.HF_CCD:
self.set_requires_grad([self.netD_21], False)
self.set_requires_grad([self.netD_12], False)
self.set_requires_grad([self.netG_A_1, self.netG_B_1], True)
self.set_requires_grad([self.netG_A_2, self.netG_B_2], True)
self.optimizer_G_1.zero_grad()
self.optimizer_G_2.zero_grad()
# self.optimizer_CLS.zero_grad()
self.backward_G(opt)
self.optimizer_G_1.step()
self.optimizer_G_2.step()
if opt.HF_CCD:
self.optimizer_G_1.zero_grad()
self.optimizer_G_2.zero_grad()
self.set_requires_grad([self.netG_A_1, self.netG_A_2], True)
self.set_requires_grad([self.netG_B_1, self.netG_B_2], False)
self.backward_HF_CCD(opt)
self.optimizer_G_1.step()
self.optimizer_G_2.step()
# D_A and D_B
if opt.Shared_DT:
self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], True)
else:
self.set_requires_grad([self.netD_A_1, self.netD_B_1], True)
self.set_requires_grad([self.netD_A_2, self.netD_B_2], True)
if opt.Shared_DT:
self.optimizer_D.zero_grad()
else:
self.optimizer_D_1.zero_grad()
self.optimizer_D_2.zero_grad()
self.backward_D_B()
self.backward_D_A(opt.Shared_DT)
if opt.Shared_DT:
self.optimizer_D.step()
else:
self.optimizer_D_1.step()
self.optimizer_D_2.step()
if opt.SAD:
self.set_requires_grad([self.netD_3], True)
self.optimizer_D_3.zero_grad()
self.backward_D('SAD')
self.optimizer_D_3.step()
if opt.CCD or opt.HF_CCD:
self.set_requires_grad([self.netD_21], True)
self.optimizer_D_21.zero_grad()
self.backward_D('CCD_21')
self.optimizer_D_21.step()
self.set_requires_grad([self.netD_12], True)
self.optimizer_D_12.zero_grad()
self.backward_D('CCD_12')
self.optimizer_D_12.step()
================================================
FILE: cyclegan/models/networks.py
================================================
import functools
import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler
###############################################################################
# Helper Functions
###############################################################################
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', gain=0.02):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
def init_net(net, init_type='normal', gpu_ids=[]):
if len(gpu_ids) > 0:
assert (torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids)
init_weights(net, init_type)
return net
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
netG = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
return init_net(netG, init_type, gpu_ids)
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
netD = None
norm_layer = get_norm_layer(norm_type=norm)
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
elif which_model_netD == 'pixel':
netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
return init_net(netD, init_type, gpu_ids)
def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]):
# if output_nc == 3:
# netC = get_model('DTN', num_cls=10)
# else:
# Exception('classifier only implemented for 32x32x3 images')
netC = Classifier(output_nc, ndf)
return init_net(netC, init_type, gpu_ids)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(input)
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
assert (n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
if use_sigmoid:
self.net.append(nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def forward(self, input):
return self.net(input)
class Classifier(nn.Module):
def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d):
super(Classifier, self).__init__()
kw = 3
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(3):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2),
norm_layer(ndf * nf_mult, affine=True),
nn.LeakyReLU(0.2, True)
]
self.before_linear = nn.Sequential(*sequence)
sequence = [
nn.Linear(ndf * nf_mult, 1024),
nn.Linear(1024, 10)
]
self.after_linear = nn.Sequential(*sequence)
def forward(self, x):
bs = x.size(0)
out = self.after_linear(self.before_linear(x).view(bs, -1))
return out
# return nn.functional.log_softmax(out, dim=1)
================================================
FILE: cyclegan/models/test_model.py
================================================
from . import networks
from .base_model import BaseModel
class TestModel(BaseModel):
def name(self):
return 'TestModel'
def initialize(self, opt):
assert (not opt.isTrain)
BaseModel.initialize(self, opt)
# specify the training losses you want to print out. The program will call base_model.get_current_losses
self.loss_names = []
# specify the images you want to save/display. The program will call base_model.get_current_visuals
self.visual_names = ['real_A']
# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
if opt.dataset_mode == 'synthia_cityscapes':
self.model_names = ['G_A_1']
self.visual_names.append('fake_B_1')
self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
elif opt.dataset_mode == 'gta5_cityscapes':
self.model_names = ['G_A_2']
self.visual_names.append('fake_B_2')
self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
def set_input(self, input):
# we need to use single_dataset mode
self.real_A = input['A'].to(self.device)
self.image_paths = input['A_paths']
def forward(self):
if hasattr(self, 'netG_A_1'):
self.fake_B_1 = self.netG_A_1(self.real_A)
elif hasattr(self, 'netG_A_2'):
self.fake_B_2 = self.netG_A_2(self.real_A)
================================================
FILE: cyclegan/options/__init__.py
================================================
================================================
FILE: cyclegan/options/base_options.py
================================================
import argparse
import os
import torch
from util import util
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--loadSize', type=int, default=600, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=600, help='then crop to this size')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--which_model_netD', type=str, default='n_layers', help='selects model to use for netD')
self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--name', type=str, default='experiment_name',
help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
help='chooses how datasets are loaded. [unaligned | aligned | single]')
self.parser.add_argument('--model', type=str, default='cycle_gan',
help='chooses which model to use. cycle_gan, pix2pix, test')
self.parser.add_argument('--weights_model_type', type=str, default='drn26',
help='chooses which model to use. drn26, fcn8s')
self.parser.add_argument('--num_cls', default=19, type=int)
self.parser.add_argument('--max_epoch', default=20, type=int)
self.parser.add_argument('--current_epoch', default=0, type=int)
self.parser.add_argument('--weights_init', type=str)
self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
self.parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, '
'only a subset is loaded.')
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width_and_crop',
help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
self.parser.add_argument('--suffix', default='', type=str,
help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
self.parser.add_argument('--out_all', action='store_true', help='output all stylized images(fake_B_{})')
self.parser.add_argument('--SAD', action='store_true', help='Sub-domain Aggregation Discriminator module')
self.parser.add_argument('--CCD', action='store_true', help='Cross-domain Cycle Discriminator module')
self.parser.add_argument('--CCD_weight', type=float, default=1, help='weight for cross domain cycle discriminator loss')
self.parser.add_argument('--HF_CCD', action='store_true', help='Half Freeze Cross-domain Cycle Discriminator module')
self.parser.add_argument('--CCD_frozen_epoch', type=int, default=-1)
self.parser.add_argument('--SAD_frozen_epoch', type=int, default=-1)
self.parser.add_argument('--Shared_DT', type=bool, default=True, help="Through ")
self.parser.add_argument('--model_type', type=str, default='fcn8s', help="choose to load which type of model (fcn8s, drn26, deeplabv2)")
self.parser.add_argument('--semantic_loss', action='store_true', help='use semantic loss')
self.parser.add_argument('--general_semantic_weight', type=float, default=0.2, help='weight for semantic loss')
self.parser.add_argument('--weights_syn', type=str, default='', help='init weights for synthia')
self.parser.add_argument('--weights_gta', type=str, default='', help='init weights for gta')
self.parser.add_argument('--inference_script', type=str, default='', help='inference script')
self.parser.add_argument('--dynamic_weight', type=float, default=10, help='Weight for Dynamic Semantic Loss(KL div) loss')
self.initialized = True
def parse(self):
if not self.initialized:
self.initialize()
opt = self.parser.parse_args()
opt.isTrain = self.isTrain # train or test
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# set gpu ids
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
args = vars(opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
self.opt = opt
return self.opt
================================================
FILE: cyclegan/options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.isTrain = False
================================================
FILE: cyclegan/options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
self.parser.add_argument('--display_ncols', type=int, default=4,
help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for cycle loss (B -> A -> B)')
self.parser.add_argument('--lambda_identity', type=float, default=0,
help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the '
'identity mapping loss.'
'For example, if the weight of the identity loss should be 10 times smaller than the weight of the '
'reconstruction loss, please set lambda_identity = 0.1')
self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--no_html', action='store_true',
help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
================================================
FILE: cyclegan/test.py
================================================
import os
import sys
import torch
from models import create_model
from options.test_options import TestOptions
from util import html
from util.visualizer import save_images
from data import CreateDataLoader
import logging
sys.path.append("/nfs/project/libo_i/MADAN")
if __name__ == '__main__':
opt = TestOptions().parse()
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
opt.display_id = -1 # no visdom display
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
model.setup(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
# check img size
if i == 0:
for item in data.items():
if isinstance(item[1], torch.Tensor):
logging.info(item[0], item[1].size())
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
# remove reductant files when outputing
if opt.out_all:
remove_list = []
for item in visuals:
if 'fake_B' not in item:
remove_list.append(item)
for rm_item in remove_list:
del visuals[rm_item]
img_path = model.get_image_paths()
if i % 5 == 0:
logging.info('processing (%04d)-th image...' % (i * opt.batchSize))
if 'mul' in opt.model:
save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, multi_flag=True)
else:
save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
================================================
FILE: cyclegan/train.py
================================================
import subprocess
import sys
import time
sys.path.append("/nfs/project/libo_i/MADAN/cyclegan")
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
import torch
import logging
if __name__ == '__main__':
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
logging.info('#training images = %d' % dataset_size)
model = create_model(opt)
model.setup(opt)
visualizer = Visualizer(opt)
total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
opt.current_epoch = epoch
logging.info("Current epoch update to {}".format(opt.current_epoch))
for i, data in enumerate(dataset):
if total_steps == 0:
for item in data.items():
if isinstance(item[1], torch.Tensor):
logging.info(item[1].size())
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters(opt)
if total_steps % opt.display_freq == 0:
save_result = total_steps % opt.update_html_freq == 0
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_steps % opt.print_freq == 0:
losses = model.get_current_losses()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
if total_steps % opt.save_latest_freq == 0:
logging.info('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.save_networks('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
logging.info('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.save_networks('latest')
model.save_networks(epoch)
logging.info('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.max_epoch, time.time() - epoch_start_time))
model.update_learning_rate()
================================================
FILE: cyclegan/util/__init__.py
================================================
================================================
FILE: cyclegan/util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""
Download CycleGAN or Pix2Pix Data.
Args:
technique : str
One of: 'cyclegan' or 'pix2pix'.
verbose : bool
If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Args:
save_path : str
A directory to save the data to.
dataset : str, optional
A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full : str
The absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: cyclegan/util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, reflesh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: cyclegan/util/image_pool.py
================================================
import random
import torch
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = torch.cat(return_images, 0)
return return_images
================================================
FILE: cyclegan/util/util.py
================================================
from __future__ import print_function
import os
import numpy as np
import torch
from PIL import Image
# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, imtype=np.uint8):
if isinstance(input_image, torch.Tensor):
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor.cpu().float().numpy()
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
# def tensor2im(input_image, imtype=np.uint8):
# if isinstance(input_image, torch.Tensor):
# image_tensor = input_image.data
# else:
# return input_image
# image_numpy = image_tensor[0].cpu().float().numpy()
# if image_numpy.shape[0] == 1:
# image_numpy = np.tile(image_numpy, (3, 1, 1))
# image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
# return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: cyclegan/util/visualizer.py
================================================
import ntpath
import os
import time
import numpy as np
from . import html, util
# save image to the disk
def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=256, multi_flag=False):
for i in range(len(image_path)):
short_path = ntpath.basename(image_path[i])
name = os.path.splitext(short_path)[0]
for ind, (label, im_data) in enumerate(visuals.items()):
# align visual names and real image name
if multi_flag is True and (str(i + 1) not in label):
continue
im = util.tensor2im(im_data[i, :, :, :])
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
util.save_image(im, save_path)
class Visualizer():
def __init__(self, opt):
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.opt = opt
self.saved = False
if self.display_id > 0:
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
self.saved = False
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, save_result):
if self.display_id > 0: # show images in the browser
ncols = self.ncols
if ncols > 0:
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table
gitextract_rf2a9vy3/
├── .gitignore
├── .idea/
│ ├── MADAN.iml
│ ├── deployment.xml
│ ├── inspectionProfiles/
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ ├── remote-mappings.xml
│ └── vcs.xml
├── LICENSE
├── README.md
├── cycada/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── adda_datasets.py
│ │ ├── bdds.py
│ │ ├── cityscapes.py
│ │ ├── cityscapes_labels.py
│ │ ├── cyclegan.py
│ │ ├── cyclegta5.py
│ │ ├── cyclesynthia.py
│ │ ├── cyclesynthia_cyclegta5.py
│ │ ├── data_loader.py
│ │ ├── gta5.py
│ │ ├── rotater.py
│ │ ├── synthia.py
│ │ └── util.py
│ ├── logging.yml
│ ├── models/
│ │ ├── MDAN.py
│ │ ├── __init__.py
│ │ ├── adda_net.py
│ │ ├── drn.py
│ │ ├── fcn8s.py
│ │ ├── models.py
│ │ ├── task_net.py
│ │ └── util.py
│ ├── tools/
│ │ ├── __init__.py
│ │ ├── train_adda_net.py
│ │ ├── train_task_net.py
│ │ └── util.py
│ ├── transforms.py
│ └── util.py
├── cyclegan/
│ ├── .gitignore
│ ├── data/
│ │ ├── __init__.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── cityscapes.py
│ │ ├── gta5_cityscapes.py
│ │ ├── gta_synthia_cityscapes.py
│ │ ├── image_folder.py
│ │ └── synthia_cityscapes.py
│ ├── environment.yml
│ ├── models/
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── cycle_gan_model.py
│ │ ├── cycle_gan_semantic_model.py
│ │ ├── multi_cycle_gan_semantic_model.py
│ │ ├── networks.py
│ │ └── test_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── requirements.txt
├── scripts/
│ ├── ADDA/
│ │ ├── adda_cyclegta2cs_feat.sh
│ │ ├── adda_cyclegta2cs_score.sh
│ │ ├── adda_cyclesyn2cs_feat.sh
│ │ ├── adda_cyclesyn2cs_score.sh
│ │ └── adda_templates.sh
│ ├── CycleGAN/
│ │ ├── cyclegan_gta2cityscapes.sh
│ │ ├── cyclegan_gta_synthia2cityscapes.sh
│ │ ├── cyclegan_synthia2cityscapes.sh
│ │ ├── test_templates.sh
│ │ └── test_templates_cycle.sh
│ ├── FCN/
│ │ ├── train_fcn8s_cyclesgta5.sh
│ │ └── train_fcn8s_cyclesynthia.sh
│ ├── eval_fcn.py
│ ├── train_fcn.py
│ ├── train_fcn_adda.py
│ └── train_fcn_mdan.py
└── tools/
├── __init__.py
└── eval_templates.sh
SYMBOL INDEX (386 symbols across 52 files)
FILE: cycada/data/adda_datasets.py
class AddaDataLoader (line 9) | class AddaDataLoader(object):
method __init__ (line 10) | def __init__(self, net_transform, dataset, rootdir, downscale, crop_si...
method __iter__ (line 31) | def __iter__(self):
method __next__ (line 34) | def __next__(self):
method next (line 37) | def next(self):
method __len__ (line 51) | def __len__(self):
method set_loader_src (line 54) | def set_loader_src(self):
method set_loader_tgt (line 69) | def set_loader_tgt(self):
FILE: cycada/data/bdds.py
class BDDS (line 10) | class BDDS(data.Dataset):
method __init__ (line 11) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 22) | def collect_ids(self):
method img_path (line 32) | def img_path(self, filename):
method label_path (line 35) | def label_path(self, filename):
method __getitem__ (line 38) | def __getitem__(self, index, debug=False):
method __len__ (line 51) | def __len__(self):
FILE: cycada/data/cityscapes.py
function remap_labels_to_train_ids (line 10) | def remap_labels_to_train_ids(arr):
class CityScapesParams (line 18) | class CityScapesParams(DatasetParams):
class Cityscapes (line 28) | class Cityscapes(data.Dataset):
method __init__ (line 29) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 43) | def collect_ids(self):
method img_path (line 52) | def img_path(self, id):
method label_path (line 58) | def label_path(self, id):
method __getitem__ (line 64) | def __getitem__(self, index, debug=False):
method __len__ (line 78) | def __len__(self):
FILE: cycada/data/cityscapes_labels.py
function label_img_to_color (line 7) | def label_img_to_color(img):
FILE: cycada/data/cyclegan.py
class CycleGANDataset (line 10) | class CycleGANDataset(data.Dataset):
method __init__ (line 11) | def __init__(self, root, regexp, transform=None, target_transform=None,
method find_images (line 19) | def find_images(self, regexp='*.png'):
method __getitem__ (line 28) | def __getitem__(self, index):
method __len__ (line 39) | def __len__(self):
class Svhn2MNIST (line 44) | class Svhn2MNIST(CycleGANDataset):
method __init__ (line 45) | def __init__(self, root, train=True, transform=None, target_transform=...
class Svhn2MNISTParams (line 56) | class Svhn2MNISTParams(DatasetParams):
class Usps2Mnist (line 76) | class Usps2Mnist(CycleGANDataset):
method __init__ (line 77) | def __init__(self, root, train=True, transform=None, target_transform=...
class Usps2MnistParams (line 88) | class Usps2MnistParams(DatasetParams):
class Mnist2Usps (line 100) | class Mnist2Usps(CycleGANDataset):
method __init__ (line 101) | def __init__(self, root, train=True, transform=None, target_transform=...
class Mnist2UspsParams (line 112) | class Mnist2UspsParams(DatasetParams):
FILE: cycada/data/cyclegta5.py
class CycleGTA5 (line 12) | class CycleGTA5(GTA5):
method collect_ids (line 13) | def collect_ids(self):
method __getitem__ (line 29) | def __getitem__(self, index, debug=False):
FILE: cycada/data/cyclesynthia.py
function syn_relabel (line 55) | def syn_relabel(arr):
class SYNTHIAParams (line 63) | class SYNTHIAParams(DatasetParams):
class CycleSYNTHIA (line 73) | class CycleSYNTHIA(data.Dataset):
method __init__ (line 75) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 85) | def collect_ids(self):
method img_path (line 99) | def img_path(self, filename):
method label_path (line 102) | def label_path(self, filename):
method __getitem__ (line 110) | def __getitem__(self, index, debug=False):
method __len__ (line 126) | def __len__(self):
FILE: cycada/data/cyclesynthia_cyclegta5.py
function syn_relabel (line 56) | def syn_relabel(arr):
class SYNTHIAParams (line 64) | class SYNTHIAParams(DatasetParams):
class CycleSYNTHIACycleGTA5 (line 74) | class CycleSYNTHIACycleGTA5(data.Dataset):
method __init__ (line 76) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 89) | def collect_ids(self, datasets_name):
method img_path (line 110) | def img_path(self, prefix, filename):
method syn_label_path (line 115) | def syn_label_path(self, filename):
method gta_label_path (line 121) | def gta_label_path(self, filename):
method __getitem__ (line 127) | def __getitem__(self, index, debug=False):
method __len__ (line 166) | def __len__(self):
FILE: cycada/data/data_loader.py
function load_data (line 15) | def load_data(name, dset, batch=64, rootdir='', num_channels=3,
function get_transform_dataset (line 34) | def get_transform_dataset(dataset_name, rootdir, net_transform, downscal...
function get_orig_size (line 43) | def get_orig_size(dataset_name):
function get_transform2 (line 51) | def get_transform2(dataset_name, net_transform, downscale, resize):
function get_transform (line 72) | def get_transform(params, image_size, num_channels):
function get_target_transform (line 98) | def get_target_transform(params):
class AddaDataset (line 108) | class AddaDataset(data.Dataset):
method __init__ (line 110) | def __init__(self, src_data, tgt_data):
method __getitem__ (line 114) | def __getitem__(self, index):
method __len__ (line 121) | def __len__(self):
function register_data_params (line 128) | def register_data_params(name):
function register_dataset_obj (line 139) | def register_dataset_obj(name):
class DatasetParams (line 147) | class DatasetParams(object):
function get_dataset (line 157) | def get_dataset(name, rootdir, dset, image_size, num_channels, download=...
function get_fcn_dataset (line 167) | def get_fcn_dataset(name, rootdir, **kwargs):
FILE: cycada/data/gta5.py
class GTA5Params (line 13) | class GTA5Params(DatasetParams):
class GTA5 (line 23) | class GTA5(data.Dataset):
method __init__ (line 25) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 41) | def collect_ids(self):
method img_path (line 46) | def img_path(self, id):
method label_path (line 50) | def label_path(self, id):
method __getitem__ (line 54) | def __getitem__(self, index, debug=False):
method __len__ (line 71) | def __len__(self):
FILE: cycada/data/rotater.py
class Rotater (line 1) | class Rotater(object):
method __init__ (line 3) | def __init__(self, dataset, orientations=6, transform=None,
method __getitem__ (line 10) | def __getitem__(self, index):
method __len__ (line 21) | def __len__(self):
FILE: cycada/data/synthia.py
function syn_relabel (line 9) | def syn_relabel(arr):
class SYNTHIAParams (line 16) | class SYNTHIAParams(DatasetParams):
class SYNTHIA (line 26) | class SYNTHIA(data.Dataset):
method __init__ (line 28) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
method collect_ids (line 40) | def collect_ids(self):
method img_path (line 48) | def img_path(self, filename):
method label_path (line 56) | def label_path(self, filename):
method __getitem__ (line 64) | def __getitem__(self, index, debug=False):
method __len__ (line 87) | def __len__(self):
FILE: cycada/data/util.py
function maybe_download (line 57) | def maybe_download(url, dest):
function download (line 66) | def download(url, dest):
FILE: cycada/models/MDAN.py
class GradientReversalLayer (line 13) | class GradientReversalLayer(torch.autograd.Function):
method forward (line 19) | def forward(self, inputs):
method backward (line 22) | def backward(self, grad_output):
class MDANet (line 28) | class MDANet(nn.Module):
method __init__ (line 33) | def __init__(self, configs):
method forward (line 54) | def forward(self, sinputs_syn, sinputs_gta, tinputs):
method inference (line 94) | def inference(self, inputs):
FILE: cycada/models/adda_net.py
class AddaNet (line 10) | class AddaNet(nn.Module):
method __init__ (line 12) | def __init__(self, num_cls=10, model='LeNet', src_weights_init=None,
method forward (line 30) | def forward(self, x_s, x_t):
method setup_net (line 44) | def setup_net(self):
method load (line 61) | def load(self, init_path):
method load_src_net (line 66) | def load_src_net(self, init_path):
method save (line 72) | def save(self, out_path):
method save_tgt_net (line 75) | def save_tgt_net(self, out_path):
FILE: cycada/models/drn.py
function conv3x3 (line 20) | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
class BasicBlock (line 25) | class BasicBlock(nn.Module):
method __init__ (line 28) | def __init__(self, inplanes, planes, stride=1, downsample=None,
method forward (line 42) | def forward(self, x):
class Bottleneck (line 61) | class Bottleneck(nn.Module):
method __init__ (line 64) | def __init__(self, inplanes, planes, stride=1, downsample=None,
method forward (line 79) | def forward(self, x):
class DRN (line 102) | class DRN(nn.Module):
method __init__ (line 110) | def __init__(self, block, layers, num_cls=1000,
method _make_layer (line 175) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
method forward (line 199) | def forward(self, x):
function drn26 (line 252) | def drn26(pretrained=True, finetune=False, out_map=True, **kwargs):
function drn42 (line 267) | def drn42(pretrained=False, finetune=False, out_map=True, **kwargs):
function drn58 (line 275) | def drn58(pretrained=False, **kwargs):
FILE: cycada/models/fcn8s.py
function get_upsample_filter (line 14) | def get_upsample_filter(size):
class Bilinear (line 27) | class Bilinear(nn.Module):
method __init__ (line 29) | def __init__(self, factor, num_channels):
method forward (line 38) | def forward(self, x):
class VGG16_FCN8s (line 43) | class VGG16_FCN8s(nn.Module):
method __init__ (line 51) | def __init__(self, num_cls=19, pretrained=True, weights_init=None,
method load_base_vgg (line 86) | def load_base_vgg(self, weights_state_dict):
method load_vgg_head (line 90) | def load_vgg_head(self, weights_state_dict):
method get_dict_by_prefix (line 94) | def get_dict_by_prefix(self, weights_state_dict, prefix):
method load_weights (line 99) | def load_weights(self, weights_state_dict):
method split_vgg_head (line 103) | def split_vgg_head(self):
method forward (line 107) | def forward(self, x):
method load_base_weights (line 142) | def load_base_weights(self):
class VGG16_FCN8s_caffe (line 161) | class VGG16_FCN8s_caffe(VGG16_FCN8s):
method load_base_weights (line 171) | def load_base_weights(self):
class Discriminator (line 188) | class Discriminator(nn.Module):
method __init__ (line 189) | def __init__(self, input_dim=4096, output_dim=2, pretrained=False, wei...
method forward (line 206) | def forward(self, x):
method load_weights (line 210) | def load_weights(self, weights):
class Transform_Module (line 215) | class Transform_Module(nn.Module):
method __init__ (line 216) | def __init__(self, input_dim=4096):
method forward (line 229) | def forward(self, x):
function init_eye (line 234) | def init_eye(tensor):
function _crop (line 241) | def _crop(input, shape, offset=0):
function make_layers (line 246) | def make_layers(cfg, batch_norm=False):
FILE: cycada/models/models.py
function register_model (line 4) | def register_model(name):
function get_model (line 11) | def get_model(name, num_cls=10, **args):
FILE: cycada/models/task_net.py
class TaskNet (line 8) | class TaskNet(nn.Module):
method __init__ (line 15) | def __init__(self, num_cls=10, weights_init=None):
method forward (line 25) | def forward(self, x, with_ft=False):
method setup_net (line 35) | def setup_net(self):
method load (line 39) | def load(self, init_path):
method save (line 43) | def save(self, out_path):
class LeNet (line 47) | class LeNet(TaskNet):
method setup_net (line 55) | def setup_net(self):
class DTNClassifier (line 76) | class DTNClassifier(TaskNet):
method setup_net (line 84) | def setup_net(self):
FILE: cycada/models/util.py
function init_weights (line 4) | def init_weights(obj):
FILE: cycada/tools/train_adda_net.py
function train (line 21) | def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch):
function train_adda (line 118) | def train_adda(src, tgt, model, num_cls, num_epoch=200,
FILE: cycada/tools/train_task_net.py
function train_epoch (line 22) | def train_epoch(loader, net, opt_net, epoch):
function train (line 55) | def train(data, datadir, model, num_cls, outdir='',
FILE: cycada/tools/util.py
function make_variable (line 7) | def make_variable(tensor, volatile=False, requires_grad=True):
function pairwise_distance (line 15) | def pairwise_distance(x, y):
function gaussian_kernel_matrix (line 30) | def gaussian_kernel_matrix(x, y, sigmas):
function maximum_mean_discrepancy (line 40) | def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
function mmd_loss (line 48) | def mmd_loss(source_features, target_features):
FILE: cycada/transforms.py
class RandomCrop (line 17) | class RandomCrop(object):
method __init__ (line 23) | def __init__(self, size):
method __call__ (line 29) | def __call__(self, tensors):
class HalfCrop (line 48) | class HalfCrop(object):
method __call__ (line 54) | def __call__(self, tensors):
class RandomHorizontalFlip (line 65) | class RandomHorizontalFlip(object):
method __call__ (line 69) | def __call__(self, tensors):
function augment_collate (line 79) | def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=N...
FILE: cycada/util.py
class TqdmHandler (line 13) | class TqdmHandler(logging.StreamHandler):
method __init__ (line 15) | def __init__(self):
method emit (line 18) | def emit(self, record):
function config_logging (line 23) | def config_logging(logfile=None):
function to_tensor_raw (line 35) | def to_tensor_raw(im):
function safe_load_state_dict (line 39) | def safe_load_state_dict(net, state_dict):
function step_lr (line 66) | def step_lr(optimizer, mult):
FILE: cyclegan/data/__init__.py
function CreateDataLoader (line 10) | def CreateDataLoader(opt):
function CreateDataset (line 17) | def CreateDataset(opt):
class CustomDatasetDataLoader (line 39) | class CustomDatasetDataLoader(BaseDataLoader):
method name (line 40) | def name(self):
method initialize (line 43) | def initialize(self, opt):
method load_data (line 52) | def load_data(self):
method __len__ (line 55) | def __len__(self):
method __iter__ (line 58) | def __iter__(self):
FILE: cyclegan/data/base_data_loader.py
class BaseDataLoader (line 1) | class BaseDataLoader():
method __init__ (line 2) | def __init__(self):
method initialize (line 5) | def initialize(self, opt):
method load_data (line 9) | def load_data():
FILE: cyclegan/data/base_dataset.py
class BaseDataset (line 8) | class BaseDataset(data.Dataset):
method __init__ (line 9) | def __init__(self):
method name (line 12) | def name(self):
method initialize (line 15) | def initialize(self, opt):
function get_transform (line 20) | def get_transform(opt):
function get_label_transform (line 46) | def get_label_transform(opt):
function __scale_width (line 71) | def __scale_width(img, target_width):
function to_tensor_raw (line 80) | def to_tensor_raw(im):
FILE: cyclegan/data/cityscapes.py
function remap_labels_to_train_ids (line 18) | def remap_labels_to_train_ids(arr):
FILE: cyclegan/data/gta5_cityscapes.py
class GTAVCityscapesDataset (line 57) | class GTAVCityscapesDataset(BaseDataset):
method initialize (line 58) | def initialize(self, opt):
method __getitem__ (line 83) | def __getitem__(self, index):
method __len__ (line 134) | def __len__(self):
method name (line 137) | def name(self):
FILE: cyclegan/data/gta_synthia_cityscapes.py
function syn_relabel (line 56) | def syn_relabel(arr):
class GTASynthiaCityscapesDataset (line 63) | class GTASynthiaCityscapesDataset(BaseDataset):
method initialize (line 64) | def initialize(self, opt):
method __getitem__ (line 98) | def __getitem__(self, index):
method __len__ (line 160) | def __len__(self):
method name (line 163) | def name(self):
FILE: cyclegan/data/image_folder.py
function is_image_file (line 21) | def is_image_file(filename):
function make_cs_labels (line 24) | def make_cs_labels(dir):
function make_dataset (line 37) | def make_dataset(dir):
function load_labels (line 49) | def load_labels(dir, images):
function default_loader (line 64) | def default_loader(path):
class ImageFolder (line 68) | class ImageFolder(data.Dataset):
method __init__ (line 70) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 84) | def __getitem__(self, index):
method __len__ (line 94) | def __len__(self):
FILE: cyclegan/data/synthia_cityscapes.py
function syn_relabel (line 57) | def syn_relabel(arr):
class SynthiaCityscapesDataset (line 64) | class SynthiaCityscapesDataset(BaseDataset):
method initialize (line 65) | def initialize(self, opt):
method __getitem__ (line 90) | def __getitem__(self, index):
method __len__ (line 137) | def __len__(self):
method name (line 140) | def name(self):
FILE: cyclegan/models/__init__.py
function create_model (line 3) | def create_model(opt):
FILE: cyclegan/models/base_model.py
class BaseModel (line 9) | class BaseModel():
method name (line 10) | def name(self):
method initialize (line 13) | def initialize(self, opt):
method set_input (line 26) | def set_input(self, input):
method forward (line 29) | def forward(self):
method setup (line 33) | def setup(self, opt):
method eval (line 42) | def eval(self):
method test (line 50) | def test(self):
method get_image_paths (line 55) | def get_image_paths(self):
method optimize_parameters (line 58) | def optimize_parameters(self):
method update_learning_rate (line 62) | def update_learning_rate(self):
method get_current_visuals (line 69) | def get_current_visuals(self):
method get_current_losses (line 77) | def get_current_losses(self):
method save_networks (line 86) | def save_networks(self, which_epoch):
method __patch_instance_norm_state_dict (line 100) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
method load_networks (line 111) | def load_networks(self, which_epoch):
method print_networks (line 130) | def print_networks(self, verbose):
method set_requires_grad (line 144) | def set_requires_grad(self, nets, requires_grad=False):
FILE: cyclegan/models/cycle_gan_model.py
class CycleGANModel (line 8) | class CycleGANModel(BaseModel):
method name (line 9) | def name(self):
method initialize (line 12) | def initialize(self, opt):
method set_input (line 64) | def set_input(self, input):
method forward (line 70) | def forward(self):
method backward_D_basic (line 77) | def backward_D_basic(self, netD, real, fake):
method backward_D_A (line 90) | def backward_D_A(self):
method backward_D_B (line 94) | def backward_D_B(self):
method backward_G (line 98) | def backward_G(self):
method optimize_parameters (line 126) | def optimize_parameters(self):
FILE: cyclegan/models/cycle_gan_semantic_model.py
class CycleGANSemanticModel (line 15) | class CycleGANSemanticModel(BaseModel):
method name (line 16) | def name(self):
method initialize (line 19) | def initialize(self, opt):
method set_input (line 90) | def set_input(self, input):
method forward (line 101) | def forward(self):
method backward_D_basic (line 117) | def backward_D_basic(self, netD, real, fake):
method backward_PixelCLS (line 130) | def backward_PixelCLS(self):
method backward_D_A (line 137) | def backward_D_A(self):
method backward_D_B (line 141) | def backward_D_B(self):
method backward_G (line 145) | def backward_G(self, opt):
method optimize_parameters (line 182) | def optimize_parameters(self, opt):
FILE: cyclegan/models/multi_cycle_gan_semantic_model.py
class CycleGANSemanticModel (line 15) | class CycleGANSemanticModel(BaseModel):
method name (line 16) | def name(self):
method initialize (line 19) | def initialize(self, opt):
method set_input (line 193) | def set_input(self, input):
method forward (line 206) | def forward(self, opt):
method backward_D_basic (line 242) | def backward_D_basic(self, netD, real, fake, SAD=False):
method backward_D_A (line 259) | def backward_D_A(self, Shared_DT):
method backward_D_B (line 273) | def backward_D_B(self):
method backward_D (line 282) | def backward_D(self, which_D):
method backward_G (line 301) | def backward_G(self, opt):
method backward_HF_CCD (line 383) | def backward_HF_CCD(self, opt):
method optimize_parameters (line 409) | def optimize_parameters(self, opt):
FILE: cyclegan/models/networks.py
function get_norm_layer (line 14) | def get_norm_layer(norm_type='instance'):
function get_scheduler (line 26) | def get_scheduler(optimizer, opt):
function init_weights (line 42) | def init_weights(net, init_type='normal', gain=0.02):
function init_net (line 66) | def init_net(net, init_type='normal', gpu_ids=[]):
function print_network (line 75) | def print_network(net):
function define_G (line 83) | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', u...
function define_D (line 100) | def define_D(input_nc, ndf, which_model_netD,
function define_C (line 117) | def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]):
class GANLoss (line 135) | class GANLoss(nn.Module):
method __init__ (line 136) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
method get_target_tensor (line 145) | def get_target_tensor(self, input, target_is_real):
method __call__ (line 152) | def __call__(self, input, target_is_real):
class ResnetGenerator (line 161) | class ResnetGenerator(nn.Module):
method __init__ (line 162) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 205) | def forward(self, input):
class ResnetBlock (line 210) | class ResnetBlock(nn.Module):
method __init__ (line 211) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
method build_conv_block (line 215) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
method forward (line 247) | def forward(self, x):
class UnetGenerator (line 256) | class UnetGenerator(nn.Module):
method __init__ (line 257) | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
method forward (line 273) | def forward(self, input):
class UnetSkipConnectionBlock (line 280) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 281) | def __init__(self, outer_nc, inner_nc, input_nc=None,
method forward (line 326) | def forward(self, x):
class NLayerDiscriminator (line 334) | class NLayerDiscriminator(nn.Module):
method __init__ (line 335) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
method forward (line 377) | def forward(self, input):
class PixelDiscriminator (line 381) | class PixelDiscriminator(nn.Module):
method __init__ (line 382) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_si...
method forward (line 402) | def forward(self, input):
class Classifier (line 406) | class Classifier(nn.Module):
method __init__ (line 407) | def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d):
method forward (line 436) | def forward(self, x):
FILE: cyclegan/models/test_model.py
class TestModel (line 5) | class TestModel(BaseModel):
method name (line 6) | def name(self):
method initialize (line 9) | def initialize(self, opt):
method set_input (line 37) | def set_input(self, input):
method forward (line 42) | def forward(self):
FILE: cyclegan/options/base_options.py
class BaseOptions (line 8) | class BaseOptions():
method __init__ (line 9) | def __init__(self):
method initialize (line 13) | def initialize(self):
method parse (line 77) | def parse(self):
FILE: cyclegan/options/test_options.py
class TestOptions (line 4) | class TestOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: cyclegan/options/train_options.py
class TrainOptions (line 4) | class TrainOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: cyclegan/util/get_data.py
class GetData (line 11) | class GetData(object):
method __init__ (line 29) | def __init__(self, technique='cyclegan', verbose=True):
method _print (line 37) | def _print(self, text):
method _get_options (line 42) | def _get_options(r):
method _present_options (line 48) | def _present_options(self):
method _download_data (line 58) | def _download_data(self, dataset_url, save_path):
method get (line 81) | def get(self, save_path, dataset=None):
FILE: cyclegan/util/html.py
class HTML (line 6) | class HTML:
method __init__ (line 7) | def __init__(self, web_dir, title, reflesh=0):
method get_image_dir (line 22) | def get_image_dir(self):
method add_header (line 25) | def add_header(self, str):
method add_table (line 29) | def add_table(self, border=1):
method add_images (line 33) | def add_images(self, ims, txts, links, width=400):
method save (line 45) | def save(self):
FILE: cyclegan/util/image_pool.py
class ImagePool (line 5) | class ImagePool():
method __init__ (line 6) | def __init__(self, pool_size):
method query (line 12) | def query(self, images):
FILE: cyclegan/util/util.py
function tensor2im (line 12) | def tensor2im(input_image, imtype=np.uint8):
function diagnose_network (line 35) | def diagnose_network(net, name='network'):
function save_image (line 48) | def save_image(image_numpy, image_path):
function print_numpy (line 53) | def print_numpy(x, val=True, shp=False):
function mkdirs (line 63) | def mkdirs(paths):
function mkdir (line 71) | def mkdir(path):
FILE: cyclegan/util/visualizer.py
function save_images (line 10) | def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=...
class Visualizer (line 26) | class Visualizer():
method __init__ (line 27) | def __init__(self, opt):
method reset (line 49) | def reset(self):
method display_current_results (line 53) | def display_current_results(self, visuals, epoch, save_result):
method plot_current_losses (line 119) | def plot_current_losses(self, epoch, counter_ratio, opt, losses):
method print_current_losses (line 135) | def print_current_losses(self, epoch, i, losses, t, t_data):
FILE: scripts/eval_fcn.py
function fmt_array (line 26) | def fmt_array(arr, fmt=','):
function fast_hist (line 31) | def fast_hist(a, b, n):
function result_stats (line 36) | def result_stats(hist):
function main (line 57) | def main(path, dataset, datadir, model, gpu, num_cls, batch_size, loadSi...
FILE: scripts/train_fcn.py
function to_tensor_raw (line 25) | def to_tensor_raw(im):
function roundrobin_infinite (line 29) | def roundrobin_infinite(*loaders):
function supervised_loss (line 43) | def supervised_loss(score, label, weights=None):
function main (line 78) | def main(output, dataset, datadir, batch_size, lr, step, iterations,
FILE: scripts/train_fcn_adda.py
function check_label (line 25) | def check_label(label, num_cls):
function forward_pass (line 40) | def forward_pass(net, discriminator, im, requires_grad=False, discrim_fe...
function supervised_loss (line 53) | def supervised_loss(score, label, weights=None):
function discriminator_loss (line 59) | def discriminator_loss(score, target_val, lsgan=False):
function fast_hist (line 69) | def fast_hist(a, b, n):
function seg_accuracy (line 74) | def seg_accuracy(score, label, num_cls):
function main (line 112) | def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cl...
FILE: scripts/train_fcn_mdan.py
function to_tensor_raw (line 29) | def to_tensor_raw(im):
function roundrobin_infinite (line 33) | def roundrobin_infinite(*loaders):
function multi_source_infinite (line 47) | def multi_source_infinite(loaders, target_loader):
function supervised_loss (line 64) | def supervised_loss(score, label, weights=None):
function main (line 96) | def main(output, dataset, target_name, datadir, batch_size, lr, iterations,
Condensed preview — 88 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (265K chars).
[
{
"path": ".gitignore",
"chars": 1463,
"preview": ".DS_Store\n# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm\n# "
},
{
"path": ".idea/MADAN.iml",
"chars": 558,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n <component name=\"NewModuleRootManager"
},
{
"path": ".idea/deployment.xml",
"chars": 482,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"PublishConfigData\" autoUpload=\"Always\" s"
},
{
"path": ".idea/inspectionProfiles/profiles_settings.xml",
"chars": 174,
"preview": "<component name=\"InspectionProjectProfileManager\">\n <settings>\n <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n"
},
{
"path": ".idea/misc.xml",
"chars": 462,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"JavaScriptSettings\">\n <option name=\"l"
},
{
"path": ".idea/modules.xml",
"chars": 262,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"ProjectModuleManager\">\n <modules>\n "
},
{
"path": ".idea/remote-mappings.xml",
"chars": 473,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"RemoteMappingsManager\">\n <list>\n "
},
{
"path": ".idea/vcs.xml",
"chars": 180,
"preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n <component name=\"VcsDirectoryMappings\">\n <mapping dire"
},
{
"path": "LICENSE",
"chars": 1066,
"preview": "MIT License\n\nCopyright (c) 2019 liljprime\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
},
{
"path": "README.md",
"chars": 4148,
"preview": "# MADAN\n\nA Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181)"
},
{
"path": "cycada/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cycada/data/__init__.py",
"chars": 122,
"preview": "from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds\nfrom . import adda_datase"
},
{
"path": "cycada/data/adda_datasets.py",
"chars": 3239,
"preview": "import os.path\n\nimport torch.utils.data\n\nfrom .data_loader import get_transform_dataset\nfrom ..transforms import augment"
},
{
"path": "cycada/data/bdds.py",
"chars": 1516,
"preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import classes, igno"
},
{
"path": "cycada/data/cityscapes.py",
"chars": 2253,
"preview": "import os.path\nimport sys\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import cl"
},
{
"path": "cycada/data/cityscapes_labels.py",
"chars": 772,
"preview": "# function for colorizing a label image:\n# camera-ready\n\nimport numpy as np\n\n\ndef label_img_to_color(img):\n\tlabel_to_col"
},
{
"path": "cycada/data/cyclegan.py",
"chars": 3660,
"preview": "import os\nfrom os.path import join\nimport glob\nfrom PIL import Image\n\nimport torch.utils.data as data\nfrom .data_loader "
},
{
"path": "cycada/data/cyclegta5.py",
"chars": 1805,
"preview": "import os.path\n\nimport numpy as np\nfrom PIL import Image\n\nfrom .cityscapes import remap_labels_to_train_ids\nfrom .data_l"
},
{
"path": "cycada/data/cyclesynthia.py",
"chars": 3323,
"preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .data_loader import Datas"
},
{
"path": "cycada/data/cyclesynthia_cyclegta5.py",
"chars": 5214,
"preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .cityscapes import remap_"
},
{
"path": "cycada/data/data_loader.py",
"chars": 5111,
"preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\n\nimport numpy as np\nimport torch\nimport torch."
},
{
"path": "cycada/data/gta5.py",
"chars": 2074,
"preview": "import os.path\n\nimport numpy as np\nimport scipy.io\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .cityscap"
},
{
"path": "cycada/data/rotater.py",
"chars": 755,
"preview": "class Rotater(object):\n\n def __init__(self, dataset, orientations=6, transform=None,\n target_transfor"
},
{
"path": "cycada/data/synthia.py",
"chars": 2520,
"preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import classes, igno"
},
{
"path": "cycada/data/util.py",
"chars": 1812,
"preview": "import logging\nimport os.path\n\nimport requests\n\nlogger = logging.getLogger(__name__)\n\nignore_label = 255\nid2label = {0: "
},
{
"path": "cycada/logging.yml",
"chars": 722,
"preview": "---\nversion: 1\ndisable_existing_loggers: False\nformatters:\n simple:\n format: \"[%(asctime)s] %(levelname)-8s %("
},
{
"path": "cycada/models/MDAN.py",
"chars": 3563,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport logging\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functi"
},
{
"path": "cycada/models/__init__.py",
"chars": 193,
"preview": "from .models import get_model\nfrom .task_net import LeNet\nfrom .task_net import DTNClassifier\nfrom .adda_net import Adda"
},
{
"path": "cycada/models/adda_net.py",
"chars": 2499,
"preview": "\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom .util import init_weights\nfrom .mo"
},
{
"path": "cycada/models/drn.py",
"chars": 8303,
"preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nimport torchvision\n\nfrom .mode"
},
{
"path": "cycada/models/fcn8s.py",
"chars": 7917,
"preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision\nfrom torch import nn\nfrom torch.autog"
},
{
"path": "cycada/models/models.py",
"chars": 308,
"preview": "import torch\n\nmodels = {}\ndef register_model(name):\n def decorator(cls):\n models[name] = cls\n return cl"
},
{
"path": "cycada/models/task_net.py",
"chars": 3070,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom .models import register_model \nfrom .util import init_"
},
{
"path": "cycada/models/util.py",
"chars": 349,
"preview": "import torch.nn as nn\nfrom torch.nn import init\n\ndef init_weights(obj):\n for m in obj.modules():\n if isinstanc"
},
{
"path": "cycada/tools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cycada/tools/train_adda_net.py",
"chars": 5678,
"preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\nimport numpy as np\n\n# Import from torch\nimport"
},
{
"path": "cycada/tools/train_task_net.py",
"chars": 3334,
"preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\nimport numpy as np\nimport argparse\n\n# Import f"
},
{
"path": "cycada/tools/util.py",
"chars": 1560,
"preview": "from functools import partial\n\nimport torch\nfrom torch.autograd import Variable\n\n\ndef make_variable(tensor, volatile=Fal"
},
{
"path": "cycada/transforms.py",
"chars": 2714,
"preview": "\"\"\"These random transforms extend the transforms provided in torchvision to\nallow for transforming multiple images at th"
},
{
"path": "cycada/util.py",
"chars": 2005,
"preview": "import logging\nimport logging.config\nimport os.path\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch"
},
{
"path": "cyclegan/.gitignore",
"chars": 736,
"preview": ".DS_Store\ndebug*\ncheckpoints/\nresults/\nbuild/\ndist/\n*.png\ntorch.egg-info/\n*/**/__pycache__\ntorch/version.py\ntorch/csrc/g"
},
{
"path": "cyclegan/data/__init__.py",
"chars": 1819,
"preview": "import sys\n\nimport torch.utils.data\nfrom data.base_data_loader import BaseDataLoader\n\nsys.path.append('/nfs/project/libo"
},
{
"path": "cyclegan/data/base_data_loader.py",
"chars": 171,
"preview": "class BaseDataLoader():\n def __init__(self):\n pass\n\n def initialize(self, opt):\n self.opt = opt\n "
},
{
"path": "cyclegan/data/base_dataset.py",
"chars": 2925,
"preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torchvision.transforms as transforms\nfrom PIL imp"
},
{
"path": "cyclegan/data/cityscapes.py",
"chars": 1216,
"preview": "import numpy as np\n\nignore_label = 255\nid2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,\n"
},
{
"path": "cyclegan/data/gta5_cityscapes.py",
"chars": 3753,
"preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
},
{
"path": "cyclegan/data/gta_synthia_cityscapes.py",
"chars": 4626,
"preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
},
{
"path": "cyclegan/data/image_folder.py",
"chars": 3052,
"preview": "###############################################################################\n# Code from\n# https://github.com/pytorch"
},
{
"path": "cyclegan/data/synthia_cityscapes.py",
"chars": 3775,
"preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
},
{
"path": "cyclegan/environment.yml",
"chars": 249,
"preview": "name: pytorch-CycleGAN-and-pix2pix\nchannels:\n- peterjc123\n- defaults\ndependencies:\n- python=3.5.5\n- pytorch=0.3.1\n- scip"
},
{
"path": "cyclegan/models/__init__.py",
"chars": 745,
"preview": "import logging\n\ndef create_model(opt):\n\tmodel = None\n\tif opt.model == 'cycle_gan':\n\t\t# assert(opt.dataset_mode == 'unali"
},
{
"path": "cyclegan/models/base_model.py",
"chars": 4958,
"preview": "import os\nfrom collections import OrderedDict\n\nimport torch\n\nfrom . import networks\n\n\nclass BaseModel():\n\tdef name(self)"
},
{
"path": "cyclegan/models/cycle_gan_model.py",
"chars": 6300,
"preview": "import torch\nimport itertools\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import netw"
},
{
"path": "cyclegan/models/cycle_gan_semantic_model.py",
"chars": 8035,
"preview": "import itertools\nimport sys\n\nimport torch\nimport torch.nn.functional as F\nfrom util.image_pool import ImagePool\n\nfrom . "
},
{
"path": "cyclegan/models/multi_cycle_gan_semantic_model.py",
"chars": 19776,
"preview": "import itertools\nimport sys\n\nimport torch\nimport torch.nn.functional as F\nfrom util.image_pool import ImagePool\n\nfrom . "
},
{
"path": "cyclegan/models/networks.py",
"chars": 15300,
"preview": "import functools\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch.optim import lr_scheduler\n\n\n##"
},
{
"path": "cyclegan/models/test_model.py",
"chars": 1795,
"preview": "from . import networks\nfrom .base_model import BaseModel\n\n\nclass TestModel(BaseModel):\n\tdef name(self):\n\t\treturn 'TestMo"
},
{
"path": "cyclegan/options/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cyclegan/options/base_options.py",
"chars": 7391,
"preview": "import argparse\nimport os\n\nimport torch\nfrom util import util\n\n\nclass BaseOptions():\n\tdef __init__(self):\n\t\tself.parser "
},
{
"path": "cyclegan/options/test_options.py",
"chars": 845,
"preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n def initialize(self):\n BaseOptions.in"
},
{
"path": "cyclegan/options/train_options.py",
"chars": 3352,
"preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n\tdef initialize(self):\n\t\tBaseOptions.initialize"
},
{
"path": "cyclegan/test.py",
"chars": 1722,
"preview": "import os\nimport sys\n\nimport torch\nfrom models import create_model\nfrom options.test_options import TestOptions\nfrom uti"
},
{
"path": "cyclegan/train.py",
"chars": 2354,
"preview": "import subprocess\nimport sys\nimport time\n\nsys.path.append(\"/nfs/project/libo_i/MADAN/cyclegan\")\nfrom options.train_optio"
},
{
"path": "cyclegan/util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cyclegan/util/get_data.py",
"chars": 3511,
"preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
},
{
"path": "cyclegan/util/html.py",
"chars": 1912,
"preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n def __init__(self, web_dir, title, reflesh=0):\n "
},
{
"path": "cyclegan/util/image_pool.py",
"chars": 1072,
"preview": "import random\nimport torch\n\n\nclass ImagePool():\n def __init__(self, pool_size):\n self.pool_size = pool_size\n "
},
{
"path": "cyclegan/util/util.py",
"chars": 1896,
"preview": "from __future__ import print_function\n\nimport os\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\n\n# Converts a T"
},
{
"path": "cyclegan/util/visualizer.py",
"chars": 5282,
"preview": "import ntpath\nimport os\nimport time\n\nimport numpy as np\nfrom . import html, util\n\n\n# save image to the disk\ndef save_ima"
},
{
"path": "requirements.txt",
"chars": 142,
"preview": "scipy\ntorchvision\ntensorboardX\ntensorflow\nclick\ntqdm\nrequests\ncolorlog\npyyaml\ntorch>=1.1.0\ntorchvision>=0.3.0\ndominate>="
},
{
"path": "scripts/ADDA/adda_cyclegta2cs_feat.sh",
"chars": 1568,
"preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=2e-5\nmomentum="
},
{
"path": "scripts/ADDA/adda_cyclegta2cs_score.sh",
"chars": 1600,
"preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=2e-5\nmomentum="
},
{
"path": "scripts/ADDA/adda_cyclesyn2cs_feat.sh",
"chars": 1379,
"preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
},
{
"path": "scripts/ADDA/adda_cyclesyn2cs_score.sh",
"chars": 1379,
"preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
},
{
"path": "scripts/ADDA/adda_templates.sh",
"chars": 1274,
"preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
},
{
"path": "scripts/CycleGAN/cyclegan_gta2cityscapes.sh",
"chars": 532,
"preview": "#!/usr/bin/env bash\ncd /nfs/project/libo_i/MADAN/cyclegan\n\nsudo python3 train.py --name cyclegan_gta2cityscapes \\\n --"
},
{
"path": "scripts/CycleGAN/cyclegan_gta_synthia2cityscapes.sh",
"chars": 710,
"preview": "#!/usr/bin/env bash\ncd /nfs/project/libo_i/MADAN/cyclegan\n\npython3 train.py --name cyclegan_gta_synthia2cityscapes_noIde"
},
{
"path": "scripts/CycleGAN/cyclegan_synthia2cityscapes.sh",
"chars": 742,
"preview": "#!/usr/bin/env bash\ncd /root/MADAN/cyclegan\n\npython3 train.py --name cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_S"
},
{
"path": "scripts/CycleGAN/test_templates.sh",
"chars": 444,
"preview": "#!/usr/bin/env bash\n\nhow_many=100000\n\ncd /root/MADAN/cyclegan\nname=$1\nepoch=$2\n\npython3 test.py --name ${name} --resize_"
},
{
"path": "scripts/CycleGAN/test_templates_cycle.sh",
"chars": 988,
"preview": "#!/usr/bin/env bash\n# Sequentially load two generators(GTA, Synthia) and finish\nhow_many=100000\n\ncd /root/MADAN/cyclegan"
},
{
"path": "scripts/FCN/train_fcn8s_cyclesgta5.sh",
"chars": 617,
"preview": "#!/usr/bin/env bash\ngpu=0,1,2,3\ndata=cyclegta5\nmodel=fcn8s\n\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\n\ndatadir=/root/MAD"
},
{
"path": "scripts/FCN/train_fcn8s_cyclesynthia.sh",
"chars": 637,
"preview": "#!/usr/bin/env bash\ngpu=0,1,2,3\ndata=cyclesynthia\nmodel=fcn8s\n\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\n\ndatadir=/root/"
},
{
"path": "scripts/eval_fcn.py",
"chars": 4985,
"preview": "import os\nimport sys\n\nfrom torchvision.transforms import transforms\n\nsys.path.append('/nfs/project/libo_iMADAN')\nimport "
},
{
"path": "scripts/train_fcn.py",
"chars": 6789,
"preview": "import logging\nimport os.path\nimport sys\nfrom collections import deque\n\nimport click\nimport numpy as np\nimport torch\nimp"
},
{
"path": "scripts/train_fcn_adda.py",
"chars": 14835,
"preview": "import logging\nimport os\nimport os.path\nimport sys\nfrom collections import deque\nfrom datetime import datetime\n\nimport c"
},
{
"path": "scripts/train_fcn_mdan.py",
"chars": 10341,
"preview": "import itertools\nimport json\nimport logging\nimport os.path\nimport subprocess\nimport sys\nfrom collections import deque\n\ni"
},
{
"path": "tools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "tools/eval_templates.sh",
"chars": 363,
"preview": "#!/usr/bin/env bash\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\ncd /nfs/project/libo_i/MADAN\n\nckpt_path=$1\ndatadir=/nfs/pr"
}
]
About this extraction
This page contains the full source code of the Luodian/MADAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 88 files (237.8 KB), approximately 68.9k tokens, and a symbol index with 386 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.