Full Code of Luodian/MADAN for AI

master 7a2918da44f5 cached
88 files
237.8 KB
68.9k tokens
386 symbols
1 requests
Download .txt
Showing preview only (259K chars total). Download the full file or copy to clipboard to get everything.
Repository: Luodian/MADAN
Branch: master
Commit: 7a2918da44f5
Files: 88
Total size: 237.8 KB

Directory structure:
gitextract_rf2a9vy3/

├── .gitignore
├── .idea/
│   ├── MADAN.iml
│   ├── deployment.xml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── remote-mappings.xml
│   └── vcs.xml
├── LICENSE
├── README.md
├── cycada/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── adda_datasets.py
│   │   ├── bdds.py
│   │   ├── cityscapes.py
│   │   ├── cityscapes_labels.py
│   │   ├── cyclegan.py
│   │   ├── cyclegta5.py
│   │   ├── cyclesynthia.py
│   │   ├── cyclesynthia_cyclegta5.py
│   │   ├── data_loader.py
│   │   ├── gta5.py
│   │   ├── rotater.py
│   │   ├── synthia.py
│   │   └── util.py
│   ├── logging.yml
│   ├── models/
│   │   ├── MDAN.py
│   │   ├── __init__.py
│   │   ├── adda_net.py
│   │   ├── drn.py
│   │   ├── fcn8s.py
│   │   ├── models.py
│   │   ├── task_net.py
│   │   └── util.py
│   ├── tools/
│   │   ├── __init__.py
│   │   ├── train_adda_net.py
│   │   ├── train_task_net.py
│   │   └── util.py
│   ├── transforms.py
│   └── util.py
├── cyclegan/
│   ├── .gitignore
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── cityscapes.py
│   │   ├── gta5_cityscapes.py
│   │   ├── gta_synthia_cityscapes.py
│   │   ├── image_folder.py
│   │   └── synthia_cityscapes.py
│   ├── environment.yml
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── cycle_gan_model.py
│   │   ├── cycle_gan_semantic_model.py
│   │   ├── multi_cycle_gan_semantic_model.py
│   │   ├── networks.py
│   │   └── test_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── test.py
│   ├── train.py
│   └── util/
│       ├── __init__.py
│       ├── get_data.py
│       ├── html.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
├── requirements.txt
├── scripts/
│   ├── ADDA/
│   │   ├── adda_cyclegta2cs_feat.sh
│   │   ├── adda_cyclegta2cs_score.sh
│   │   ├── adda_cyclesyn2cs_feat.sh
│   │   ├── adda_cyclesyn2cs_score.sh
│   │   └── adda_templates.sh
│   ├── CycleGAN/
│   │   ├── cyclegan_gta2cityscapes.sh
│   │   ├── cyclegan_gta_synthia2cityscapes.sh
│   │   ├── cyclegan_synthia2cityscapes.sh
│   │   ├── test_templates.sh
│   │   └── test_templates_cycle.sh
│   ├── FCN/
│   │   ├── train_fcn8s_cyclesgta5.sh
│   │   └── train_fcn8s_cyclesynthia.sh
│   ├── eval_fcn.py
│   ├── train_fcn.py
│   ├── train_fcn_adda.py
│   └── train_fcn_mdan.py
└── tools/
    ├── __init__.py
    └── eval_templates.sh

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.DS_Store
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839

# User-specific stuff
.idea/**/workspace.xml
.idea/**/tasks.xml
.idea/**/usage.statistics.xml
.idea/**/dictionaries
.idea/**/shelf

# Generated files
.idea/**/contentModel.xml

# Sensitive or high-churn files
.idea/**/dataSources/
.idea/**/dataSources.ids
.idea/**/dataSources.local.xml
.idea/**/sqlDataSources.xml
.idea/**/dynamic.xml
.idea/**/uiDesigner.xml
.idea/**/dbnavigator.xml

# Gradle
.idea/**/gradle.xml
.idea/**/libraries

# Gradle and Maven with auto-import
# When using Gradle or Maven with auto-import, you should exclude module files,
# since they will be recreated, and may cause churn.  Uncomment if using
# auto-import.
# .idea/modules.xml
# .idea/*.iml
# .idea/modules
# *.iml
# *.ipr

# CMake
cmake-build-*/

# Mongo Explorer plugin
.idea/**/mongoSettings.xml

# File-based project format
*.iws

# IntelliJ
out/

# mpeltonen/sbt-idea plugin
.idea_modules/

# JIRA plugin
atlassian-ide-plugin.xml

# Cursive Clojure plugin
.idea/replstate.xml

# Crashlytics plugin (for Android Studio and IntelliJ)
com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties

# Editor-based Rest Client
.idea/httpRequests

# Android studio 3.1+ serialized cache file
.idea/caches/build_file_checksums.ser

/models/


================================================
FILE: .idea/MADAN.iml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
  <component name="NewModuleRootManager">
    <content url="file://$MODULE_DIR$" />
    <orderEntry type="jdk" jdkName="Remote Python 3.5.2 (sftp://luban@10.84.217.43:8022/usr/bin/python3)" jdkType="Python SDK" />
    <orderEntry type="sourceFolder" forTests="false" />
  </component>
  <component name="TestRunnerService">
    <option name="projectConfiguration" value="Twisted Trial" />
    <option name="PROJECT_TEST_RUNNER" value="Twisted Trial" />
  </component>
</module>

================================================
FILE: .idea/deployment.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="PublishConfigData" autoUpload="Always" serverName="MV_CyCADA_22G">
    <serverData>
      <paths name="MV_CyCADA_22G">
        <serverdata>
          <mappings>
            <mapping deploy="/nfs/project/libo_i/MADAN" local="$PROJECT_DIR$" web="/" />
          </mappings>
        </serverdata>
      </paths>
    </serverData>
    <option name="myAutoUpload" value="ALWAYS" />
  </component>
</project>

================================================
FILE: .idea/inspectionProfiles/profiles_settings.xml
================================================
<component name="InspectionProjectProfileManager">
  <settings>
    <option name="USE_PROJECT_PROFILE" value="false" />
    <version value="1.0" />
  </settings>
</component>

================================================
FILE: .idea/misc.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="JavaScriptSettings">
    <option name="languageLevel" value="ES6" />
  </component>
  <component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.5.2 (sftp://luban@10.84.217.43:8022/usr/bin/python3)" project-jdk-type="Python SDK" />
  <component name="PythonCompatibilityInspectionAdvertiser">
    <option name="version" value="3" />
  </component>
</project>

================================================
FILE: .idea/modules.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="ProjectModuleManager">
    <modules>
      <module fileurl="file://$PROJECT_DIR$/.idea/MADAN.iml" filepath="$PROJECT_DIR$/.idea/MADAN.iml" />
    </modules>
  </component>
</project>

================================================
FILE: .idea/remote-mappings.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="RemoteMappingsManager">
    <list>
      <list>
        <remote-mappings server-id="python@sftp://luban@10.84.217.43:8022/usr/bin/python3">
          <settings>
            <list>
              <mapping local-root="$PROJECT_DIR$" remote-root="/nfs/project/libo_i/MADAN" />
            </list>
          </settings>
        </remote-mappings>
      </list>
    </list>
  </component>
</project>

================================================
FILE: .idea/vcs.xml
================================================
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
  <component name="VcsDirectoryMappings">
    <mapping directory="$PROJECT_DIR$" vcs="Git" />
  </component>
</project>

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2019 liljprime

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# MADAN

A Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181)

If you use this code in your research please consider citing:

```
@InProceedings{zhao2019madan,
   title = {Multi-source Domain Adaptation for Semantic Segmentation},
   author = {Zhao, Sicheng and Li, Bo and Yue, Xiangyu and Gu, Yang and Xu, Pengfei and Tan, Hu, Runbo and Chai, Hua and   Keutzer, Kurt},
   booktitle = {Advances in Neural Information Processing Systems},
   year = {2019}
}
```

## Quick Look

Our multi-source domain adaptation builds on the work [CyCADA](https://github.com/jhoffman/cycada_release) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). Since we focus on Semantic Segmentation task, we remove Digit Classfication part in CyCADA.

We add following modules and achieve startling improvements.

1. Dynamic Semantic Consistency Module
2. Adversarial Aggregation Module
   1. Sub-domain Aggregation Discriminator
   2. Cross-domain Cycle Discriminator

While we implements [MDAN](https://openreview.net/pdf?id=ryDNZZZAW) for Semantic Segmentation task in Pytorch as our baseline comparasion.

## Overall Structure

![image-20190608104531451](http://ww4.sinaimg.cn/large/006tNc79ly1g3tjype7qlj31vo0u0hb1.jpg)

## Setup

Check out this repo:

```bash
git clone https://github.com/pikachusocute/MADAN.git
```

Install Python3 requirements

```bash
pip3 install -r requirements.txt
```

## Dynamic Adversarial Image Generation

We follow the way in CyCADA, in the first step, we need to train Image Adaptation module to transfer source image(GTA, Synthia or Multi-source) to "source as target".

![image-20190608111738818](http://ww4.sinaimg.cn/large/006tNc79ly1g3tkvxw9rrj31r40e8kjl.jpg)

We refer Image Adaptation module from GTA to Cityscapes as GTA->Cityscapes in the following.

#### GTA->Cityscapes

```bash
cd scripts/CycleGAN
bash cyclegan_gta2cityscapes.sh
```

In the training process, snapshot files will be stored in `cyclegan/checkpoints/[EXP_NAME]`.

Usually, afer we run for 20 epochs, there'll be a file `20_net_G_A.pth ` in previous folder path. 

Then we run the test process.

```bash
bash scripts/CycleGAN/test_templates.sh [EXP_NAME] 20 cycle_gan_semantic_fcn gta5_cityscapes
```

In multi-source case, there are both `20_net_G_A_1.pth` and `20_net_G_A_2.pth` exist. We use another script to run test process.

![image](https://tva1.sinaimg.cn/large/006y8mN6ly1g9cqt9m2kmj31j80skgsh.jpg)

```bash
bash scripts/CycleGAN/test_templates_cycle.sh [EXP_NAME] 20 test synthia_cityscapes gta5_cityscapes
```

New dataset will be generated at `~/cyclegan/results/[EXP_NAME]/train_20`.

After we obtain a new source stylized dataset, we then train segmenter on the new dataset.

## Pixel Level Adaptation

In this part, we train our new segmenter on new dataset.

```bash
ln -s ~/cyclegan/results/[EXP_NAME]/train_20 ~/data/cyclegta5/[EXP_NAME]_TRAIN_60
```

Then we set `dataflag = [EXP_NAME]_TRAIN_60` to find datasets' paths, and follow instructions to train segmenter to perform pixel level adaptation.

```bash
bash scripts/FCN/train_fcn8s_cyclesgta5_DSC.sh
```

## Feature Level Adaptation

For adaptation, we use

```bash
bash scripts/ADDA/adda_cyclegta2cs_score.sh
```

Make sure you choose the desired `src` and `tgt` and `datadir` before. In this process, you should load your `base_model` trained on synthetic dataset and perform adaptation in feature level to real scene dataset.

### Our Model

We release our adaptation model in the `./models`, you can use `scripts/eval_templates.sh` to evaluate its validity.

1. [CycleGTA5_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/1moGF7L2hkTHUPUzqsSxPwKNlHCHQm4Ms/view?usp=sharing)
2. [CycleSYNTHIA_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/19V5J1zyF3ct3247gSSr-u3WVkDJqQvUk/view?usp=sharing)
3. [Multi_Source_SAD_CCD](https://drive.google.com/file/d/1xgmLwhsbwv-isy7R5FkNevVSH9gcMxuq/view?usp=sharing)

### Transfered Dataset

We will release our transfer dataset soon, where our `CycleGTA5_Dynamic_Semantic_Consistency` model is trained to perform pixel level adaptation.


================================================
FILE: cycada/__init__.py
================================================


================================================
FILE: cycada/data/__init__.py
================================================
from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds
from . import adda_datasets

================================================
FILE: cycada/data/adda_datasets.py
================================================
import os.path

import torch.utils.data

from .data_loader import get_transform_dataset
from ..transforms import augment_collate


class AddaDataLoader(object):
	def __init__(self, net_transform, dataset, rootdir, downscale, crop_size=None, resize=None,
	             batch_size=1, shuffle=False, num_workers=2, half_crop=None, src_data_flag=None, small=False):
		self.dataset = dataset
		self.downscale = downscale
		self.resize = resize
		self.crop_size = crop_size
		self.half_crop = half_crop
		self.batch_size = batch_size
		self.shuffle = shuffle
		self.num_workers = num_workers
		assert len(self.dataset) == 2, 'Requires two datasets: source, target'
		sourcedir = os.path.join(rootdir, self.dataset[0])
		targetdir = os.path.join(rootdir, self.dataset[1])
		self.source = get_transform_dataset(self.dataset[0], sourcedir, net_transform, downscale, resize, src_data_flag=src_data_flag, small=small)
		self.target = get_transform_dataset(self.dataset[1], targetdir, net_transform, downscale, resize, small=small)
		print('Source length:', len(self.source), 'Target length:', len(self.target))
		self.n = max(len(self.source), len(self.target))  # make sure you see all images
		self.num = 0
		self.set_loader_src()
		self.set_loader_tgt()
	
	def __iter__(self):
		return self
	
	def __next__(self):
		return self.next()
	
	def next(self):
		if self.num % len(self.iters_src) == 0:
			print('restarting source dataset')
			self.set_loader_src()
		if self.num % len(self.iters_tgt) == 0:
			print('restarting target dataset')
			self.set_loader_tgt()
		
		img_src, label_src = next(self.iters_src)
		img_tgt, label_tgt = next(self.iters_tgt)
		
		self.num += 1
		return img_src, img_tgt, label_src, label_tgt
	
	def __len__(self):
		return min(len(self.source), len(self.target))
	
	def set_loader_src(self):
		batch_size = self.batch_size
		shuffle = self.shuffle
		num_workers = self.num_workers
		if self.crop_size is not None or self.resize is not None:
			collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
			                                           halfcrop=self.half_crop, flip=True)
		else:
			collate_fn = torch.utils.data.dataloader.default_collate
			
		self.loader_src = torch.utils.data.DataLoader(self.source,
		                                              batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
		                                              collate_fn=collate_fn, pin_memory=True)
		self.iters_src = iter(self.loader_src)
	
	def set_loader_tgt(self):
		batch_size = self.batch_size
		shuffle = self.shuffle
		num_workers = self.num_workers
		if self.crop_size is not None or self.resize is not None:
			collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size,
			                                           halfcrop=self.half_crop, flip=True)
		else:
			collate_fn = torch.utils.data.dataloader.default_collate
		self.loader_tgt = torch.utils.data.DataLoader(self.target,
		                                              batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
		                                              collate_fn=collate_fn, pin_memory=True)
		self.iters_tgt = iter(self.loader_tgt)


================================================
FILE: cycada/data/bdds.py
================================================
import os.path

import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import register_dataset_obj

@register_dataset_obj('bdds')
class BDDS(data.Dataset):
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
		self.root = root
		self.split = split
		self.remap_labels = remap_labels
		self.transform = transform
		self.target_transform = target_transform
		self.classes = classes
		self.data_flag = data_flag
		self.num_cls = num_cls
		self.ids = self.collect_ids()
	
	def collect_ids(self):
		splits = []
		path = os.path.join(self.root, "images", self.split)
		files = os.listdir(path)
		for item in files:
			fip = os.path.join(path, item)
			splits.append(fip.split('/')[-1])
		
		return splits
	
	def img_path(self, filename):
		return os.path.join(self.root, "images", self.split, filename)
	
	def label_path(self, filename):
		return os.path.join(self.root, 'labels', self.split, "{}_train_id.png".format(filename[:-4]))
	
	def __getitem__(self, index, debug=False):
		id = self.ids[index]
		img_path = self.img_path(id)
		label_path = self.label_path(id)
		
		img = Image.open(img_path).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		target = Image.open(label_path)
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target
	
	def __len__(self):
		return len(self.ids)


================================================
FILE: cycada/data/cityscapes.py
================================================
import os.path
import sys

import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj

def remap_labels_to_train_ids(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out


@register_data_params('cityscapes')
class CityScapesParams(DatasetParams):
	num_channels = 3
	image_size = 1024
	mean = 0.5
	std = 0.5
	num_cls = 19
	target_transform = None


@register_dataset_obj('cityscapes')
class Cityscapes(data.Dataset):
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None,
	             target_transform=None):
		self.root = root
		sys.path.append(root)
		self.split = split
		self.remap_labels = remap_labels
		self.ids = self.collect_ids()
		self.transform = transform
		self.target_transform = target_transform
		self.num_cls = num_cls
		
		self.id2label = id2label
		self.classes = classes
	
	def collect_ids(self):
		im_dir = os.path.join(self.root, 'leftImg8bit', self.split)
		ids = []
		for dirpath, dirnames, filenames in os.walk(im_dir):
			for filename in filenames:
				if filename.endswith('.png'):
					ids.append('_'.join(filename.split('_')[:3]))
		return ids
	
	def img_path(self, id):
		fmt = 'leftImg8bit/{}/{}/{}_leftImg8bit.png'
		subdir = id.split('_')[0]
		path = fmt.format(self.split, subdir, id)
		return os.path.join(self.root, path)
	
	def label_path(self, id):
		fmt = 'gtFine/{}/{}/{}_gtFine_labelIds.png'
		subdir = id.split('_')[0]
		path = fmt.format(self.split, subdir, id)
		return os.path.join(self.root, path)
	
	def __getitem__(self, index, debug=False):
		id = self.ids[index]
		img = Image.open(self.img_path(id)).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		target = Image.open(self.label_path(id)).convert('L')
		if self.remap_labels:
			target = np.asarray(target)
			target = remap_labels_to_train_ids(target)
			target = Image.fromarray(np.uint8(target), 'L')
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target
	
	def __len__(self):
		return len(self.ids)


================================================
FILE: cycada/data/cityscapes_labels.py
================================================
# function for colorizing a label image:
# camera-ready

import numpy as np


def label_img_to_color(img):
	label_to_color = {
		0: [128, 64, 128],
		1: [244, 35, 232],
		2: [70, 70, 70],
		3: [102, 102, 156],
		4: [190, 153, 153],
		5: [153, 153, 153],
		6: [250, 170, 30],
		7: [220, 220, 0],
		8: [107, 142, 35],
		9: [152, 251, 152],
		10: [70, 130, 180],
		11: [220, 20, 60],
		12: [255, 0, 0],
		13: [0, 0, 142],
		14: [0, 0, 70],
		15: [0, 60, 100],
		16: [0, 80, 100],
		17: [0, 0, 230],
		18: [119, 11, 32]
	}
	
	img_height, img_width = img.shape
	
	img_color = np.zeros((img_height, img_width, 3))
	for row in range(img_height):
		for col in range(img_width):
			label = img[row, col]
			img_color[row, col] = np.array(label_to_color[label])
	
	return img_color


================================================
FILE: cycada/data/cyclegan.py
================================================
import os
from os.path import join
import glob
from PIL import Image

import torch.utils.data as data
from .data_loader import DatasetParams
from .data_loader import register_dataset_obj, register_data_params

class CycleGANDataset(data.Dataset):
    def __init__(self, root, regexp, transform=None, target_transform=None, 
            download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        self.image_paths, self.labels = self.find_images(regexp)

    def find_images(self, regexp='*.png'):
        basenames = sorted(glob.glob(join(self.root, regexp)))
        image_paths = []
        labels = []
        for basename in basenames:
            image_paths.append(os.path.join(self.root, basename))
            labels.append(int(basename.split('/')[-1].split('_')[0]))
        return image_paths, labels

    def __getitem__(self, index):
        im = Image.open(self.image_paths[index]) #.convert('L')
        target = self.labels[index]

        if self.transform is not None:
            im = self.transform(im)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return im, target

    def __len__(self):
        return len(self.image_paths)


@register_dataset_obj('svhn2mnist')
class Svhn2MNIST(CycleGANDataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, 
            download=False):
        if not train:
            print('No test set for svhn2mnist.')
            self.image_paths = []
        else:
            super(Svhn2MNIST, self).__init__(root, '*_fake_B.png',
                    transform=transform, target_transform=target_transform, 
                    download=download)

@register_data_params('svhn2mnist')
class Svhn2MNISTParams(DatasetParams):
    num_channels = 3
    image_size = 32
    mean = 0.5
    std = 0.5
    #mean = 0.1307
    #std = 0.3081
    
    # mean and std (when scaled between [0,1])
    #mean = 0.127 # ep50
    #mean = 0.21 # ep100 -- more white pixels...
    #std = 0.29

    #mean = 0.21
    #std = 0.2
    
    num_cls = 10
    target_transform = None

@register_dataset_obj('usps2mnist')
class Usps2Mnist(CycleGANDataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, 
            download=False):
        if not train:
            print('No test set for usps2mnist.')
            self.image_paths = []
        else:
            super(Usps2Mnist, self).__init__(root, '*_fake_A.png',
                    transform=transform, target_transform=target_transform, 
                    download=download)

@register_data_params('usps2mnist')
class Usps2MnistParams(DatasetParams):
    num_channels = 3
    image_size = 16
    #mean = 0.1307
    #std = 0.3081
    mean = 0.5
    std = 0.5
    num_cls = 10
    target_transform = None


@register_dataset_obj('mnist2usps')
class Mnist2Usps(CycleGANDataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, 
            download=False):
        if not train:
            print('No test set for mnist2usps.')
            self.image_paths = []
        else:
            super(Mnist2Usps, self).__init__(root, '*_fake_B.png',
                    transform=transform, target_transform=target_transform, 
                    download=download)

@register_data_params('mnist2usps')
class Mnist2UspsParams(DatasetParams):
    num_channels = 3
    image_size = 16 # this seems wrong...
    #mean = 0.25
    #std = 0.37
    
    #mean = 0.1307
    #std = 0.3081
    mean = 0.5
    std = 0.5
    num_cls = 10
    target_transform = None


================================================
FILE: cycada/data/cyclegta5.py
================================================
import os.path

import numpy as np
from PIL import Image

from .cityscapes import remap_labels_to_train_ids
from .data_loader import register_dataset_obj
from .gta5 import GTA5  # , LABEL2TRAIN


@register_dataset_obj('cyclegta5')
class CycleGTA5(GTA5):
	def collect_ids(self):
		# ids = GTA5.collect_ids(self)
		existing_ids = []
		if self.data_flag:
			path = os.path.join(self.root, self.data_flag)
		else:
			path = os.path.join(self.root, "images")
		
		files = os.listdir(path)
		for item in files:
			full_path = os.path.join(path, item)
			if os.path.exists(full_path) is False:
				continue
			existing_ids.append(full_path.split('/')[-1])
		return sorted(existing_ids)
	
	def __getitem__(self, index, debug=False):
		filename = self.ids[index]
		if self.data_flag == '' or self.data_flag is None:
			img_path = os.path.join(self.root, "images", filename)
		else:
			img_path = os.path.join(self.root, self.data_flag, filename)
		
		if self.data_flag == '' or self.data_flag is None:
			label_path = os.path.join(self.root, 'labels_600x1080', filename)
		else:
			if filename.endswith('_fake_B.png'):
				label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B.png', '.png'))
			elif filename.endswith('_fake_B_2.png'):
				label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B_2.png', '.png'))
				
		img = Image.open(img_path).convert('RGB')
		target = Image.open(label_path)
		img = img.resize(target.size, resample=Image.BILINEAR)
		if self.transform is not None:
			img = self.transform(img)
		if self.remap_labels:
			target = np.asarray(target)
			target = remap_labels_to_train_ids(target)
			target = Image.fromarray(target, 'L')
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target


================================================
FILE: cycada/data/cyclesynthia.py
================================================
import os.path

import numpy as np
import torch.utils.data as data
from PIL import Image

from .data_loader import DatasetParams, register_data_params, register_dataset_obj

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']


def syn_relabel(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out


@register_data_params('cyclesynthia')
class SYNTHIAParams(DatasetParams):
	num_channels = 3
	image_size = 1024
	mean = 0.5
	std = 0.5
	num_cls = 19
	target_transform = None


@register_dataset_obj('cyclesynthia')
class CycleSYNTHIA(data.Dataset):
	
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
		self.root = root.replace('cycle', '')
		self.split = split
		self.remap_labels = remap_labels
		self.transform = transform
		self.target_transform = target_transform
		self.classes = classes
		self.num_cls = num_cls
		self.ids = self.collect_ids()
	
	def collect_ids(self):
		splits = []
		if self.data_flag:
			path = os.path.join(self.root, self.data_flag)
		else:
			path = os.path.join(self.root, 'Cycle')
		files = os.listdir(path)
		for item in files:
			fip = os.path.join(path, item)
			if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
				splits.append(fip.split('/')[-1])
		
		return splits
	
	def img_path(self, filename):
		return os.path.join(self.root, filename)
	
	def label_path(self, filename):
		# Case for loading images generated in multi-source cycle
		# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
		if filename.endswith('_fake_B_1.png'):
			return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
		elif filename.endswith('_fake_B.png'):
			return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
	
	def __getitem__(self, index, debug=False):
		id = self.ids[index]
		img_path = self.img_path(id)
		label_path = self.label_path(id)
		img = Image.open(img_path).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		target = Image.open(label_path)
		if self.remap_labels:
			target = np.asarray(target)
			target = syn_relabel(target)
			target = Image.fromarray(target, 'L')
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target
	
	def __len__(self):
		return len(self.ids)


================================================
FILE: cycada/data/cyclesynthia_cyclegta5.py
================================================
import os.path

import numpy as np
import torch.utils.data as data
from PIL import Image

from .cityscapes import remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']


def syn_relabel(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out


@register_data_params('cyclesynthia_cyclegta5')
class SYNTHIAParams(DatasetParams):
	num_channels = 3
	image_size = 1024
	mean = 0.5
	std = 0.5
	num_cls = 19
	target_transform = None

# In this class, we iteratively load transferred images from cyclesynthia and cyclegta5
@register_dataset_obj('cyclesynthia_cyclegta5')
class CycleSYNTHIACycleGTA5(data.Dataset):
	
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None):
		self.dataset_name = os.path.basename(root)
		self.parent_path = root.replace(self.dataset_name, '')
		self.syn_name = os.path.join(self.parent_path, 'synthia')
		self.gta_name = os.path.join(self.parent_path, 'cyclegta5')
		self.remap_labels = remap_labels
		self.transform = transform
		self.target_transform = target_transform
		self.classes = classes
		self.num_cls = num_cls
		self.syn_ids = self.collect_ids('syn')
		self.gta_ids = self.collect_ids('gta')
	
	def collect_ids(self, datasets_name):
		splits = []
		if datasets_name == 'syn':
			files = os.listdir(self.syn_name)
			for item in files:
				fip = os.path.join(self.syn_name, item)
				if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')):
					splits.append(fip.split('/')[-1])
		
		elif datasets_name == 'gta':
			files = os.listdir(self.gta_name)
			for item in files:
				fip = os.path.join(self.gta_name, item)
				if (fip.endswith('_fake_B_2.png') or fip.endswith('_fake_B.png')):
					splits.append(fip.split('/')[-1])
		
		else:
			print("Don't Recognize {}".format(datasets_name))
		
		return splits
	
	def img_path(self, prefix, filename):
		return os.path.join(prefix, filename)
	
	# Case for loading images generated in multi-source cycle
	# In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5
	def syn_label_path(self, filename):
		if filename.endswith('_fake_B_1.png'):
			return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png'))
		elif filename.endswith('_fake_B.png'):
			return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png'))
	
	def gta_label_path(self, filename):
		if filename.endswith('_fake_B_2.png'):
			return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B_2.png', '.png'))
		elif filename.endswith('_fake_B.png'):
			return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B.png', '.png'))
	
	def __getitem__(self, index, debug=False):
		# we iteratively load images from cyclesynthia and cyclegta5
		if index % 2:
			id = self.syn_ids[index % len(self.syn_ids)]
			img_path = self.img_path(self.syn_name, id)
			label_path = self.syn_label_path(id)
			img = Image.open(img_path).convert('RGB')
			if self.transform is not None:
				img = self.transform(img)
			target = Image.open(label_path)
			if self.remap_labels:
				target = np.asarray(target)
				target = syn_relabel(target)
				target = Image.fromarray(target, 'L')
			if self.target_transform is not None:
				target = self.target_transform(target)
		
		else:
			id = self.gta_ids[index % len(self.gta_ids)]
			img_path = self.img_path(self.gta_name, id)
			label_path = self.gta_label_path(id)
			img = Image.open(img_path).convert('RGB')
			if self.transform is not None:
				img = self.transform(img)
			target = Image.open(label_path)
			if self.remap_labels:
				target = np.asarray(target)
				target = remap_labels_to_train_ids(target)
				target = Image.fromarray(target, 'L')
			if self.target_transform is not None:
				target = self.target_transform(target)
		
		# if debug:
		# 	print(self.__class__.__name__)
		# 	print("IMG Path: {}".format(img_path))
		# 	print("Label Path: {}".format(label_path))
		#
		return img, target
	
	def __len__(self):
		return len(self.syn_ids) + len(self.gta_ids)


================================================
FILE: cycada/data/data_loader.py
================================================
from __future__ import print_function

import os
from os.path import join

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from torchvision import transforms

from ..util import to_tensor_raw


def load_data(name, dset, batch=64, rootdir='', num_channels=3,
              image_size=32, download=True, kwargs={}):
	is_train = (dset == 'train')
	if isinstance(name, list) and len(name) == 2:  # load adda data
		src_dataset = get_dataset(name[0], join(rootdir, name[0]), dset,
		                          image_size, num_channels, download=download)
		tgt_dataset = get_dataset(name[1], join(rootdir, name[1]), dset,
		                          image_size, num_channels, download=download)
		dataset = AddaDataset(src_dataset, tgt_dataset)
	else:
		dataset = get_dataset(name, rootdir, dset, image_size, num_channels,
		                      download=download)
	if len(dataset) == 0:
		return None
	loader = torch.utils.data.DataLoader(dataset, batch_size=batch,
	                                     shuffle=is_train, **kwargs)
	return loader


def get_transform_dataset(dataset_name, rootdir, net_transform, downscale, resize=None, src_data_flag=None, small=False):
	user_paths = os.environ['PYTHONPATH'].split(os.pathsep)
	transform, target_transform = get_transform2(dataset_name, net_transform, downscale, resize)
	return get_fcn_dataset(dataset_name, rootdir, transform=transform, target_transform=target_transform, data_flag=src_data_flag, small=small)


sizes = {'cyclesynthia_cyclegta5': 1280, 'cyclesynthia': 1280, 'cityscapes': 1280, 'gta5': 1280, 'cyclegta5': 1280, "synthia": 1280}


def get_orig_size(dataset_name):
	"Size of images in the dataset for relative scaling."
	try:
		return sizes[dataset_name]
	except:
		raise Exception('Unknown dataset size:', dataset_name)


def get_transform2(dataset_name, net_transform, downscale, resize):
	"Returns image and label transform to downscale, crop and prepare for net."
	orig_size = get_orig_size(dataset_name)
	transform = []
	target_transform = []
	if downscale is not None:
		transform.append(transforms.Resize(orig_size // downscale))
		target_transform.append(transforms.Resize(orig_size // downscale, interpolation=Image.NEAREST))
	
	if resize is not None:
		transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.BICUBIC)])
		target_transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.NEAREST)])
	
	transform.extend([net_transform])
	target_transform.extend([to_tensor_raw])
	
	transform = transforms.Compose(transform)
	target_transform = transforms.Compose(target_transform)
	return transform, target_transform


def get_transform(params, image_size, num_channels):
	# Transforms for PIL Images: Gray <-> RGB
	Gray2RGB = transforms.Lambda(lambda x: x.convert('RGB'))
	RGB2Gray = transforms.Lambda(lambda x: x.convert('L'))
	
	transform = []
	# Does size request match original size?
	if not image_size == params.image_size:
		transform.append(transforms.Resize(image_size))
	
	# Does number of channels requested match original?
	if not num_channels == params.num_channels:
		if num_channels == 1:
			transform.append(RGB2Gray)
		elif num_channels == 3:
			transform.append(Gray2RGB)
		else:
			print('NumChannels should be 1 or 3', num_channels)
			raise Exception
	
	transform += [transforms.ToTensor(),
	              transforms.Normalize((params.mean,), (params.std,))]
	
	return transforms.Compose(transform)


def get_target_transform(params):
	transform = params.target_transform
	t_uniform = transforms.Lambda(lambda x: x[:, 0]
	if isinstance(x, (list, np.ndarray)) and len(x) == 2 else x)
	if transform is None:
		return t_uniform
	else:
		return transforms.Compose([transform, t_uniform])


class AddaDataset(data.Dataset):
	
	def __init__(self, src_data, tgt_data):
		self.src = src_data
		self.tgt = tgt_data
	
	def __getitem__(self, index):
		ns = len(self.src)
		nt = len(self.tgt)
		xs, ys = self.src[index % ns]
		xt, yt = self.tgt[index % nt]
		return (xs, ys), (xt, yt)
	
	def __len__(self):
		return min(len(self.src), len(self.tgt))


data_params = {}


def register_data_params(name):
	def decorator(cls):
		data_params[name] = cls
		return cls
	
	return decorator


dataset_obj = {}


def register_dataset_obj(name):
	def decorator(cls):
		dataset_obj[name] = cls
		return cls
	
	return decorator


class DatasetParams(object):
	"Class variables defined."
	num_channels = 1
	image_size = 16
	mean = 0.1307
	std = 0.3081
	num_cls = 10
	target_transform = None


def get_dataset(name, rootdir, dset, image_size, num_channels, download=True):
	is_train = (dset == 'train')
	print('get dataset:', name, rootdir, dset)
	params = data_params[name]
	transform = get_transform(params, image_size, num_channels)
	target_transform = get_target_transform(params)
	return dataset_obj[name](rootdir, train=is_train, transform=transform,
	                         target_transform=target_transform, download=download)


def get_fcn_dataset(name, rootdir, **kwargs):
	return dataset_obj[name](rootdir, **kwargs)


================================================
FILE: cycada/data/gta5.py
================================================
import os.path

import numpy as np
import scipy.io
import torch.utils.data as data
from PIL import Image

from .cityscapes import id2label as LABEL2TRAIN, remap_labels_to_train_ids
from .data_loader import DatasetParams, register_data_params, register_dataset_obj


@register_data_params('gta5')
class GTA5Params(DatasetParams):
	num_channels = 3
	image_size = 1024
	mean = 0.5
	std = 0.5
	num_cls = 19
	target_transform = None


@register_dataset_obj('gta5')
class GTA5(data.Dataset):
	
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None):
		self.root = root
		self.split = split
		self.remap_labels = remap_labels
		self.data_flag = data_flag
		self.ids = self.collect_ids()
		self.transform = transform
		self.target_transform = target_transform
		m = scipy.io.loadmat(os.path.join(self.root, 'mapping.mat'))
		full_classes = [x[0] for x in m['classes'][0]]
		self.classes = []
		for old_id, new_id in LABEL2TRAIN.items():
			if not new_id == 255 and old_id > 0:
				self.classes.append(full_classes[old_id])
		self.num_cls = num_cls
	
	def collect_ids(self):
		splits = scipy.io.loadmat(os.path.join(self.root, 'split.mat'))
		ids = splits['{}Ids'.format(self.split)].squeeze()
		return ids
	
	def img_path(self, id):
		filename = '{:05d}.png'.format(id)
		return os.path.join(self.root, 'images', filename)
	
	def label_path(self, id):
		filename = '{:05d}.png'.format(id)
		return os.path.join(self.root, 'labels', filename)
	
	def __getitem__(self, index, debug=False):
		id = self.ids[index]
		img_path = self.img_path(id)
		label_path = self.label_path(id)
		
		img = Image.open(img_path).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		target = Image.open(label_path)
		if self.remap_labels:
			target = np.asarray(target)
			target = remap_labels_to_train_ids(target)
			target = Image.fromarray(target, 'L')
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target
	
	def __len__(self):
		return len(self.ids)


================================================
FILE: cycada/data/rotater.py
================================================
class Rotater(object):

    def __init__(self, dataset, orientations=6, transform=None,
                 target_transform=None):
        self.dataset = dataset
        self.orientations = orientations
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        im, target = self.dataset[index]
        rotation = index % self.orientations
        degrees = 360 / self.orientations * rotation
        im = im.rotate(degrees)
        if self.transform is not None:
            im = self.transform(im)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return im, target, degrees

    def __len__(self):
        return len(self.dataset)


================================================
FILE: cycada/data/synthia.py
================================================
import os.path

import numpy as np
import torch.utils.data as data
from PIL import Image
from .util import classes, ignore_label, id2label
from .data_loader import DatasetParams, register_data_params, register_dataset_obj

def syn_relabel(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out

@register_data_params('synthia')
class SYNTHIAParams(DatasetParams):
	num_channels = 3
	image_size = 1024
	mean = 0.5
	std = 0.5
	num_cls = 19
	target_transform = None


@register_dataset_obj('synthia')
class SYNTHIA(data.Dataset):
	
	def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None, small=2):
		self.root = root
		self.split = split
		self.small = small
		self.remap_labels = remap_labels
		self.ids = self.collect_ids()
		self.transform = transform
		self.target_transform = target_transform
		self.classes = classes
		self.num_cls = num_cls
		self.data_flag = data_flag
	
	def collect_ids(self):
		splits = []
		with open(os.path.join(self.root, 'SYNTHIA_imagelist_{}.txt'.format(self.split))) as f:
			for line in f:
				line = line.strip('\n')
				splits.append(line.split('/')[-1])
		return splits
	
	def img_path(self, filename):
		if self.small == 0:
			return os.path.join(self.root, 'RGB_300x540', filename)
		elif self.small == 1:
			return os.path.join(self.root, 'RGB_600x1080', filename)
		else:
			return os.path.join(self.root, 'RGB', filename)
	
	def label_path(self, filename):
		if self.small == 0:
			return os.path.join(self.root, 'GT', 'parsed_LABELS_300x540', filename)
		elif self.small == 1:
			return os.path.join(self.root, 'GT', 'parsed_LABELS_600x1080', filename)
		else:
			return os.path.join(self.root, 'GT', 'parsed_LABELS', filename)
	
	def __getitem__(self, index, debug=False):
		id = self.ids[index]
		img_path = self.img_path(id)
		label_path = self.label_path(id)
		
		if debug:
			print(self.__class__.__name__)
			print("IMG Path: {}".format(img_path))
			print("Label Path: {}".format(label_path))
		
		img = Image.open(img_path).convert('RGB')
		if self.transform is not None:
			img = self.transform(img)
		target = Image.open(label_path)
		
		if self.remap_labels:
			target = np.asarray(target)
			target = syn_relabel(target)
			target = Image.fromarray(target, 'L')
		if self.target_transform is not None:
			target = self.target_transform(target)
		return img, target
	
	def __len__(self):
		return len(self.ids)


================================================
FILE: cycada/data/util.py
================================================
import logging
import os.path

import requests

logger = logging.getLogger(__name__)

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']

palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
           220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
           0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]

def maybe_download(url, dest):
    """Download the url to dest if necessary, optionally checking file
    integrity.
    """
    if not os.path.exists(dest):
        logger.info('Downloading %s to %s', url, dest)
        download(url, dest)


def download(url, dest):
    """Download the url to dest, overwriting dest if it already exists."""
    response = requests.get(url, stream=True)
    with open(dest, 'wb') as f:
        for chunk in response.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)



================================================
FILE: cycada/logging.yml
================================================
---
version: 1
disable_existing_loggers: False
formatters:
    simple:
        format: "[%(asctime)s] %(levelname)-8s %(message)s"
    color:
        class: colorlog.ColoredFormatter
        format: "[%(asctime)s] %(log_color)s%(levelname)-8s%(reset)s %(message)s"
        log_colors:
            DEBUG: "cyan"
            INFO: "green"
            WARNING: "yellow"
            ERROR: "red"
            CRITICAL: "red,bg_white"

handlers:
    console:
        class: cycada.util.TqdmHandler
        level: INFO
        formatter: color

    file_handler:
        class: logging.FileHandler
        level: INFO
        formatter: simple
        encoding: utf8

root:
    level: INFO
    handlers: [console, file_handler]



================================================
FILE: cycada/models/MDAN.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

logger = logging.getLogger(__name__)


class GradientReversalLayer(torch.autograd.Function):
	"""
	Implement the gradient reversal layer for the convenience of domain adaptation neural network.
	The forward part is the identity function while the backward part is the negative function.
	"""
	
	def forward(self, inputs):
		return inputs
	
	def backward(self, grad_output):
		grad_input = grad_output.clone()
		grad_input = -grad_input
		return grad_input


class MDANet(nn.Module):
	"""
	Multi-layer perceptron with adversarial regularizer by domain classification.
	"""
	
	def __init__(self, configs):
		super(MDANet, self).__init__()
		
		self.pooling_layer = nn.AdaptiveAvgPool2d((2, 2))
		self.dim_reduction = nn.Conv2d(4096, 512, kernel_size=1)
		nn.init.xavier_normal_(self.dim_reduction.weight)
		nn.init.constant_(self.dim_reduction.bias, 0.1)
		self.input_dim = configs["input_dim"]
		self.num_hidden_layers = len(configs["hidden_layers"])
		self.num_neurons = [] + [self.input_dim] + configs["hidden_layers"]
		self.num_domains = configs["num_domains"]
		# Parameters of hidden, fully-connected layers, feature learning component.
		self.hiddens = nn.ModuleList([nn.Linear(self.num_neurons[i], self.num_neurons[i + 1])
		                              for i in range(self.num_hidden_layers)])
		# Parameter of the final softmax classification layer.
		self.softmax = nn.Linear(self.num_neurons[-1], configs["num_classes"])
		# Parameter of the domain classification layer, multiple sources single target domain adaptation.
		self.domains = nn.ModuleList([nn.Linear(self.num_neurons[-1], 2) for _ in range(self.num_domains)])
		# Gradient reversal layer.
		self.grls = [GradientReversalLayer() for _ in range(self.num_domains)]
	
	def forward(self, sinputs_syn, sinputs_gta, tinputs):
		"""
		:param sinputs:     A list of k inputs from k source domains.
		:param tinputs:     Input from the target domain.
		:return:
		"""
		sinputs_gta = self.pooling_layer(sinputs_gta)
		sinputs_syn = self.pooling_layer(sinputs_syn)
		tinputs = self.pooling_layer(tinputs)
		
		sinputs_gta = self.dim_reduction(sinputs_gta)
		sinputs_syn = self.dim_reduction(sinputs_syn)
		tinputs = self.dim_reduction(tinputs)
		
		b = sinputs_gta.size()[0]
		syn_relu, gta_relu, th_relu = sinputs_syn.view(b, -1), sinputs_gta.view(b, -1), tinputs.view(b, -1)
		assert (syn_relu[0].size()[0] == self.input_dim)
		
		for hidden in self.hiddens:
			syn_relu = F.relu(hidden(syn_relu))
			gta_relu = F.relu(hidden(gta_relu))
		
		for hidden in self.hiddens:
			th_relu = F.relu(hidden(th_relu))
		
		# Classification probabilities on k source domains.
		logprobs = []
		logprobs.append(F.log_softmax(self.softmax(syn_relu), dim=1))
		logprobs.append(F.log_softmax(self.softmax(gta_relu), dim=1))
		
		# Domain classification accuracies.
		sdomains, tdomains = [], []
		sdomains.append(F.log_softmax(self.domains[0](self.grls[0](syn_relu)), dim=1))
		tdomains.append(F.log_softmax(self.domains[0](self.grls[0](th_relu)), dim=1))
		
		sdomains.append(F.log_softmax(self.domains[1](self.grls[1](gta_relu)), dim=1))
		tdomains.append(F.log_softmax(self.domains[1](self.grls[1](th_relu)), dim=1))
		
		return logprobs, sdomains, tdomains
	
	def inference(self, inputs):
		h_relu = inputs
		for hidden in self.hiddens:
			h_relu = F.relu(hidden(h_relu))
		# Classification probability.
		logprobs = F.log_softmax(self.softmax(h_relu), dim=1)
		return logprobs


================================================
FILE: cycada/models/__init__.py
================================================
from .models import get_model
from .task_net import LeNet
from .task_net import DTNClassifier
from .adda_net import AddaNet
from .fcn8s import VGG16_FCN8s, Discriminator
from .drn import drn26


================================================
FILE: cycada/models/adda_net.py
================================================

import numpy as np
import torch
import torch.nn as nn
from torch.nn import init
from .util import init_weights
from .models import register_model, get_model 

@register_model('AddaNet')
class AddaNet(nn.Module):
    "Defines and Adda Network."
    def __init__(self, num_cls=10, model='LeNet', src_weights_init=None,
            weights_init=None):
        super(AddaNet, self).__init__()
        self.name = 'AddaNet'
        self.base_model = model
        self.num_cls = num_cls
        self.cls_criterion = nn.CrossEntropyLoss()
        self.gan_criterion = nn.CrossEntropyLoss()
      
        self.setup_net()
        if weights_init is not None:
            self.load(weights_init)
        elif src_weights_init is not None:
            self.load_src_net(src_weights_init)
        else:
            raise Exception('AddaNet must be initialized with weights.')
        

    def forward(self, x_s, x_t):
        """Pass source and target images through their
        respective networks."""
        score_s, x_s = self.src_net(x_s, with_ft=True)
        score_t, x_t = self.tgt_net(x_t, with_ft=True)

        if self.discrim_feat:
            d_s = self.discriminator(x_s)
            d_t = self.discriminator(x_t)
        else:
            d_s = self.discriminator(score_s)
            d_t = self.discriminator(score_t)
        return score_s, score_t, d_s, d_t

    def setup_net(self):
        """Setup source, target and discriminator networks."""
        self.src_net = get_model(self.base_model, num_cls=self.num_cls)
        self.tgt_net = get_model(self.base_model, num_cls=self.num_cls)

        input_dim = self.num_cls 
        self.discriminator = nn.Sequential(
                nn.Linear(input_dim, 500),
                nn.ReLU(),
                nn.Linear(500, 500),
                nn.ReLU(),
                nn.Linear(500, 2),
                )

        self.image_size = self.src_net.image_size
        self.num_channels = self.src_net.num_channels

    def load(self, init_path):
        "Loads full src and tgt models."
        net_init_dict = torch.load(init_path)
        self.load_state_dict(net_init_dict)

    def load_src_net(self, init_path):
        """Initialize source and target with source
        weights."""
        self.src_net.load(init_path)
        self.tgt_net.load(init_path)

    def save(self, out_path):
        torch.save(self.state_dict(), out_path)

    def save_tgt_net(self, out_path):
        torch.save(self.tgt_net.state_dict(), out_path)



================================================
FILE: cycada/models/drn.py
================================================
import math

import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torchvision

from .models import register_model
from ..util import safe_load_state_dict

__all__ = ['DRN', 'drn26', 'drn42', 'drn58']

model_urls = {
	'drn26': 'https://tigress-web.princeton.edu/~fy/drn/models/drn26-ddedf421.pth',
	'drn42': 'https://tigress-web.princeton.edu/~fy/drn/models/drn42-9d336e8c.pth',
	'drn58': 'https://tigress-web.princeton.edu/~fy/drn/models/drn58-0a53a92c.pth'
}


def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
	return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
	                 padding=padding, bias=False, dilation=dilation)


class BasicBlock(nn.Module):
	expansion = 1
	
	def __init__(self, inplanes, planes, stride=1, downsample=None,
	             dilation=(1, 1), residual=True):
		super(BasicBlock, self).__init__()
		self.conv1 = conv3x3(inplanes, planes, stride,
		                     padding=dilation[0], dilation=dilation[0])
		self.bn1 = nn.BatchNorm2d(planes)
		self.relu = nn.ReLU(inplace=True)
		self.conv2 = conv3x3(planes, planes,
		                     padding=dilation[1], dilation=dilation[1])
		self.bn2 = nn.BatchNorm2d(planes)
		self.downsample = downsample
		self.stride = stride
		self.residual = residual
	
	def forward(self, x):
		residual = x
		
		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)
		
		out = self.conv2(out)
		out = self.bn2(out)
		
		if self.downsample is not None:
			residual = self.downsample(x)
		if self.residual:
			out += residual
		out = self.relu(out)
		
		return out


class Bottleneck(nn.Module):
	expansion = 4
	
	def __init__(self, inplanes, planes, stride=1, downsample=None,
	             dilation=(1, 1), residual=True):
		super(Bottleneck, self).__init__()
		self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
		self.bn1 = nn.BatchNorm2d(planes)
		self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
		                       padding=dilation[1], bias=False,
		                       dilation=dilation[1])
		self.bn2 = nn.BatchNorm2d(planes)
		self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
		self.bn3 = nn.BatchNorm2d(planes * 4)
		self.relu = nn.ReLU(inplace=True)
		self.downsample = downsample
		self.stride = stride
	
	def forward(self, x):
		residual = x
		
		out = self.conv1(x)
		out = self.bn1(out)
		out = self.relu(out)
		
		out = self.conv2(out)
		out = self.bn2(out)
		out = self.relu(out)
		
		out = self.conv3(out)
		out = self.bn3(out)
		
		if self.downsample is not None:
			residual = self.downsample(x)
		
		out += residual
		out = self.relu(out)
		
		return out


class DRN(nn.Module):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize(
			mean=[0.485, 0.456, 0.406],
			std=[0.229, 0.224, 0.225]),
	])
	
	def __init__(self, block, layers, num_cls=1000,
	             channels=(16, 32, 64, 128, 256, 512, 512, 512),
	             out_map=False, out_middle=False, pool_size=28,
	             weights_init=None, pretrained=True, finetune=False,
	             output_last_ft=False, modelname='drn26'):
		if output_last_ft:
			print('DRN discrim feat not implemented, using scores')
		
		super(DRN, self).__init__()
		self.inplanes = channels[0]
		self.output_last_ft = output_last_ft
		self.out_map = out_map
		self.out_dim = channels[-1]
		self.out_middle = out_middle
		self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3,
		                       bias=False)
		self.bn1 = nn.BatchNorm2d(channels[0])
		self.relu = nn.ReLU(inplace=True)
		
		self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1)
		self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2)
		
		self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2)
		self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2)
		self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2,
		                               new_level=False)
		self.layer6 = None if layers[5] == 0 else \
			self._make_layer(block, channels[5], layers[5], dilation=4,
			                 new_level=False)
		self.layer7 = None if layers[6] == 0 else \
			self._make_layer(BasicBlock, channels[6], layers[6], dilation=2,
			                 new_level=False, residual=False)
		self.layer8 = None if layers[7] == 0 else \
			self._make_layer(BasicBlock, channels[7], layers[7], dilation=1,
			                 new_level=False, residual=False)
		
		if num_cls > 0:
			self.avgpool = nn.AvgPool2d(pool_size)
			# self.fc = nn.Linear(self.out_dim, num_classes)
			self.fc = nn.Conv2d(self.out_dim, num_cls, kernel_size=1,
			                    stride=1, padding=0, bias=True)
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
				m.weight.data.normal_(0, math.sqrt(2. / n))
			elif isinstance(m, nn.BatchNorm2d):
				m.weight.data.fill_(1)
				m.bias.data.zero_()
		
		if pretrained:
			if not weights_init is None:
				state_dict = torch.load(weights_init)
				print('Using state dict from', weights_init)
			else:
				state_dict = model_zoo.load_url(model_urls[modelname])
			
			if finetune:
				del state_dict['fc.weight']
				del state_dict['fc.bias']
				safe_load_state_dict(self, state_dict)
				print('Finetune: remove last layer')
			else:
				self.load_state_dict(state_dict)
				print('Loading full model')
	
	def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
	                new_level=True, residual=True):
		assert dilation == 1 or dilation % 2 == 0
		downsample = None
		if stride != 1 or self.inplanes != planes * block.expansion:
			downsample = nn.Sequential(
				nn.Conv2d(self.inplanes, planes * block.expansion,
				          kernel_size=1, stride=stride, bias=False),
				nn.BatchNorm2d(planes * block.expansion),
			)
		
		layers = []
		layers.append(block(
			self.inplanes, planes, stride, downsample,
			dilation=(1, 1) if dilation == 1 else (
				dilation // 2 if new_level else dilation, dilation),
			residual=residual))
		self.inplanes = planes * block.expansion
		for i in range(1, blocks):
			layers.append(block(self.inplanes, planes, residual=residual,
			                    dilation=(dilation, dilation)))
		
		return nn.Sequential(*layers)
	
	def forward(self, x):
		_, _, h, w = x.size()
		y = list()
		
		x = self.conv1(x)
		x = self.bn1(x)
		x = self.relu(x)
		x = self.layer1(x)
		y.append(x)
		x = self.layer2(x)
		y.append(x)
		
		x = self.layer3(x)
		y.append(x)
		
		x = self.layer4(x)
		y.append(x)
		
		x = self.layer5(x)
		y.append(x)
		
		if self.layer6 is not None:
			x = self.layer6(x)
			y.append(x)
		
		if self.layer7 is not None:
			x = self.layer7(x)
			y.append(x)
		
		if self.layer8 is not None:
			x = self.layer8(x)
			y.append(x)
		
		if self.output_last_ft:
			ft_to_save = x
		
		if self.out_map:
			x = self.fc(x)
			x = nn.functional.interpolate(x, (h, w), mode='bilinear', align_corners=True)
		else:
			x = self.avgpool(x)
			x = self.fc(x)
			x = x.view(x.size(0), -1)
		
		if self.out_middle:
			return x, y
		elif self.output_last_ft:
			return x, ft_to_save
		else:
			return x


@register_model('drn26')
def drn26(pretrained=True, finetune=False, out_map=True, **kwargs):
	model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], modelname='drn26',
	            out_map=out_map, finetune=finetune, **kwargs)
	# if pretrained:
	#    state_dict = model_zoo.load_url(model_urls['drn26'])
	#    if finetune:
	#        del state_dict['fc.weight']
	#        del state_dict['fc.bias']
	#        safe_load_state_dict(model, state_dict)
	#    else:
	#        model.load_state_dict(state_dict)
	return model


@register_model('drn42')
def drn42(pretrained=False, finetune=False, out_map=True, **kwargs):
	model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], modelname='drn42',
	            out_map=out_map, finetune=finetune, **kwargs)
	# if pretrained:
	#    model.load_state_dict(model_zoo.load_url(model_urls['drn42']))
	return model


def drn58(pretrained=False, **kwargs):
	model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], **kwargs)
	if pretrained:
		model.load_state_dict(model_zoo.load_url(model_urls['drn58']))
	return model


================================================
FILE: cycada/models/fcn8s.py
================================================
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.nn import init
from torch.utils import model_zoo
from torchvision.models import vgg

from .models import register_model


def get_upsample_filter(size):
	"""Make a 2D bilinear kernel suitable for upsampling"""
	factor = (size + 1) // 2
	if size % 2 == 1:
		center = factor - 1
	else:
		center = factor - 0.5
	og = np.ogrid[:size, :size]
	filter = (1 - abs(og[0] - center) / factor) * \
	         (1 - abs(og[1] - center) / factor)
	return torch.from_numpy(filter).float()


class Bilinear(nn.Module):
	
	def __init__(self, factor, num_channels):
		super().__init__()
		self.factor = factor
		filter = get_upsample_filter(factor * 2)
		w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2)
		for i in range(num_channels):
			w[i, i] = filter
		self.register_buffer('w', w)
	
	def forward(self, x):
		return F.conv_transpose2d(x, Variable(self.w), stride=self.factor)


@register_model('fcn8s')
class VGG16_FCN8s(nn.Module):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize(
			mean=[0.485, 0.456, 0.406],
			std=[0.229, 0.224, 0.225]),
	])
	
	def __init__(self, num_cls=19, pretrained=True, weights_init=None,
	             output_last_ft=False):
		super().__init__()
		self.output_last_ft = output_last_ft
		if weights_init:
			batch_norm = False
		else:
			batch_norm = True
		self.vgg = make_layers(vgg.cfg['D'], batch_norm=False)
		self.vgg_head = nn.Sequential(
			nn.Conv2d(512, 4096, 7),
			nn.ReLU(inplace=True),
			nn.Dropout2d(p=0.5),
			nn.Conv2d(4096, 4096, 1),
			nn.ReLU(inplace=True),
			nn.Dropout2d(p=0.5),
			nn.Conv2d(4096, num_cls, 1)
		)
		self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls)
		self.upscore8 = Bilinear(8, num_cls)
		self.score_pool4 = nn.Conv2d(512, num_cls, 1)
		for param in self.score_pool4.parameters():
			# init.constant(param, 0)
			init.constant_(param, 0)
		self.score_pool3 = nn.Conv2d(256, num_cls, 1)
		for param in self.score_pool3.parameters():
			# init.constant(param, 0)
			init.constant_(param, 0)
		
		if pretrained:
			if weights_init is not None:
				self.load_weights(torch.load(weights_init))
			else:
				self.load_base_weights()
	
	def load_base_vgg(self, weights_state_dict):
		vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.')
		self.vgg.load_state_dict(vgg_state_dict)
	
	def load_vgg_head(self, weights_state_dict):
		vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.')
		self.vgg_head.load_state_dict(vgg_head_state_dict)
	
	def get_dict_by_prefix(self, weights_state_dict, prefix):
		return {k[len(prefix):]: v
		        for k, v in weights_state_dict.items()
		        if k.startswith(prefix)}
	
	def load_weights(self, weights_state_dict):
		self.load_base_vgg(weights_state_dict)
		self.load_vgg_head(weights_state_dict)
	
	def split_vgg_head(self):
		self.classifier = list(self.vgg_head.children())[-1]
		self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1])
	
	def forward(self, x):
		input = x
		x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0)
		intermediates = {}
		fts_to_save = {16: 'pool3', 23: 'pool4'}
		for i, module in enumerate(self.vgg):
			x = module(x)
			if i in fts_to_save:
				intermediates[fts_to_save[i]] = x
		
		ft_to_save = 5  # Dropout before classifier
		last_ft = {}
		for i, module in enumerate(self.vgg_head):
			x = module(x)
			if i == ft_to_save:
				last_ft = x
		
		_, _, h, w = x.size()
		upscore2 = self.upscore2(x)
		pool4 = intermediates['pool4']
		score_pool4 = self.score_pool4(0.01 * pool4)
		score_pool4c = _crop(score_pool4, upscore2, offset=5)
		fuse_pool4 = upscore2 + score_pool4c
		upscore_pool4 = self.upscore_pool4(fuse_pool4)
		pool3 = intermediates['pool3']
		score_pool3 = self.score_pool3(0.0001 * pool3)
		score_pool3c = _crop(score_pool3, upscore_pool4, offset=9)
		fuse_pool3 = upscore_pool4 + score_pool3c
		upscore8 = self.upscore8(fuse_pool3)
		score = _crop(upscore8, input, offset=31)
		if self.output_last_ft:
			return score, last_ft
		else:
			return score
	
	def load_base_weights(self):
		"""This is complicated because we converted the base model to be fully
		convolutional, so some surgery needs to happen here."""
		base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16'])
		vgg_state_dict = {k[len('features.'):]: v
		                  for k, v in base_state_dict.items()
		                  if k.startswith('features.')}
		self.vgg.load_state_dict(vgg_state_dict)
		vgg_head_params = self.vgg_head.parameters()
		for k, v in base_state_dict.items():
			if not k.startswith('classifier.'):
				continue
			if k.startswith('classifier.6.'):
				# skip final classifier output
				continue
			vgg_head_param = next(vgg_head_params)
			vgg_head_param.data = v.view(vgg_head_param.size())


class VGG16_FCN8s_caffe(VGG16_FCN8s):
	transform = torchvision.transforms.Compose([
		torchvision.transforms.ToTensor(),
		torchvision.transforms.Normalize(
			mean=[0.485, 0.458, 0.408],
			std=[0.00392156862745098] * 3),
		torchvision.transforms.Lambda(
			lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1))
	])
	
	def load_base_weights(self):
		base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth')
		vgg_state_dict = {k[len('features.'):]: v
		                  for k, v in base_state_dict.items()
		                  if k.startswith('features.')}
		self.vgg.load_state_dict(vgg_state_dict)
		vgg_head_params = self.vgg_head.parameters()
		for k, v in base_state_dict.items():
			if not k.startswith('classifier.'):
				continue
			if k.startswith('classifier.6.'):
				# skip final classifier output
				continue
			vgg_head_param = next(vgg_head_params)
			vgg_head_param.data = v.view(vgg_head_param.size())


class Discriminator(nn.Module):
	def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''):
		super().__init__()
		dim1 = 1024 if input_dim == 4096 else 512
		dim2 = int(dim1 / 2)
		self.D = nn.Sequential(
			nn.Conv2d(input_dim, dim1, 1),
			nn.Dropout2d(p=0.5),
			nn.ReLU(inplace=True),
			nn.Conv2d(dim1, dim2, 1),
			nn.Dropout2d(p=0.5),
			nn.ReLU(inplace=True),
			nn.Conv2d(dim2, output_dim, 1)
		)
		
		if pretrained and weights_init is not None:
			self.load_weights(weights_init)
	
	def forward(self, x):
		d_score = self.D(x)
		return d_score
	
	def load_weights(self, weights):
		print('Loading discriminator weights')
		self.load_state_dict(torch.load(weights))


class Transform_Module(nn.Module):
	def __init__(self, input_dim=4096):
		super().__init__()
		self.transform = nn.Sequential(
			nn.Conv2d(input_dim, input_dim, 1),
			nn.ReLU(inplace=True),
			# nn.Conv2d(input_dim, input_dim, 1),
			# nn.ReLU(inplace=True),
		)
		for m in self.modules():
			if isinstance(m, nn.Conv2d):
				init_eye(m.weight)
				m.bias.data.zero_()
	
	def forward(self, x):
		t_x = self.transform(x)
		return t_x


def init_eye(tensor):
	if isinstance(tensor, Variable):
		init_eye(tensor.data)
		return tensor
	return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1)))


def _crop(input, shape, offset=0):
	_, _, h, w = shape.size()
	return input[:, :, offset:offset + h, offset:offset + w].contiguous()


def make_layers(cfg, batch_norm=False):
	"""This is almost verbatim from torchvision.models.vgg, except that the
	MaxPool2d modules are configured with ceil_mode=True.
	"""
	layers = []
	in_channels = 3
	for v in cfg:
		if v == 'M':
			layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True))
		else:
			conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
			modules = [conv2d, nn.ReLU(inplace=True)]
			if batch_norm:
				modules.insert(1, nn.BatchNorm2d(v))
			layers.extend(modules)
			in_channels = v
	return nn.Sequential(*layers)


================================================
FILE: cycada/models/models.py
================================================
import torch

models = {}
def register_model(name):
    def decorator(cls):
        models[name] = cls
        return cls
    return decorator


def get_model(name, num_cls=10, **args):
    net = models[name](num_cls=num_cls, **args)
    if torch.cuda.is_available():
        net = net.cuda()
    return net


================================================
FILE: cycada/models/task_net.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
from .models import register_model 
from .util import init_weights
import numpy as np

class TaskNet(nn.Module):

    num_channels = 3
    image_size = 32
    name = 'TaskNet'

    "Basic class which does classification."
    def __init__(self, num_cls=10, weights_init=None):
        super(TaskNet, self).__init__()
        self.num_cls = num_cls
        self.setup_net()
        self.criterion = nn.CrossEntropyLoss()
        if weights_init is not None:
            self.load(weights_init)
        else:
            init_weights(self)

    def forward(self, x, with_ft=False):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        x = self.fc_params(x)
        score = self.classifier(x)
        if with_ft:
            return score, x
        else:
            return score

    def setup_net(self):
        """Method to be implemented in each class."""
        pass

    def load(self, init_path):
        net_init_dict = torch.load(init_path)
        self.load_state_dict(net_init_dict)

    def save(self, out_path):
        torch.save(self.state_dict(), out_path)

@register_model('LeNet')
class LeNet(TaskNet):
    "Network used for MNIST or USPS experiments."    

    num_channels = 1
    image_size = 28
    name = 'LeNet'
    out_dim = 500 # dim of last feature layer

    def setup_net(self):

        self.conv_params = nn.Sequential(
                nn.Conv2d(self.num_channels, 20, kernel_size=5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                nn.Conv2d(20, 50, kernel_size=5),
                nn.Dropout2d(p=0.5),
                nn.MaxPool2d(2),
                nn.ReLU(),
                )
        
        self.fc_params = nn.Linear(50*4*4, 500)
        self.classifier = nn.Sequential(
                nn.ReLU(),
                nn.Dropout(p=0.5),
                nn.Linear(500, self.num_cls)
                )


@register_model('DTN')
class DTNClassifier(TaskNet):
    "Classifier used for SVHN->MNIST Experiment"

    num_channels = 3
    image_size = 32
    name = 'DTN'
    out_dim = 512 # dim of last feature layer

    def setup_net(self):
        self.conv_params = nn.Sequential (
                nn.Conv2d(self.num_channels, 64, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm2d(64),
                nn.Dropout2d(0.1),
                nn.ReLU(),
                nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm2d(128),
                nn.Dropout2d(0.3),
                nn.ReLU(),
                nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm2d(256),
                nn.Dropout2d(0.5),
                nn.ReLU()
                )
    
        self.fc_params = nn.Sequential (
                nn.Linear(256*4*4, 512),
                nn.BatchNorm1d(512),
                )

        self.classifier = nn.Sequential(
                nn.ReLU(),
                nn.Dropout(),
                nn.Linear(512, self.num_cls)
                )


================================================
FILE: cycada/models/util.py
================================================
import torch.nn as nn
from torch.nn import init

def init_weights(obj):
    for m in obj.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            init.xavier_normal_(m.weight)
            m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
            m.reset_parameters()


================================================
FILE: cycada/tools/__init__.py
================================================


================================================
FILE: cycada/tools/train_adda_net.py
================================================
from __future__ import print_function

import os
from os.path import join
import numpy as np

# Import from torch
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Import from within Package 
from ..models.models import get_model
from ..data.data_loader import load_data
from ..tools.test_task_net import test
from ..tools.util import make_variable

def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch):
   
    log_interval = 100 # specifies how often to display
  
    N = min(len(loader_src.dataset), len(loader_tgt.dataset)) 
    joint_loader = zip(loader_src, loader_tgt)
      
    net.train()
   
    last_update = -1
    for batch_idx, ((data_s, _), (data_t, _)) in enumerate(joint_loader):
        
        # log basic adda train info
        info_str = "[Train Adda] Epoch: {} [{}/{} ({:.2f}%)]".format(
            epoch, batch_idx*len(data_t), N, 100 * batch_idx / N)
   
        ########################
        # Setup data variables #
        ########################
        data_s = make_variable(data_s, requires_grad=False)
        data_t = make_variable(data_t, requires_grad=False)
        
        ##########################
        # Optimize discriminator #
        ##########################

        # zero gradients for optimizer
        opt_dis.zero_grad()

        # extract and concat features
        score_s = net.src_net(data_s)
        score_t = net.tgt_net(data_t)
        f = torch.cat((score_s, score_t), 0)
        
        # predict with discriminator
        pred_concat = net.discriminator(f)

        # prepare real and fake labels: source=1, target=0
        target_dom_s = make_variable(torch.ones(len(data_s)).long(), requires_grad=False)
        target_dom_t = make_variable(torch.zeros(len(data_t)).long(), requires_grad=False)
        label_concat = torch.cat((target_dom_s, target_dom_t), 0)

        # compute loss for disciminator
        loss_dis = net.gan_criterion(pred_concat, label_concat)
        loss_dis.backward()

        # optimize discriminator
        opt_dis.step()

        # compute discriminator acc
        pred_dis = torch.squeeze(pred_concat.max(1)[1])
        acc = (pred_dis == label_concat).float().mean()
        
        # log discriminator update info
        info_str += " acc: {:0.1f} D: {:.3f}".format(acc.item()*100, loss_dis.item())

        ###########################
        # Optimize target network #
        ###########################

        # only update net if discriminator is strong
        if acc.item() > 0.6:
            
            last_update = batch_idx
        
            # zero out optimizer gradients
            opt_dis.zero_grad()
            opt_net.zero_grad()

            # extract target features
            score_t = net.tgt_net(data_t)

            # predict with discriinator
            pred_tgt = net.discriminator(score_t)
            
            # create fake label
            label_tgt = make_variable(torch.ones(pred_tgt.size(0)).long(), requires_grad=False)
            
            # compute loss for target network
            loss_gan_t = net.gan_criterion(pred_tgt, label_tgt) 
            loss_gan_t.backward()

            # optimize tgt network
            opt_net.step()

            # log net update info
            info_str += " G: {:.3f}".format(loss_gan_t.item()) 

        ###########
        # Logging #
        ###########
        if batch_idx % log_interval == 0:
            print(info_str)

    return last_update


def train_adda(src, tgt, model, num_cls, num_epoch=200,
        batch=128, datadir="", outdir="", 
        src_weights=None, weights=None, lr=1e-5, betas=(0.9,0.999),
        weight_decay=0):
    """Main function for training ADDA."""

    ###########################
    # Setup cuda and networks #
    ###########################

    # setup cuda
    if torch.cuda.is_available():
        kwargs = {'num_workers': 1, 'pin_memory': True}
    else:
        kwargs = {}

    # setup network 
    net = get_model('AddaNet', model=model, num_cls=num_cls,
            src_weights_init=src_weights)
    
    # print network and arguments
    print(net)
    print('Training Adda {} model for {}->{}'.format(model, src, tgt))

    #######################################
    # Setup data for training and testing #
    #######################################
    train_src_data = load_data(src, 'train', batch=batch, 
        rootdir=join(datadir, src), num_channels=net.num_channels, 
        image_size=net.image_size, download=True, kwargs=kwargs)
    train_tgt_data = load_data(tgt, 'train', batch=batch, 
        rootdir=join(datadir, tgt), num_channels=net.num_channels, 
        image_size=net.image_size, download=True, kwargs=kwargs)

    ######################
    # Optimization setup #
    ######################
 
    net_param = net.tgt_net.parameters()
    opt_net = optim.Adam(net_param, lr=lr, weight_decay=weight_decay, betas=betas)
    opt_dis = optim.Adam(net.discriminator.parameters(), lr=lr, 
            weight_decay=weight_decay, betas=betas)

    ##############
    # Train Adda #
    ##############
    for epoch in range(num_epoch):
        err = train(train_src_data, train_tgt_data, net, opt_net, opt_dis, epoch) 
        if err == -1:
            print("No suitable discriminator")
            break
       
    ##############
    # Save Model #
    ##############
    os.makedirs(outdir, exist_ok=True)
    outfile = join(outdir, 'adda_{:s}_net_{:s}_{:s}.pth'.format(
        model, src, tgt))
    print('Saving to', outfile)
    net.save(outfile)



================================================
FILE: cycada/tools/train_task_net.py
================================================
from __future__ import print_function

import os
from os.path import join
import numpy as np
import argparse

# Import from torch
import torch
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# Import from Cycada Package 
from ..models.models import get_model
from ..data.data_loader import load_data
from .test_task_net import test
from .util import make_variable

def train_epoch(loader, net, opt_net, epoch):
    log_interval = 100 # specifies how often to display
    net.train()
    for batch_idx, (data, target) in enumerate(loader):

        # make data variables
        data = make_variable(data, requires_grad=False)
        target = make_variable(target, requires_grad=False)
        
        # zero out gradients
        opt_net.zero_grad()
       
        # forward pass
        score = net(data)
        loss = net.criterion(score, target)
        
        # backward pass
        loss.backward()
        
        # optimize classifier and representation
        opt_net.step()
       
        # Logging
        if batch_idx % log_interval == 0:
            print('[Train] Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(loader.dataset),
                100. * batch_idx / len(loader), loss.item()), end="")
            pred = score.data.max(1)[1]
            correct = pred.eq(target.data).cpu().sum()
            acc = correct.item() / len(pred) * 100.0
            print('  Acc: {:.2f}'.format(acc))


def train(data, datadir, model, num_cls, outdir='', 
        num_epoch=100, batch=128, 
        lr=1e-4, betas=(0.9, 0.999), weight_decay=0):
    """Train a classification net and evaluate on test set."""

    # Setup GPU Usage
    if torch.cuda.is_available(): 
        kwargs = {'num_workers': 1, 'pin_memory': True}
    else:
        kwargs = {}

    ############
    # Load Net #
    ############
    net = get_model(model, num_cls=num_cls)
    print('-------Training net--------')
    print(net)

    ############################
    # Load train and test data # 
    ############################
    train_data = load_data(data, 'train', batch=batch, 
        rootdir=datadir, num_channels=net.num_channels, 
        image_size=net.image_size, download=True, kwargs=kwargs)
    
    test_data = load_data(data, 'test', batch=batch, 
        rootdir=datadir, num_channels=net.num_channels, 
        image_size=net.image_size, download=True, kwargs=kwargs)
   
    ###################
    # Setup Optimizer #
    ###################
    opt_net = optim.Adam(net.parameters(), lr=lr, betas=betas, 
            weight_decay=weight_decay)
    
    #########
    # Train #
    #########
    print('Training {} model for {}'.format(model, data))
    for epoch in range(num_epoch):
        train_epoch(train_data, net, opt_net, epoch)
    
    ########
    # Test #
    ########
    if test_data is not None:
        print('Evaluating {}-{} model on {} test set'.format(model, data, data))
        test(test_data, net)

    ############
    # Save net #
    ############
    os.makedirs(outdir, exist_ok=True)
    outfile = join(outdir, '{:s}_net_{:s}.pth'.format(model, data))
    print('Saving to', outfile)
    net.save(outfile)

    return net


================================================
FILE: cycada/tools/util.py
================================================
from functools import partial

import torch
from torch.autograd import Variable


def make_variable(tensor, volatile=False, requires_grad=True):
	if torch.cuda.is_available():
		tensor = tensor.cuda()
	if volatile:
		requires_grad = False
	return Variable(tensor, volatile=volatile, requires_grad=requires_grad)


def pairwise_distance(x, y):
	if not len(x.shape) == len(y.shape):
		raise ValueError('Both inputs should be matrices.')
	
	if x.shape[1] != y.shape[1]:
		raise ValueError('The number of features should be the same.')
	
	x = x.view(x.shape[0], x.shape[1], 1)
	y = torch.transpose(y, 0, 1)
	output = torch.sum((x - y) ** 2, 1)
	output = torch.transpose(output, 0, 1)
	
	return output


def gaussian_kernel_matrix(x, y, sigmas):
	sigmas = sigmas.view(sigmas.shape[0], 1)
	beta = 1. / (2. * sigmas)
	dist = pairwise_distance(x, y).contiguous()
	dist_ = dist.view(1, -1)
	s = torch.matmul(beta, dist_)
	
	return torch.sum(torch.exp(-s), 0).view_as(dist)


def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
	cost = torch.mean(kernel(x, x))
	cost += torch.mean(kernel(y, y))
	cost -= 2 * torch.mean(kernel(x, y))
	
	return cost


def mmd_loss(source_features, target_features):
	sigmas = [
		1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
		1e3, 1e4, 1e5, 1e6
	]
	gaussian_kernel = partial(
		gaussian_kernel_matrix, sigmas=Variable(torch.cuda.FloatTensor(sigmas))
	)
	loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel)
	loss_value = loss_value
	
	return loss_value


================================================
FILE: cycada/transforms.py
================================================
"""These random transforms extend the transforms provided in torchvision to
allow for transforming multiple images at the same time. This ensures that the
images receive the same transformation, e.g. the provided images are either all
mirrored or all left unchanged.

For example, this is useful in segmentation tasks, where a transformation to the
image necessitates that same transformation on the label.
"""

import numbers
import random

import torch
import torchvision


class RandomCrop(object):
	"""Crops the given PIL.Image at a random location to have a region of
	the given size. size can be a tuple (target_height, target_width)
	or an integer, in which case the target will be of a square shape (size, size)
	"""
	
	def __init__(self, size):
		if isinstance(size, numbers.Number):
			self.size = (int(size), int(size))
		else:
			self.size = size
	
	def __call__(self, tensors):
		output = []
		h, w = None, None
		th, tw = self.size
		for tensor in tensors:
			if h is None and w is None:
				_, h, w = tensor.size()
			elif tensor.size()[-2:] != (h, w):
				print(tensor.size(), (h, w))
				raise ValueError('Images must be same size')
		if w == tw and h == th:
			return tensors
		x1 = random.randint(0, w - tw)
		y1 = random.randint(0, h - th)
		for tensor in tensors:
			output.append(tensor[..., y1:y1 + th, x1:x1 + tw].contiguous())
		return output


class HalfCrop(object):
	"""Crops halt the given PIL.Image randomly takes left or right to have a region of
	the given size. size can be a tuple (target_height, target_width)
	or an integer, in which case the target will be of a square shape (size, size)
	"""
	
	def __call__(self, tensors):
		output = []
		th, tw = self.size
		tw_half = tw // 2
		left_side = random.randint(0, 1)
		x1 = 0 + left_size * tw_half  # random.randint(0, w - tw)
		for tensor in tensors:
			output.append(tensor[..., ..., x1:x1 + tw_half].contiguous())
		return output


class RandomHorizontalFlip(object):
	"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
	"""
	
	def __call__(self, tensors):
		if random.random() < 0.5:
			output = []
			for tensor in tensors:
				indices = torch.arange(tensor.size(-1) - 1, -1, -1).long()
				output.append(tensor.index_select(-1, indices))
			return output
		return tensors


def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=None):
	transforms = []
	if crop is not None:
		transforms.append(RandomCrop(crop))
	if halfcrop is not None:
		transforms.append(HalfCrop())
	if flip:
		transforms.append(RandomHorizontalFlip())
	
	transform = torchvision.transforms.Compose(transforms)
	batch = [transform(x) for x in batch]
	return torch.utils.data.dataloader.default_collate(batch)


================================================
FILE: cycada/util.py
================================================
import logging
import logging.config
import os.path
from collections import OrderedDict

import numpy as np
import torch
import yaml
from torch.nn.parameter import Parameter
from tqdm import tqdm


class TqdmHandler(logging.StreamHandler):

    def __init__(self):
        logging.StreamHandler.__init__(self)

    def emit(self, record):
        msg = self.format(record)
        tqdm.write(msg)


def config_logging(logfile=None):
    path = os.path.join(os.path.dirname(__file__), 'logging.yml')
    with open(path, 'r') as f:
        config = yaml.load(f.read())
    if logfile is None:
        del config['handlers']['file_handler']
        del config['root']['handlers'][-1]
    else:
        config['handlers']['file_handler']['filename'] = logfile
    logging.config.dictConfig(config)


def to_tensor_raw(im):
    return torch.from_numpy(np.array(im, np.int64, copy=False))


def safe_load_state_dict(net, state_dict):
    """Copies parameters and buffers from :attr:`state_dict` into
    this module and its descendants. Any params in :attr:`state_dict`
    that do not match the keys returned by :attr:`net`'s :func:`state_dict()`
    method or have differing sizes are skipped.

    Arguments:
        state_dict (dict): A dict containing parameters and
            persistent buffers.
    """
    own_state = net.state_dict()
    skipped = []
    for name, param in state_dict.items():
        if name not in own_state:
            skipped.append(name)
            continue
        if isinstance(param, Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        if own_state[name].size() != param.size():
            skipped.append(name)
            continue
        own_state[name].copy_(param)

    if skipped:
        logging.info('Skipped loading some parameters: {}'.format(skipped))

def step_lr(optimizer, mult):
    for param_group in optimizer.param_groups:
        lr = param_group['lr']
        param_group['lr'] = lr * mult


================================================
FILE: cyclegan/.gitignore
================================================
.DS_Store
debug*
checkpoints/
results/
build/
dist/
*.png
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
.idea


================================================
FILE: cyclegan/data/__init__.py
================================================
import sys

import torch.utils.data
from data.base_data_loader import BaseDataLoader

sys.path.append('/nfs/project/libo_i/MADAN')
from cycada.transforms import augment_collate


def CreateDataLoader(opt):
	data_loader = CustomDatasetDataLoader()
	print(data_loader.name())
	data_loader.initialize(opt)
	return data_loader


def CreateDataset(opt):
	dataset = None
	if opt.dataset_mode == 'synthia_cityscapes':
		from data.synthia_cityscapes import SynthiaCityscapesDataset
		dataset = SynthiaCityscapesDataset()
	elif opt.dataset_mode == 'gta5_cityscapes':
		from data.gta5_cityscapes import GTAVCityscapesDataset
		dataset = GTAVCityscapesDataset()
	elif opt.dataset_mode == 'gta_synthia_cityscapes':
		from data.gta_synthia_cityscapes import GTASynthiaCityscapesDataset
		dataset = GTASynthiaCityscapesDataset()
	elif opt.dataset_mode == 'merged_gta_synthia_cityscapes':
		from data.merged_gta_synthia_cityscapes import MergedGTASynthiaCityscapesDataset
		dataset = MergedGTASynthiaCityscapesDataset()
	else:
		raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
	
	print("dataset [%s] was created" % (dataset.name()))
	dataset.initialize(opt)
	return dataset


class CustomDatasetDataLoader(BaseDataLoader):
	def name(self):
		return 'CustomDatasetDataLoader'
	
	def initialize(self, opt):
		BaseDataLoader.initialize(self, opt)
		self.dataset = CreateDataset(opt)
		self.dataloader = torch.utils.data.DataLoader(
			self.dataset,
			batch_size=opt.batchSize,
			shuffle=not opt.serial_batches,
			num_workers=int(opt.nThreads))
	
	def load_data(self):
		return self
	
	def __len__(self):
		return min(len(self.dataset), self.opt.max_dataset_size)
	
	def __iter__(self):
		for i, data in enumerate(self.dataloader):
			if i * self.opt.batchSize >= self.opt.max_dataset_size:
				break
			yield data


================================================
FILE: cyclegan/data/base_data_loader.py
================================================
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data():
        return None


================================================
FILE: cyclegan/data/base_dataset.py
================================================
import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image


class BaseDataset(data.Dataset):
	def __init__(self):
		super(BaseDataset, self).__init__()
	
	def name(self):
		return 'BaseDataset'
	
	def initialize(self, opt):
		pass


# TODO: 增加crop的部分
def get_transform(opt):
	transform_list = []
	if opt.resize_or_crop == 'resize_and_crop':
		osize = [int(opt.loadSize), int(opt.loadSize)]
		transform_list.append(transforms.Resize(osize, interpolation=Image.BICUBIC))
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	if opt.resize_or_crop == 'resize_only':
		osize = [int(opt.loadSize), int(opt.loadSize)]
		transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
	elif opt.resize_or_crop == 'crop':
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	elif opt.resize_or_crop == 'scale_width':
		transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
	elif opt.resize_or_crop == 'scale_width_and_crop':
		transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC))
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	
	if opt.isTrain and not opt.no_flip:
		transform_list.append(transforms.RandomHorizontalFlip())
	
	transform_list += [transforms.ToTensor(),
	                   transforms.Normalize((0.5, 0.5, 0.5),
	                                        (0.5, 0.5, 0.5))]
	return transforms.Compose(transform_list)


def get_label_transform(opt):
	transform_list = []
	if opt.resize_or_crop == 'resize_and_crop':
		osize = [opt.loadSize, opt.loadSize]
		transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	elif opt.resize_or_crop == 'resize_only':
		osize = [opt.loadSize, opt.loadSize]
		transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST))
	elif opt.resize_or_crop == 'crop':
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	elif opt.resize_or_crop == 'scale_width':
		transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
	elif opt.resize_or_crop == 'scale_width_and_crop':
		transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST))
		transform_list.append(transforms.RandomCrop(opt.fineSize))
	# transform_list.append(transforms.RandomCrop(opt.fineSize))
	
	if opt.isTrain and not opt.no_flip:
		transform_list.append(transforms.RandomHorizontalFlip())
	
	transform_list.append(transforms.Lambda(lambda img: to_tensor_raw(img)))
	return transforms.Compose(transform_list)


def __scale_width(img, target_width):
	ow, oh = img.size
	if (ow == target_width):
		return img
	w = target_width
	h = int(target_width * oh / ow)
	return img.resize((w, h), Image.BICUBIC)


def to_tensor_raw(im):
	return torch.from_numpy(np.array(im, np.int64, copy=False))


================================================
FILE: cyclegan/data/cityscapes.py
================================================
import numpy as np

ignore_label = 255
id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,
            3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label,
            7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4,
            14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5,
            18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14,
            28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18}
palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30,
           220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70,
           0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
           'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
           'bicycle']


def remap_labels_to_train_ids(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out


================================================
FILE: cyclegan/data/gta5_cityscapes.py
================================================
import os.path
import random

import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']


# This dataset is used to conduct GTA->CityScapes images transfer procedure.
class GTAVCityscapesDataset(BaseDataset):
	def initialize(self, opt):
		self.opt = opt
		self.root = opt.dataroot
		self.dir_A = os.path.join(opt.dataroot, 'gta5', 'images')
		self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
		self.dir_A_label = os.path.join(opt.dataroot, 'gta5', 'labels')
		self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
		
		self.A_paths = make_dataset(self.dir_A)
		self.B_paths = make_dataset(self.dir_B)
		
		self.A_paths = sorted(self.A_paths)
		self.B_paths = sorted(self.B_paths)
		self.A_size = len(self.A_paths)
		self.B_size = len(self.B_paths)
		
		self.A_labels = make_dataset(self.dir_A_label)
		self.B_labels = make_cs_labels(self.dir_B_label)
		
		self.A_labels = sorted(self.A_labels)
		self.B_labels = sorted(self.B_labels)
		
		self.transform = get_transform(opt)
		self.label_transform = get_label_transform(opt)
	
	def __getitem__(self, index):
		A_path = self.A_paths[index % self.A_size]
		if self.opt.serial_batches:
			index_B = index % self.B_size
		else:
			index_B = random.randint(0, self.B_size - 1)
		B_path = self.B_paths[index_B]
		
		A_label_path = self.A_labels[index % self.A_size]
		B_label_path = self.B_labels[index_B]
		
		A_label = Image.open(A_label_path)
		B_label = Image.open(B_label_path)
		
		A_label = np.asarray(A_label)
		A_label = remap_labels_to_train_ids(A_label)
		
		A_label = Image.fromarray(A_label, 'L')
		B_label = np.asarray(B_label)
		B_label = remap_labels_to_train_ids(B_label)
		B_label = Image.fromarray(B_label, 'L')
		
		A_img = Image.open(A_path).convert('RGB')
		B_img = Image.open(B_path).convert('RGB')
		
		A = self.transform(A_img)
		B = self.transform(B_img)
		
		A_label = self.label_transform(A_label)
		B_label = self.label_transform(B_label)
		
		# print(A_label.unique())
		# print(B_label.unique())
		
		if self.opt.which_direction == 'BtoA':
			input_nc = self.opt.output_nc
			output_nc = self.opt.input_nc
		else:
			input_nc = self.opt.input_nc
			output_nc = self.opt.output_nc
		
		if input_nc == 1:  # RGB to gray
			tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
			A = tmp.unsqueeze(0)
		
		if output_nc == 1:  # RGB to gray
			tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
			B = tmp.unsqueeze(0)
		return {'A': A, 'B': B,
		        'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
	
	def __len__(self):
		return max(self.A_size, self.B_size)
	
	def name(self):
		return 'GTA5_Cityscapes'


================================================
FILE: cyclegan/data/gta_synthia_cityscapes.py
================================================
import os.path
import random

import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.cityscapes import remap_labels_to_train_ids
from data.image_folder import make_cs_labels, make_dataset

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']


def syn_relabel(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out

# This dataset is used to conduct double cyclegan for both GTAV->CityScapes and Synthia->CityScapes
class GTASynthiaCityscapesDataset(BaseDataset):
	def initialize(self, opt):
		# SYNTHIA as dataset 1
		# GTAV as dataset 2
		self.opt = opt
		self.root = opt.dataroot
		self.dir_A_1 = os.path.join(opt.dataroot, 'synthia', 'RGB')
		self.dir_A_2 = os.path.join(opt.dataroot, 'gta5', 'images')
		self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
		self.dir_A_label_1 = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
		self.dir_A_label_2 = os.path.join(opt.dataroot, 'gta5', 'labels')
		
		self.A_paths_1 = make_dataset(self.dir_A_1)
		self.A_paths_2 = make_dataset(self.dir_A_2)
		self.B_paths = make_dataset(self.dir_B)
		
		self.A_paths_1 = sorted(self.A_paths_1)
		self.A_paths_2 = sorted(self.A_paths_2)
		
		self.B_paths = sorted(self.B_paths)
		
		self.A_size_1 = len(self.A_paths_1)
		self.A_size_2 = len(self.A_paths_2)
		
		self.B_size = len(self.B_paths)
		
		self.A_labels_1 = make_dataset(self.dir_A_label_1)
		self.A_labels_2 = make_dataset(self.dir_A_label_2)
		
		self.A_labels_1 = sorted(self.A_labels_1)
		self.A_labels_2 = sorted(self.A_labels_2)
		
		self.transform = get_transform(opt)
		self.label_transform = get_label_transform(opt)
	
	def __getitem__(self, index):
		A_path_1 = self.A_paths_1[index % self.A_size_1]
		A_path_2 = self.A_paths_2[index % self.A_size_2]
		
		if self.opt.serial_batches:
			index_B = index % self.B_size
		else:
			index_B = random.randint(0, self.B_size - 1)
		
		B_path = self.B_paths[index_B]
		
		A_label_path_1 = self.A_labels_1[index % self.A_size_1]
		A_label_path_2 = self.A_labels_2[index % self.A_size_2]
		
		A_label_1 = Image.open(A_label_path_1)
		A_label_2 = Image.open(A_label_path_2)
		
		# remaping label for synthia
		A_label_1 = np.asarray(A_label_1)
		A_label_1 = syn_relabel(A_label_1)
		A_label_1 = Image.fromarray(A_label_1, 'L')
		
		# remaping label for gta5
		
		A_label_2 = np.asarray(A_label_2)
		A_label_2 = remap_labels_to_train_ids(A_label_2)
		A_label_2 = Image.fromarray(A_label_2, 'L')
		
		A_img_1 = Image.open(A_path_1).convert('RGB')
		A_img_2 = Image.open(A_path_2).convert('RGB')
		
		B_img = Image.open(B_path).convert('RGB')
		
		A_1 = self.transform(A_img_1)
		A_2 = self.transform(A_img_2)
		
		B = self.transform(B_img)
		
		A_label_1 = self.label_transform(A_label_1)
		A_label_2 = self.label_transform(A_label_2)
		
		if self.opt.which_direction == 'BtoA':
			input_nc = self.opt.output_nc
			output_nc = self.opt.input_nc
		else:
			input_nc = self.opt.input_nc
			output_nc = self.opt.output_nc
		
		if input_nc == 1:  # RGB to gray
			tmp = A_1[0, ...] * 0.299 + A_1[1, ...] * 0.587 + A_1[2, ...] * 0.114
			A_1 = tmp.unsqueeze(0)
			
			tmp = A_2[0, ...] * 0.299 + A_2[1, ...] * 0.587 + A_2[2, ...] * 0.114
			A_2 = tmp.unsqueeze(0)
		
		if output_nc == 1:  # RGB to gray
			tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
			B = tmp.unsqueeze(0)
		
		return {'A_1': A_1, 'A_2': A_2, 'B': B, 'A_paths_1': A_path_1, 'A_paths_2': A_path_2, 'B_paths': B_path, 'A_label_1': A_label_1,
		        'A_label_2': A_label_2}
	
	def __len__(self):
		return max(self.A_size_1, self.B_size, self.A_size_2)
	
	def name(self):
		return 'GTA5_Synthia_Cityscapes'


================================================
FILE: cyclegan/data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################

import torch.utils.data as data

import numpy as np
from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_cs_labels(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                if path.endswith("_gtFine_labelIds.png"):
                    images.append(path)

    return list(set(images))

def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    return list(set(images))

def load_labels(dir, images):
    if os.path.exists(os.path.join(dir, 'labels.txt')):
        with open(os.path.join(dir, 'labels.txt'), 'r') as f:
            data = f.read().splitlines()
        parse = np.array([(x.split(' ')[0], int(x.split(' ')[1])) for x in data])
        label_dict = dict(parse)
        labels = []
        for image in images:
            im_id = image.split('/')[-1].split('.')[0]
            labels.append(label_dict[im_id])
    elif os.path.isdir(os.path.join(dir, 'labels')):
        Exception('Not yet implemented load_labels for image folder')
    else:
        Exception('load_labels expects %s to contain labels.txt or labels folder' % dir)

def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: cyclegan/data/synthia_cityscapes.py
================================================
import os.path
import random

import numpy as np
from PIL import Image
from data.base_dataset import BaseDataset, get_label_transform, get_transform
from data.image_folder import make_cs_labels, make_dataset

from data.cityscapes import remap_labels_to_train_ids

ignore_label = 255
id2label = {0: ignore_label,
            1: 10,
            2: 2,
            3: 0,
            4: 1,
            5: 4,
            6: 8,
            7: 5,
            8: 13,
            9: 7,
            10: 11,
            11: 18,
            12: 17,
            13: ignore_label,
            14: ignore_label,
            15: 6,
            16: 9,
            17: 12,
            18: 14,
            19: 15,
            20: 16,
            21: 3,
            22: ignore_label}

classes = ['road',
           'sidewalk',
           'building',
           'wall',
           'fence',
           'pole',
           'traffic light',
           'traffic sign',
           'vegetation',
           'terrain',
           'sky',
           'person',
           'rider',
           'car',
           'truck',
           'bus',
           'train',
           'motorcycle',
           'bicycle']


def syn_relabel(arr):
	out = ignore_label * np.ones(arr.shape, dtype=np.uint8)
	for id, label in id2label.items():
		out[arr == id] = int(label)
	return out


class SynthiaCityscapesDataset(BaseDataset):
	def initialize(self, opt):
		self.opt = opt
		self.root = opt.dataroot
		self.dir_A = os.path.join(opt.dataroot, 'synthia', 'RGB')
		self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit')
		self.dir_A_label = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS')
		self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine')
		
		self.A_paths = make_dataset(self.dir_A)
		self.B_paths = make_dataset(self.dir_B)
		
		self.A_paths = sorted(self.A_paths)
		self.B_paths = sorted(self.B_paths)
		self.A_size = len(self.A_paths)
		self.B_size = len(self.B_paths)
		
		self.A_labels = make_dataset(self.dir_A_label)
		self.B_labels = make_cs_labels(self.dir_B_label)
		
		self.A_labels = sorted(self.A_labels)
		self.B_labels = sorted(self.B_labels)
		
		self.transform = get_transform(opt)
		self.label_transform = get_label_transform(opt)
	
	def __getitem__(self, index):
		A_path = self.A_paths[index % self.A_size]
		if self.opt.serial_batches:
			index_B = index % self.B_size
		else:
			index_B = random.randint(0, self.B_size - 1)
		B_path = self.B_paths[index_B]
		
		A_label_path = self.A_labels[index % self.A_size]
		B_label_path = self.B_labels[index_B]
		
		A_label = Image.open(A_label_path)
		B_label = Image.open(B_label_path)
		
		A_label = np.asarray(A_label)
		A_label = syn_relabel(A_label)
		
		A_label = Image.fromarray(A_label, 'L')
		B_label = np.asarray(B_label)
		B_label = remap_labels_to_train_ids(B_label)
		B_label = Image.fromarray(B_label, 'L')
		
		A_img = Image.open(A_path).convert('RGB')
		B_img = Image.open(B_path).convert('RGB')
		
		A = self.transform(A_img)
		B = self.transform(B_img)
		
		A_label = self.label_transform(A_label)
		B_label = self.label_transform(B_label)
		
		if self.opt.which_direction == 'BtoA':
			input_nc = self.opt.output_nc
			output_nc = self.opt.input_nc
		else:
			input_nc = self.opt.input_nc
			output_nc = self.opt.output_nc
		
		if input_nc == 1:  # RGB to gray
			tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
			A = tmp.unsqueeze(0)
		
		if output_nc == 1:  # RGB to gray
			tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
			B = tmp.unsqueeze(0)
		return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
	
	def __len__(self):
		return max(self.A_size, self.B_size)
	
	def name(self):
		return 'Synthia_Cityscapes'


================================================
FILE: cyclegan/environment.yml
================================================
name: pytorch-CycleGAN-and-pix2pix
channels:
- peterjc123
- defaults
dependencies:
- python=3.5.5
- pytorch=0.3.1
- scipy
- pip:
  - dominate==2.3.1
  - git+https://github.com/pytorch/vision.git
  - Pillow==5.0.0
  - numpy==1.14.1
  - visdom==0.1.7


================================================
FILE: cyclegan/models/__init__.py
================================================
import logging

def create_model(opt):
	model = None
	if opt.model == 'cycle_gan':
		# assert(opt.dataset_mode == 'unaligned')
		from .cycle_gan_model import CycleGANModel
		model = CycleGANModel()
	elif opt.model == 'test':
		from .test_model import TestModel
		model = TestModel()
	elif opt.model == 'multi_cycle_gan_semantic':
		from .multi_cycle_gan_semantic_model import CycleGANSemanticModel
		model = CycleGANSemanticModel()
	elif opt.model == 'cycle_gan_semantic_fcn':
		from .cycle_gan_semantic_model import CycleGANSemanticModel
		model = CycleGANSemanticModel()
	else:
		raise NotImplementedError('model [%s] not implemented.' % opt.model)
	model.initialize(opt)
	logging.info("model [%s] was created" % (model.name()))
	return model


================================================
FILE: cyclegan/models/base_model.py
================================================
import os
from collections import OrderedDict

import torch

from . import networks


class BaseModel():
	def name(self):
		return 'BaseModel'
	
	def initialize(self, opt):
		self.opt = opt
		self.gpu_ids = opt.gpu_ids
		self.isTrain = opt.isTrain
		self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
		self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
		if opt.resize_or_crop != 'scale_width':
			torch.backends.cudnn.benchmark = True
		self.loss_names = []
		self.model_names = []
		self.visual_names = []
		self.image_paths = []
	
	def set_input(self, input):
		self.input = input
	
	def forward(self):
		pass
	
	# load and print networks; create shedulars
	def setup(self, opt):
		if self.isTrain:
			self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
		
		if not self.isTrain or opt.continue_train:
			self.load_networks(opt.which_epoch)
		self.print_networks(opt.verbose)
	
	# make models eval mode during test time
	def eval(self):
		for name in self.model_names:
			if isinstance(name, str):
				net = getattr(self, 'net' + name)
				net.eval()
	
	# used in test time, wrapping `forward` in no_grad() so we don't save
	# intermediate steps for backprop
	def test(self):
		with torch.no_grad():
			self.forward()
	
	# get image paths
	def get_image_paths(self):
		return self.image_paths
	
	def optimize_parameters(self):
		pass
	
	# update learning rate (called once every epoch)
	def update_learning_rate(self):
		for scheduler in self.schedulers:
			scheduler.step()
		lr = self.optimizers[0].param_groups[0]['lr']
		print('learning rate = %.7f' % lr)
	
	# return visualization images. train.py will display these images, and save the images to a html
	def get_current_visuals(self):
		visual_ret = OrderedDict()
		for name in self.visual_names:
			if isinstance(name, str):
				visual_ret[name] = getattr(self, name)
		return visual_ret
	
	# return traning losses/errors. train.py will print out these errors as debugging information
	def get_current_losses(self):
		errors_ret = OrderedDict()
		for name in self.loss_names:
			if isinstance(name, str):
				# float(...) works for both scalar tensor and float number
				errors_ret[name] = float(getattr(self, 'loss_' + name))
		return errors_ret
	
	# save models to the disk
	def save_networks(self, which_epoch):
		for name in self.model_names:
			# Don't save semantic consistency networks
			if isinstance(name, str) and ("PixelCLS" not in name):
				save_filename = '%s_net_%s.pth' % (which_epoch, name)
				save_path = os.path.join(self.save_dir, save_filename)
				net = getattr(self, 'net' + name)
				
				if len(self.gpu_ids) > 0 and torch.cuda.is_available():
					torch.save(net.module.cpu().state_dict(), save_path)
					net.cuda(self.gpu_ids[0])
				else:
					torch.save(net.cpu().state_dict(), save_path)
	
	def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
		key = keys[i]
		if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
			if module.__class__.__name__.startswith('InstanceNorm') and \
				(key == 'running_mean' or key == 'running_var'):
				if getattr(module, key) is None:
					state_dict.pop('.'.join(keys))
		else:
			self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
	
	# load models from the disk
	def load_networks(self, which_epoch):
		for name in self.model_names:
			if isinstance(name, str):
				load_filename = '%s_net_%s.pth' % (which_epoch, name)
				load_path = os.path.join(self.save_dir, load_filename)
				
				net = getattr(self, 'net' + name)
				if isinstance(net, torch.nn.DataParallel):
					net = net.module
				print('loading the model from %s' % load_path)
				# if you are using PyTorch newer than 0.4 (e.g., built from
				# GitHub source), you can remove str() on self.device
				state_dict = torch.load(load_path, map_location=str(self.device))
				# patch InstanceNorm checkpoints prior to 0.4
				for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
					self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
				net.load_state_dict(state_dict)
	
	# print network information
	def print_networks(self, verbose):
		print('---------- Networks initialized -------------')
		for name in self.model_names:
			if isinstance(name, str):
				net = getattr(self, 'net' + name)
				num_params = 0
				for param in net.parameters():
					num_params += param.numel()
				if verbose:
					print(net)
				print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
		print('-----------------------------------------------')
	
	# set requies_grad=Fasle to avoid computation
	def set_requires_grad(self, nets, requires_grad=False):
		if not isinstance(nets, list):
			nets = [nets]
		for net in nets:
			if net is not None:
				for param in net.parameters():
					param.requires_grad = requires_grad


================================================
FILE: cyclegan/models/cycle_gan_model.py
================================================
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks


class CycleGANModel(BaseModel):
    def name(self):
        return 'CycleGANModel'

    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        # specify the training losses you want to print out. The program will call base_model.get_current_losses
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The program will call base_model.get_current_visuals
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:
            visual_names_A.append('idt_A')
            visual_names_B.append('idt_B')

        self.visual_names = visual_names_A + visual_names_B
        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # load/define networks
        # The naming conversion is different from those used in the paper
        # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.fake_B = self.netG_A(self.real_A)
        self.rec_A = self.netG_B(self.fake_B)

        self.fake_A = self.netG_B(self.real_B)
        self.rec_B = self.netG_A(self.fake_A)

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed.
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed.
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        # forward
        self.forward()
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_A()
        self.backward_D_B()
        self.optimizer_D.step()


================================================
FILE: cyclegan/models/cycle_gan_semantic_model.py
================================================
import itertools
import sys

import torch
import torch.nn.functional as F
from util.image_pool import ImagePool

from . import networks
from .base_model import BaseModel

sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model


class CycleGANSemanticModel(BaseModel):
	def name(self):
		return 'CycleGANModel'
	
	def initialize(self, opt):
		BaseModel.initialize(self, opt)
		
		# specify the training losses you want to print out. The program will call base_model.get_current_losses
		self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A',
		                   'D_B', 'G_B', 'cycle_B', 'idt_B',
		                   'sem_AB']
		
		# specify the images you want to save/display. The program will call base_model.get_current_visuals
		visual_names_A = ['real_A', 'fake_B', 'rec_A']
		visual_names_B = ['real_B', 'fake_A', 'rec_B']
		if self.isTrain and self.opt.lambda_identity > 0.0:
			visual_names_A.append('idt_A')
			visual_names_B.append('idt_B')
		
		self.visual_names = visual_names_A + visual_names_B
		# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
		if self.isTrain:
			self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
		
		else:  # during test time, only load Gs
			self.model_names = ['G_A', 'G_B']
		
		# load/define networks
		# The naming conversion is different from those used in the paper
		# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
		self.netG_A = networks.define_G(opt.input_nc, opt.output_nc,
		                                opt.ngf, opt.which_model_netG, opt.norm,
		                                not opt.no_dropout, opt.init_type, self.gpu_ids)
		self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
		                                opt.ngf, opt.which_model_netG, opt.norm,
		                                not opt.no_dropout, opt.init_type, self.gpu_ids)
		
		if self.isTrain:
			use_sigmoid = opt.no_lsgan
			self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
			                                opt.which_model_netD,
			                                opt.n_layers_D, opt.norm, use_sigmoid,
			                                opt.init_type, self.gpu_ids)
			self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
			                                opt.which_model_netD,
			                                opt.n_layers_D, opt.norm, use_sigmoid,
			                                opt.init_type, self.gpu_ids)
			
			# Here for semantic consistency loss, load a fcn network as fs here.
			self.netPixelCLS = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_init)
			# Specially initialize Pixel CLS network
			if len(self.gpu_ids) > 0:
				assert (torch.cuda.is_available())
				self.netPixelCLS.to(self.gpu_ids[0])
				self.netPixelCLS = torch.nn.DataParallel(self.netPixelCLS, self.gpu_ids)
		
		if self.isTrain:
			self.fake_A_pool = ImagePool(opt.pool_size)
			self.fake_B_pool = ImagePool(opt.pool_size)
			# define loss functions
			self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
			self.criterionCycle = torch.nn.L1Loss()
			self.criterionIdt = torch.nn.L1Loss()
			# self.criterionCLS = torch.nn.modules.CrossEntropyLoss()
			self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
			# initialize optimizers
			self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()),
			                                    lr=opt.lr, betas=(opt.beta1, 0.999))
			self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()),
			                                    lr=opt.lr, betas=(opt.beta1, 0.999))
			
			self.optimizers = []
			self.optimizers.append(self.optimizer_G)
			self.optimizers.append(self.optimizer_D)
	
	def set_input(self, input):
		AtoB = self.opt.which_direction == 'AtoB'
		self.real_A = input['A' if AtoB else 'B'].to(self.device)
		self.real_B = input['B' if AtoB else 'A'].to(self.device)
		self.image_paths = input['A_paths' if AtoB else 'B_paths']
		if 'A_label' in input and 'B_label' in input:
			self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device)
			self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device)
	
	# self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain
	
	def forward(self):
		self.fake_B = self.netG_A(self.real_A)
		self.rec_A = self.netG_B(self.fake_B)
		
		self.fake_A = self.netG_B(self.real_B)
		self.rec_B = self.netG_A(self.fake_A)
		
		if self.isTrain:
			# Forward all four images through classifier
			# Keep predictions from fake images only
			self.pred_real_A = self.netPixelCLS(self.real_A)
			_, self.gt_pred_A = self.pred_real_A.max(1)
			
			self.pred_fake_B = self.netPixelCLS(self.fake_B)
			_, pfB = self.pred_fake_B.max(1)
	
	def backward_D_basic(self, netD, real, fake):
		# Real
		pred_real = netD(real)
		loss_D_real = self.criterionGAN(pred_real, True)
		# Fake
		pred_fake = netD(fake.detach())
		loss_D_fake = self.criterionGAN(pred_fake, False)
		# Combined Loss
		loss_D = (loss_D_real + loss_D_fake) * 0.5
		# backward
		loss_D.backward()
		return loss_D
	
	def backward_PixelCLS(self):
		label_A = self.input_A_label
		# forward only real source image through semantic classifier
		pred_A = self.netPixelCLS(self.real_A)
		self.loss_PixelCLS = self.criterionSemantic(F.log_softmax(pred_A, dim=1), label_A.long())
		self.loss_PixelCLS.backward()
	
	def backward_D_A(self):
		fake_B = self.fake_B_pool.query(self.fake_B)
		self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
	
	def backward_D_B(self):
		fake_A = self.fake_A_pool.query(self.fake_A)
		self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
	
	def backward_G(self, opt):
		lambda_idt = self.opt.lambda_identity
		lambda_A = self.opt.lambda_A
		lambda_B = self.opt.lambda_B
		# Identity loss
		if lambda_idt > 0:
			# G_A should be identity if real_B is fed.
			self.idt_A = self.netG_A(self.real_B)
			self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
			# G_B should be identity if real_A is fed.
			self.idt_B = self.netG_B(self.real_A)
			self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
		else:
			self.loss_idt_A = 0
			self.loss_idt_B = 0
		
		# GAN loss D_A(G_A(A))
		self.loss_G_A = 2 * self.criterionGAN(self.netD_A(self.fake_B), True)
		# GAN loss D_B(G_B(B))
		self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
		# Forward cycle loss
		self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
		# Backward cycle loss
		self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
		# combined loss standard cyclegan
		self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
		
		# real_A(syn)->fake_B->(fcn_frozen)->pred_fake_B == input_A_label
		if opt.semantic_loss:
			self.loss_sem_AB = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B, dim=1), F.softmax(self.pred_real_A,
			                                                                                                                 dim=1))
			self.loss_sem_AB = opt.general_semantic_weight * torch.div(self.loss_sem_AB, self.pred_fake_B.shape[1] * self.pred_fake_B.shape[2]
			                                                           * self.pred_fake_B.shape[3])
			self.loss_G += self.loss_sem_AB
		
		self.loss_G.backward()
	
	def optimize_parameters(self, opt):
		# forward
		self.forward()
		# G_A and G_B
		self.set_requires_grad([self.netD_A, self.netD_B], False)
		self.optimizer_G.zero_grad()
		# self.optimizer_CLS.zero_grad()
		self.backward_G(opt)
		self.optimizer_G.step()
		# D_A and D_B
		self.set_requires_grad([self.netD_A, self.netD_B], True)
		self.optimizer_D.zero_grad()
		self.backward_D_A()
		self.backward_D_B()
		self.optimizer_D.step()


================================================
FILE: cyclegan/models/multi_cycle_gan_semantic_model.py
================================================
import itertools
import sys

import torch
import torch.nn.functional as F
from util.image_pool import ImagePool

from . import networks
from .base_model import BaseModel

sys.path.append('/nfs/project/libo_iMADAN')
from cycada.models import get_model


class CycleGANSemanticModel(BaseModel):
	def name(self):
		return 'CycleGANModel'
	
	def initialize(self, opt):
		BaseModel.initialize(self, opt)
		
		self.semantic_loss = opt.semantic_loss
		
		# specify the training losses you want to print out. The program will call base_model.get_current_losses
		self.loss_names = ['D_A_1', 'G_A_1', 'cycle_A_1', 'idt_A_1',
		                   'D_B_1', 'G_B_1', 'cycle_B_1', 'idt_B_1',
		                   'D_A_2', 'G_A_2', 'cycle_A_2', 'idt_A_2',
		                   'D_B_2', 'G_B_2', 'cycle_B_2', 'idt_B_2']
		
		if opt.SAD:
			self.loss_names.extend(['D_3_1', 'G_s1s2'])
		
		if opt.CCD or opt.HF_CCD:
			self.loss_names.extend(['D_21', 'G_s1s21'])
			self.loss_names.extend(['D_12', 'G_s2s12'])
		
		if self.semantic_loss:
			self.loss_names.extend(['sem_syn', 'sem_gta'])
		
		# specify the images you want to save/display. The program will call base_model.get_current_visuals
		visual_names_A_1 = ['real_A_1', 'fake_B_1', 'rec_A_1']
		visual_names_B_1 = ['real_B', 'fake_A_1', 'rec_B_1']
		
		visual_names_A_2 = ['real_A_2', 'fake_B_2', 'rec_A_2']
		visual_names_B_2 = ['fake_A_2', 'rec_B_2']
		
		if self.isTrain and self.opt.lambda_identity > 0.0:
			visual_names_A_1.append('idt_A_1')
			visual_names_B_1.append('idt_B_1')
			
			visual_names_A_2.append('idt_A_2')
			visual_names_B_2.append('idt_B_2')
		
		self.visual_names = visual_names_A_1 + visual_names_B_1 + visual_names_A_2 + visual_names_B_2
		# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
		if self.isTrain:
			# self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
			if opt.Shared_DT:
				self.model_names = ['G_A_1', 'G_B_1', 'D_A', 'D_B_1', 'D_B_2', 'G_A_2', 'G_B_2']
			else:
				self.model_names = ['G_A_1', 'G_B_1', 'D_A_1', 'D_B_1', 'G_A_2', 'G_B_2', 'D_A_2', 'D_B_2']
			if opt.SAD:
				self.model_names.append('D_3')
			
			if opt.CCD or opt.HF_CCD:
				self.model_names.append('D_12')
				self.model_names.append('D_21')
		
		else:  # during test time, only load Gs
			self.model_names = ['G_A_1', 'G_B_1', 'G_A_2', 'G_B_2']
		
		# load/define networks
		# The naming conversion is different from those used in the paper
		# Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
		self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
		                                  opt.ngf, opt.which_model_netG, opt.norm,
		                                  not opt.no_dropout, opt.init_type, self.gpu_ids)
		self.netG_B_1 = networks.define_G(opt.output_nc, opt.input_nc,
		                                  opt.ngf, opt.which_model_netG, opt.norm,
		                                  not opt.no_dropout, opt.init_type, self.gpu_ids)
		
		self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
		                                  opt.ngf, opt.which_model_netG, opt.norm,
		                                  not opt.no_dropout, opt.init_type, self.gpu_ids)
		
		self.netG_B_2 = networks.define_G(opt.output_nc, opt.input_nc,
		                                  opt.ngf, opt.which_model_netG, opt.norm,
		                                  not opt.no_dropout, opt.init_type, self.gpu_ids)
		
		if opt.semantic_loss:
			self.netPixelCLS_SYN = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_syn)
			self.netPixelCLS_GTA = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_gta)
			if len(self.gpu_ids) > 0:
				assert (torch.cuda.is_available())
				self.netPixelCLS_SYN.to(self.gpu_ids[0])
				self.netPixelCLS_SYN = torch.nn.DataParallel(self.netPixelCLS_SYN, self.gpu_ids)
				self.netPixelCLS_GTA.to(self.gpu_ids[0])
				self.netPixelCLS_GTA = torch.nn.DataParallel(self.netPixelCLS_GTA, self.gpu_ids)
		
		if self.isTrain:
			use_sigmoid = opt.no_lsgan
			if opt.Shared_DT:
				self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
				                                opt.which_model_netD,
				                                opt.n_layers_D, opt.norm, use_sigmoid,
				                                opt.init_type, self.gpu_ids)
			else:
				self.netD_A_1 = networks.define_D(opt.output_nc, opt.ndf,
				                                  opt.which_model_netD,
				                                  opt.n_layers_D, opt.norm, use_sigmoid,
				                                  opt.init_type, self.gpu_ids)
				
				self.netD_A_2 = networks.define_D(opt.output_nc, opt.ndf,
				                                  opt.which_model_netD,
				                                  opt.n_layers_D, opt.norm, use_sigmoid,
				                                  opt.init_type, self.gpu_ids)
			
			self.netD_B_1 = networks.define_D(opt.input_nc, opt.ndf,
			                                  opt.which_model_netD,
			                                  opt.n_layers_D, opt.norm, use_sigmoid,
			                                  opt.init_type, self.gpu_ids)
			
			self.netD_B_2 = networks.define_D(opt.input_nc, opt.ndf,
			                                  opt.which_model_netD,
			                                  opt.n_layers_D, opt.norm, use_sigmoid,
			                                  opt.init_type, self.gpu_ids)
			
			if opt.SAD:
				self.netD_3 = networks.define_D(opt.input_nc, opt.ndf,
				                                opt.which_model_netD,
				                                opt.n_layers_D, opt.norm, use_sigmoid,
				                                opt.init_type, self.gpu_ids)
			if opt.CCD or opt.HF_CCD:
				self.netD_12 = networks.define_D(opt.input_nc, opt.ndf,
				                                 opt.which_model_netD,
				                                 opt.n_layers_D, opt.norm, use_sigmoid,
				                                 opt.init_type, self.gpu_ids)
				self.netD_21 = networks.define_D(opt.input_nc, opt.ndf,
				                                 opt.which_model_netD,
				                                 opt.n_layers_D, opt.norm, use_sigmoid,
				                                 opt.init_type, self.gpu_ids)
		
		if self.isTrain:
			self.fake_A_1_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
			self.fake_B_1_pool = ImagePool(opt.pool_size)
			self.fake_A_2_pool = ImagePool(opt.pool_size)
			self.fake_B_2_pool = ImagePool(opt.pool_size)
			self.fake_A_21_pool = ImagePool(opt.pool_size)
			self.fake_A_12_pool = ImagePool(opt.pool_size)
			# define loss functions
			self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
			self.criterionCycle = torch.nn.L1Loss()
			self.criterionIdt = torch.nn.L1Loss()
			self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean')
			# initialize optimizers
			if opt.Shared_DT:
				self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B_1.parameters(),
				                                                    self.netD_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
			else:
				self.optimizer_D_1 = torch.optim.Adam(itertools.chain(self.netD_A_1.parameters(), self.netD_B_1.parameters()),
				                                      lr=opt.lr, betas=(opt.beta1, 0.999))
				self.optimizer_D_2 = torch.optim.Adam(itertools.chain(self.netD_A_2.parameters(), self.netD_B_2.parameters()),
				                                      lr=opt.lr, betas=(opt.beta1, 0.999))
			
			self.optimizer_G_1 = torch.optim.Adam(itertools.chain(self.netG_A_1.parameters(), self.netG_B_1.parameters()),
			                                      lr=opt.lr, betas=(opt.beta1, 0.999))
			
			self.optimizer_G_2 = torch.optim.Adam(itertools.chain(self.netG_A_2.parameters(), self.netG_B_2.parameters()),
			                                      lr=opt.lr, betas=(opt.beta1, 0.999))
			
			if opt.SAD:
				self.optimizer_D_3 = torch.optim.Adam(self.netD_3.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
			
			if opt.CCD or opt.HF_CCD:
				self.optimizer_D_21 = torch.optim.Adam(self.netD_21.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
				self.optimizer_D_12 = torch.optim.Adam(self.netD_12.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
			
			self.optimizers = []
			self.optimizers.append(self.optimizer_G_1)
			self.optimizers.append(self.optimizer_G_2)
			if opt.Shared_DT:
				self.optimizers.append(self.optimizer_D)
			else:
				self.optimizers.append(self.optimizer_D_1)
				self.optimizers.append(self.optimizer_D_2)
			
			if opt.SAD:
				self.optimizers.append(self.optimizer_D_3)
			if opt.CCD or opt.HF_CCD:
				self.optimizers.append(self.optimizer_D_12)
				self.optimizers.append(self.optimizer_D_21)
	
	def set_input(self, input):
		self.real_A_1 = input['A_1'].to(self.device)
		self.real_A_2 = input['A_2'].to(self.device)
		self.real_B = input['B'].to(self.device)
		
		self.image_paths_1 = input['A_paths_1']
		self.image_paths_2 = input['A_paths_2']
		self.image_paths = self.image_paths_1 + self.image_paths_2
		if 'A_label_1' in input and 'B_label' in input and 'A_label_2' in input:
			self.input_A_label_1 = input['A_label_1'].to(self.device)
			self.input_A_label_2 = input['A_label_2'].to(self.device)
			self.input_B_label = input['B_label'].to(self.device)
	
	def forward(self, opt):
		# cycle for data input #1
		self.fake_B_1 = self.netG_A_1(self.real_A_1)
		self.rec_A_1 = self.netG_B_1(self.fake_B_1)
		
		self.fake_A_1 = self.netG_B_1(self.real_B)
		self.rec_B_1 = self.netG_A_1(self.fake_A_1)
		
		# cycle for data input #2
		self.fake_B_2 = self.netG_A_2(self.real_A_2)
		self.rec_A_2 = self.netG_B_2(self.fake_B_2)
		
		self.fake_A_2 = self.netG_B_2(self.real_B)
		self.rec_B_2 = self.netG_A_2(self.fake_A_2)
		
		if opt.CCD:
			# generate s21 for d21 branch
			self.fake_A_21 = self.netG_B_1(self.fake_B_2)
			# generate s12 for d12 branch
			self.fake_A_12 = self.netG_B_2(self.fake_B_1)
		
		if self.isTrain and self.semantic_loss:
			# Forward all four images through classifier
			# Keep predictions from fake images only
			self.pred_real_A_SYN = self.netPixelCLS_SYN(self.real_A_1)
			_, self.gt_pred_A_SYN = self.pred_real_A_SYN.max(1)
			
			self.pred_fake_B_SYN = self.netPixelCLS_SYN(self.fake_B_1)
			_, pfB_SYN = self.pred_fake_B_SYN.max(1)
			
			self.pred_real_A_GTA = self.netPixelCLS_GTA(self.real_A_2)
			_, self.gt_pred_A_GTA = self.pred_real_A_GTA.max(1)
			
			self.pred_fake_B_GTA = self.netPixelCLS_GTA(self.fake_B_2)
			_, pfB_GTA = self.pred_fake_B_GTA.max(1)
	
	def backward_D_basic(self, netD, real, fake, SAD=False):
		# Real
		if SAD == False:
			pred_real = netD(real)
		else:
			pred_real = netD(real.detach())
		
		loss_D_real = self.criterionGAN(pred_real, True)
		# Fake
		pred_fake = netD(fake.detach())
		loss_D_fake = self.criterionGAN(pred_fake, False)
		# Combined loss
		loss_D = (loss_D_real + loss_D_fake) * 0.5
		# backward
		loss_D.backward()
		return loss_D
	
	def backward_D_A(self, Shared_DT):
		# data 1 A1->B
		fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
		if Shared_DT:
			self.loss_D_A_1 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_1)
		else:
			self.loss_D_A_1 = self.backward_D_basic(self.netD_A_1, self.real_B, fake_B_1)
		# data 2 A2->B
		fake_B_2 = self.fake_B_2_pool.query(self.fake_B_2)
		if Shared_DT:
			self.loss_D_A_2 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_2)
		else:
			self.loss_D_A_2 = self.backward_D_basic(self.netD_A_2, self.real_B, fake_B_2)
	
	def backward_D_B(self):
		# data 1 B->A1
		fake_A_1 = self.fake_A_1_pool.query(self.fake_A_1)
		self.loss_D_B_1 = self.backward_D_basic(self.netD_B_1, self.real_A_1, fake_A_1)
		
		# data 2 B->A2
		fake_A_2 = self.fake_A_2_pool.query(self.fake_A_2)
		self.loss_D_B_2 = self.backward_D_basic(self.netD_B_2, self.real_A_2, fake_A_2)
	
	def backward_D(self, which_D):
		if which_D == 'SAD':
			fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1)
			self.loss_D_3_1 = self.backward_D_basic(self.netD_3, self.fake_B_2, fake_B_1, SAD=True)
		
		elif which_D == 'CCD_21':
			fake_A_21 = self.fake_A_21_pool.query(self.fake_A_21)
			self.loss_D_21 = self.backward_D_basic(self.netD_21, self.real_A_1, fake_A_21)
		
		elif which_D == 'CCD_12':
			fake_A_12 = self.fake_A_12_pool.query(self.fake_A_12)
			self.loss_D_12 = self.backward_D_basic(self.netD_12, self.real_A_2, fake_A_12)
		
		else:
			raise Exception("Invalid Choice {}".format(which_D))
	
	# fake_B_2 = self.fake_B_pool.query(self.fake_B_2)
	# self.loss_D_3_2 = self.backward_D_basic(self.netD_3, self.fake_B_1, fake_B_2)
	
	def backward_G(self, opt):
		lambda_idt = self.opt.lambda_identity
		lambda_A = self.opt.lambda_A
		lambda_B = self.opt.lambda_B
		# Identity loss
		if lambda_idt > 0:
			self.idt_A_1 = self.netG_A_1(self.real_B)
			self.loss_idt_A_1 = self.criterionIdt(self.idt_A_1, self.real_B) * lambda_B * lambda_idt
			
			self.idt_A_2 = self.netG_A_2(self.real_B)
			self.loss_idt_A_2 = self.criterionIdt(self.idt_A_2, self.real_B) * lambda_B * lambda_idt
			
			self.idt_B_1 = self.netG_B_1(self.real_A_1)
			self.loss_idt_B_1 = self.criterionIdt(self.idt_B_1, self.real_A_1) * lambda_A * lambda_idt
			
			self.idt_B_2 = self.netG_B_2(self.real_A_2)
			self.loss_idt_B_2 = self.criterionIdt(self.idt_B_2, self.real_A_2) * lambda_A * lambda_idt
		
		else:
			self.loss_idt_A_1 = 0
			self.loss_idt_A_2 = 0
			self.loss_idt_B_1 = 0
			self.loss_idt_B_2 = 0
		
		if opt.Shared_DT:
			self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A(self.fake_B_1), True)
			self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A(self.fake_B_2), True)
		else:
			self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A_1(self.fake_B_1), True)
			self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A_2(self.fake_B_2), True)
		
		# GAN loss D_B(G_B(B))
		self.loss_G_B_1 = self.criterionGAN(self.netD_B_1(self.fake_A_1), True)
		self.loss_G_B_2 = self.criterionGAN(self.netD_B_2(self.fake_A_2), True)
		
		# Forward cycle loss
		self.loss_cycle_A_1 = self.criterionCycle(self.rec_A_1, self.real_A_1) * lambda_A
		self.loss_cycle_A_2 = self.criterionCycle(self.rec_A_2, self.real_A_2) * lambda_A
		
		# Backward cycle loss
		self.loss_cycle_B_1 = self.criterionCycle(self.rec_B_1, self.real_B) * lambda_B
		self.loss_cycle_B_2 = self.criterionCycle(self.rec_B_2, self.real_B) * lambda_B
		
		# combined loss standard cyclegan
		self.loss_G_1 = self.loss_G_A_1 + self.loss_G_B_1 + self.loss_cycle_A_1 + self.loss_cycle_B_1 + self.loss_idt_A_1 + self.loss_idt_B_1
		self.loss_G_2 = self.loss_G_A_2 + self.loss_G_B_2 + self.loss_cycle_A_2 + self.loss_cycle_B_2 + self.loss_idt_A_2 + self.loss_idt_B_2
		self.loss_G = self.loss_G_1 + self.loss_G_2
		
		if opt.SAD:
			# D3 loss
			if opt.SAD_frozen_epoch != -1 and opt.current_epoch > opt.SAD_frozen_epoch:
				self.loss_G_s1s2 = self.criterionGAN(self.netD_3(self.fake_B_1), True)
			else:
				self.loss_G_s1s2 = 0
			self.loss_G += self.loss_G_s1s2
		
		if opt.CCD:
			# D21 loss
			if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
				self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
				self.loss_G += self.loss_G_s1s21 * opt.D1D2_weight
			else:
				self.loss_G_s1s21 = 0
			
			if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
				self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
				self.loss_G += self.loss_G_s2s12 * opt.D1D2_weight
			else:
				self.loss_G_s2s12 = 0
		
		if opt.semantic_loss:
			self.loss_sem_syn = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_SYN, dim=1),
			                                                                F.softmax(self.pred_real_A_SYN, dim=1))
			self.loss_sem_gta = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_GTA, dim=1),
			                                                                F.softmax(self.pred_real_A_GTA, dim=1))
			self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_syn, self.pred_fake_B_SYN.shape[1] * self.pred_fake_B_SYN.shape[2]
			                                                       * self.pred_fake_B_SYN.shape[3])
			self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_gta, self.pred_fake_B_GTA.shape[1] * self.pred_fake_B_GTA.shape[2]
			                                                       * self.pred_fake_B_GTA.shape[3])
		
		self.loss_G.backward()
	
	def backward_HF_CCD(self, opt):
		self.fake_B_1 = self.netG_A_1(self.real_A_1)
		self.fake_B_2 = self.netG_A_2(self.real_A_2)
		# generate s21 for d21 branch
		self.fake_A_21 = self.netG_B_1(self.fake_B_2)
		# generate s12 for d12 branch
		self.fake_A_12 = self.netG_B_2(self.fake_B_1)
		
		# D12 loss
		if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
			self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
		else:
			self.loss_G_s2s12 = 0
		# D21 loss
		if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch:
			self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
		else:
			self.loss_G_s1s21 = 0
		
		# self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True)
		# self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True)
		self.loss_G_HF = self.loss_G_s1s21 * opt.CCD_weight + self.loss_G_s2s12 * opt.CCD_weight
		
		if isinstance(self.loss_G_HF, torch.Tensor):
			self.loss_G_HF.backward()
	
	def optimize_parameters(self, opt):
		# forward
		self.forward(opt)
		# G_A and G_B
		# set D to false, back prop G's gradients
		if opt.Shared_DT:
			self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], False)
		else:
			self.set_requires_grad([self.netD_A_1, self.netD_B_1], False)
			self.set_requires_grad([self.netD_A_2, self.netD_B_2], False)
		
		if opt.SAD:
			self.set_requires_grad([self.netD_3], False)
		
		if opt.CCD or opt.HF_CCD:
			self.set_requires_grad([self.netD_21], False)
			self.set_requires_grad([self.netD_12], False)
		
		self.set_requires_grad([self.netG_A_1, self.netG_B_1], True)
		self.set_requires_grad([self.netG_A_2, self.netG_B_2], True)
		
		self.optimizer_G_1.zero_grad()
		self.optimizer_G_2.zero_grad()
		# self.optimizer_CLS.zero_grad()
		self.backward_G(opt)
		self.optimizer_G_1.step()
		self.optimizer_G_2.step()
		
		if opt.HF_CCD:
			self.optimizer_G_1.zero_grad()
			self.optimizer_G_2.zero_grad()
			self.set_requires_grad([self.netG_A_1, self.netG_A_2], True)
			self.set_requires_grad([self.netG_B_1, self.netG_B_2], False)
			
			self.backward_HF_CCD(opt)
			self.optimizer_G_1.step()
			self.optimizer_G_2.step()
		
		# D_A and D_B
		if opt.Shared_DT:
			self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], True)
		else:
			self.set_requires_grad([self.netD_A_1, self.netD_B_1], True)
			self.set_requires_grad([self.netD_A_2, self.netD_B_2], True)
		
		if opt.Shared_DT:
			self.optimizer_D.zero_grad()
		else:
			self.optimizer_D_1.zero_grad()
			self.optimizer_D_2.zero_grad()
		
		self.backward_D_B()
		self.backward_D_A(opt.Shared_DT)
		if opt.Shared_DT:
			self.optimizer_D.step()
		else:
			self.optimizer_D_1.step()
			self.optimizer_D_2.step()
		
		if opt.SAD:
			self.set_requires_grad([self.netD_3], True)
			self.optimizer_D_3.zero_grad()
			self.backward_D('SAD')
			self.optimizer_D_3.step()
		
		if opt.CCD or opt.HF_CCD:
			self.set_requires_grad([self.netD_21], True)
			self.optimizer_D_21.zero_grad()
			self.backward_D('CCD_21')
			self.optimizer_D_21.step()
			
			self.set_requires_grad([self.netD_12], True)
			self.optimizer_D_12.zero_grad()
			self.backward_D('CCD_12')
			self.optimizer_D_12.step()


================================================
FILE: cyclegan/models/networks.py
================================================
import functools

import torch
import torch.nn as nn
from torch.nn import init
from torch.optim import lr_scheduler


###############################################################################
# Helper Functions
###############################################################################


def get_norm_layer(norm_type='instance'):
	if norm_type == 'batch':
		norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
	elif norm_type == 'instance':
		norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
	elif norm_type == 'none':
		norm_layer = None
	else:
		raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
	return norm_layer


def get_scheduler(optimizer, opt):
	if opt.lr_policy == 'lambda':
		def lambda_rule(epoch):
			lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
			return lr_l
		
		scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
	elif opt.lr_policy == 'step':
		scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
	elif opt.lr_policy == 'plateau':
		scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
	else:
		return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
	return scheduler


def init_weights(net, init_type='normal', gain=0.02):
	def init_func(m):
		classname = m.__class__.__name__
		if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
			if init_type == 'normal':
				init.normal_(m.weight.data, 0.0, gain)
			elif init_type == 'xavier':
				init.xavier_normal_(m.weight.data, gain=gain)
			elif init_type == 'kaiming':
				init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
			elif init_type == 'orthogonal':
				init.orthogonal_(m.weight.data, gain=gain)
			else:
				raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
			if hasattr(m, 'bias') and m.bias is not None:
				init.constant_(m.bias.data, 0.0)
		elif classname.find('BatchNorm2d') != -1:
			init.normal_(m.weight.data, 1.0, gain)
			init.constant_(m.bias.data, 0.0)
	
	print('initialize network with %s' % init_type)
	net.apply(init_func)


def init_net(net, init_type='normal', gpu_ids=[]):
	if len(gpu_ids) > 0:
		assert (torch.cuda.is_available())
		net.to(gpu_ids[0])
		net = torch.nn.DataParallel(net, gpu_ids)
	init_weights(net, init_type)
	return net


def print_network(net):
	num_params = 0
	for param in net.parameters():
		num_params += param.numel()
	print(net)
	print('Total number of parameters: %d' % num_params)


def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]):
	netG = None
	norm_layer = get_norm_layer(norm_type=norm)
	
	if which_model_netG == 'resnet_9blocks':
		netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
	elif which_model_netG == 'resnet_6blocks':
		netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
	elif which_model_netG == 'unet_128':
		netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
	elif which_model_netG == 'unet_256':
		netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
	else:
		raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
	return init_net(netG, init_type, gpu_ids)


def define_D(input_nc, ndf, which_model_netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
	netD = None
	norm_layer = get_norm_layer(norm_type=norm)
	
	if which_model_netD == 'basic':
		netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
	elif which_model_netD == 'n_layers':
		netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
	elif which_model_netD == 'pixel':
		netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
	else:
		raise NotImplementedError('Discriminator model name [%s] is not recognized' %
		                          which_model_netD)
	return init_net(netD, init_type, gpu_ids)


def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]):
	# if output_nc == 3:
	#    netC = get_model('DTN', num_cls=10)
	# else:
	#    Exception('classifier only implemented for 32x32x3 images')
	netC = Classifier(output_nc, ndf)
	return init_net(netC, init_type, gpu_ids)


##############################################################################
# Classes
##############################################################################


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
	def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
		super(GANLoss, self).__init__()
		self.register_buffer('real_label', torch.tensor(target_real_label))
		self.register_buffer('fake_label', torch.tensor(target_fake_label))
		if use_lsgan:
			self.loss = nn.MSELoss()
		else:
			self.loss = nn.BCELoss()
	
	def get_target_tensor(self, input, target_is_real):
		if target_is_real:
			target_tensor = self.real_label
		else:
			target_tensor = self.fake_label
		return target_tensor.expand_as(input)
	
	def __call__(self, input, target_is_real):
		target_tensor = self.get_target_tensor(input, target_is_real)
		return self.loss(input, target_tensor)


# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
	def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
		assert (n_blocks >= 0)
		super(ResnetGenerator, self).__init__()
		self.input_nc = input_nc
		self.output_nc = output_nc
		self.ngf = ngf
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		
		model = [nn.ReflectionPad2d(3),
		         nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
		                   bias=use_bias),
		         norm_layer(ngf),
		         nn.ReLU(True)]
		
		n_downsampling = 2
		for i in range(n_downsampling):
			mult = 2 ** i
			model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
			                    stride=2, padding=1, bias=use_bias),
			          norm_layer(ngf * mult * 2),
			          nn.ReLU(True)]
		
		mult = 2 ** n_downsampling
		for i in range(n_blocks):
			model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
		
		for i in range(n_downsampling):
			mult = 2 ** (n_downsampling - i)
			model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
			                             kernel_size=3, stride=2,
			                             padding=1, output_padding=1,
			                             bias=use_bias),
			          norm_layer(int(ngf * mult / 2)),
			          nn.ReLU(True)]
		model += [nn.ReflectionPad2d(3)]
		model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
		model += [nn.Tanh()]
		
		self.model = nn.Sequential(*model)
	
	def forward(self, input):
		return self.model(input)


# Define a resnet block
class ResnetBlock(nn.Module):
	def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
		super(ResnetBlock, self).__init__()
		self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
	
	def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
		conv_block = []
		p = 0
		if padding_type == 'reflect':
			conv_block += [nn.ReflectionPad2d(1)]
		elif padding_type == 'replicate':
			conv_block += [nn.ReplicationPad2d(1)]
		elif padding_type == 'zero':
			p = 1
		else:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)
		
		conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
		               norm_layer(dim),
		               nn.ReLU(True)]
		if use_dropout:
			conv_block += [nn.Dropout(0.5)]
		
		p = 0
		if padding_type == 'reflect':
			conv_block += [nn.ReflectionPad2d(1)]
		elif padding_type == 'replicate':
			conv_block += [nn.ReplicationPad2d(1)]
		elif padding_type == 'zero':
			p = 1
		else:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)
		conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
		               norm_layer(dim)]
		
		return nn.Sequential(*conv_block)
	
	def forward(self, x):
		out = x + self.conv_block(x)
		return out


# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
	def __init__(self, input_nc, output_nc, num_downs, ngf=64,
	             norm_layer=nn.BatchNorm2d, use_dropout=False):
		super(UnetGenerator, self).__init__()
		
		# construct unet structure
		unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
		for i in range(num_downs - 5):
			unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer,
			                                     use_dropout=use_dropout)
		unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
		
		self.model = unet_block
	
	def forward(self, input):
		return self.model(input)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
	def __init__(self, outer_nc, inner_nc, input_nc=None,
	             submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
		super(UnetSkipConnectionBlock, self).__init__()
		self.outermost = outermost
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		if input_nc is None:
			input_nc = outer_nc
		downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
		                     stride=2, padding=1, bias=use_bias)
		downrelu = nn.LeakyReLU(0.2, True)
		downnorm = norm_layer(inner_nc)
		uprelu = nn.ReLU(True)
		upnorm = norm_layer(outer_nc)
		
		if outermost:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
			                            kernel_size=4, stride=2,
			                            padding=1)
			down = [downconv]
			up = [uprelu, upconv, nn.Tanh()]
			model = down + [submodule] + up
		elif innermost:
			upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
			                            kernel_size=4, stride=2,
			                            padding=1, bias=use_bias)
			down = [downrelu, downconv]
			up = [uprelu, upconv, upnorm]
			model = down + up
		else:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
			                            kernel_size=4, stride=2,
			                            padding=1, bias=use_bias)
			down = [downrelu, downconv, downnorm]
			up = [uprelu, upconv, upnorm]
			
			if use_dropout:
				model = down + [submodule] + up + [nn.Dropout(0.5)]
			else:
				model = down + [submodule] + up
		
		self.model = nn.Sequential(*model)
	
	def forward(self, x):
		if self.outermost:
			return self.model(x)
		else:
			return torch.cat([x, self.model(x)], 1)


# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
	def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
		super(NLayerDiscriminator, self).__init__()
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		
		kw = 4
		padw = 1
		sequence = [
			nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
			nn.LeakyReLU(0.2, True)
		]
		
		nf_mult = 1
		nf_mult_prev = 1
		for n in range(1, n_layers):
			nf_mult_prev = nf_mult
			nf_mult = min(2 ** n, 8)
			sequence += [
				nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
				          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
				norm_layer(ndf * nf_mult),
				nn.LeakyReLU(0.2, True)
			]
		
		nf_mult_prev = nf_mult
		nf_mult = min(2 ** n_layers, 8)
		sequence += [
			nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
			          kernel_size=kw, stride=1, padding=padw, bias=use_bias),
			norm_layer(ndf * nf_mult),
			nn.LeakyReLU(0.2, True)
		]
		
		sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
		
		if use_sigmoid:
			sequence += [nn.Sigmoid()]
		
		self.model = nn.Sequential(*sequence)
	
	def forward(self, input):
		return self.model(input)


class PixelDiscriminator(nn.Module):
	def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
		super(PixelDiscriminator, self).__init__()
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		
		self.net = [
			nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
			nn.LeakyReLU(0.2, True),
			nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
			norm_layer(ndf * 2),
			nn.LeakyReLU(0.2, True),
			nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
		
		if use_sigmoid:
			self.net.append(nn.Sigmoid())
		
		self.net = nn.Sequential(*self.net)
	
	def forward(self, input):
		return self.net(input)


class Classifier(nn.Module):
	def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d):
		super(Classifier, self).__init__()
		
		kw = 3
		sequence = [
			nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2),
			nn.LeakyReLU(0.2, True)
		]
		
		nf_mult = 1
		nf_mult_prev = 1
		for n in range(3):
			nf_mult_prev = nf_mult
			nf_mult = min(2 ** n, 8)
			sequence += [
				nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
				          kernel_size=kw, stride=2),
				norm_layer(ndf * nf_mult, affine=True),
				nn.LeakyReLU(0.2, True)
			]
		self.before_linear = nn.Sequential(*sequence)
		
		sequence = [
			nn.Linear(ndf * nf_mult, 1024),
			nn.Linear(1024, 10)
		]
		
		self.after_linear = nn.Sequential(*sequence)
	
	def forward(self, x):
		bs = x.size(0)
		out = self.after_linear(self.before_linear(x).view(bs, -1))
		return out
#       return nn.functional.log_softmax(out, dim=1)


================================================
FILE: cyclegan/models/test_model.py
================================================
from . import networks
from .base_model import BaseModel


class TestModel(BaseModel):
	def name(self):
		return 'TestModel'
	
	def initialize(self, opt):
		assert (not opt.isTrain)
		BaseModel.initialize(self, opt)
		
		# specify the training losses you want to print out. The program will call base_model.get_current_losses
		self.loss_names = []
		# specify the images you want to save/display. The program will call base_model.get_current_visuals
		self.visual_names = ['real_A']
		# specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
		
		if opt.dataset_mode == 'synthia_cityscapes':
			self.model_names = ['G_A_1']
			self.visual_names.append('fake_B_1')
			self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc,
			                                  opt.ngf, opt.which_model_netG,
			                                  opt.norm, not opt.no_dropout,
			                                  opt.init_type,
			                                  self.gpu_ids)
		
		elif opt.dataset_mode == 'gta5_cityscapes':
			self.model_names = ['G_A_2']
			self.visual_names.append('fake_B_2')
			self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc,
			                                  opt.ngf, opt.which_model_netG,
			                                  opt.norm, not opt.no_dropout,
			                                  opt.init_type,
			                                  self.gpu_ids)
	
	def set_input(self, input):
		# we need to use single_dataset mode
		self.real_A = input['A'].to(self.device)
		self.image_paths = input['A_paths']
	
	def forward(self):
		if hasattr(self, 'netG_A_1'):
			self.fake_B_1 = self.netG_A_1(self.real_A)
		elif hasattr(self, 'netG_A_2'):
			self.fake_B_2 = self.netG_A_2(self.real_A)


================================================
FILE: cyclegan/options/__init__.py
================================================


================================================
FILE: cyclegan/options/base_options.py
================================================
import argparse
import os

import torch
from util import util


class BaseOptions():
	def __init__(self):
		self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
		self.initialized = False
	
	def initialize(self):
		self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
		self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
		self.parser.add_argument('--loadSize', type=int, default=600, help='scale images to this size')
		self.parser.add_argument('--fineSize', type=int, default=600, help='then crop to this size')
		self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
		self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
		self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
		self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
		self.parser.add_argument('--which_model_netD', type=str, default='n_layers', help='selects model to use for netD')
		self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
		self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
		self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
		self.parser.add_argument('--name', type=str, default='experiment_name',
		                         help='name of the experiment. It decides where to store samples and models')
		self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
		                         help='chooses how datasets are loaded. [unaligned | aligned | single]')
		self.parser.add_argument('--model', type=str, default='cycle_gan',
		                         help='chooses which model to use. cycle_gan, pix2pix, test')
		self.parser.add_argument('--weights_model_type', type=str, default='drn26',
		                         help='chooses which model to use. drn26, fcn8s')
		self.parser.add_argument('--num_cls', default=19, type=int)
		self.parser.add_argument('--max_epoch', default=20, type=int)
		self.parser.add_argument('--current_epoch', default=0, type=int)
		self.parser.add_argument('--weights_init', type=str)
		self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
		self.parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
		self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
		self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
		self.parser.add_argument('--serial_batches', action='store_true',
		                         help='if true, takes images in order to make batches, otherwise takes them randomly')
		self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
		self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
		self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
		self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
		self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
		self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
		                         help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, '
		                              'only a subset is loaded.')
		self.parser.add_argument('--resize_or_crop', type=str, default='scale_width_and_crop',
		                         help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
		self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
		self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
		self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
		self.parser.add_argument('--suffix', default='', type=str,
		                         help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
		self.parser.add_argument('--out_all', action='store_true', help='output all stylized images(fake_B_{})')
		self.parser.add_argument('--SAD', action='store_true', help='Sub-domain Aggregation Discriminator module')
		self.parser.add_argument('--CCD', action='store_true', help='Cross-domain Cycle Discriminator module')
		self.parser.add_argument('--CCD_weight', type=float, default=1, help='weight for cross domain cycle discriminator loss')
		self.parser.add_argument('--HF_CCD', action='store_true', help='Half Freeze Cross-domain Cycle Discriminator module')
		self.parser.add_argument('--CCD_frozen_epoch', type=int, default=-1)
		self.parser.add_argument('--SAD_frozen_epoch', type=int, default=-1)
		self.parser.add_argument('--Shared_DT', type=bool, default=True, help="Through ")
		self.parser.add_argument('--model_type', type=str, default='fcn8s', help="choose to load which type of model (fcn8s, drn26, deeplabv2)")
		self.parser.add_argument('--semantic_loss', action='store_true', help='use semantic loss')
		self.parser.add_argument('--general_semantic_weight', type=float, default=0.2, help='weight for semantic loss')
		self.parser.add_argument('--weights_syn', type=str, default='', help='init weights for synthia')
		self.parser.add_argument('--weights_gta', type=str, default='', help='init weights for gta')
		
		self.parser.add_argument('--inference_script', type=str, default='', help='inference script')
		self.parser.add_argument('--dynamic_weight', type=float, default=10, help='Weight for Dynamic Semantic Loss(KL div) loss')
		self.initialized = True
	
	def parse(self):
		if not self.initialized:
			self.initialize()
		opt = self.parser.parse_args()
		opt.isTrain = self.isTrain  # train or test
		
		str_ids = opt.gpu_ids.split(',')
		opt.gpu_ids = []
		for str_id in str_ids:
			id = int(str_id)
			if id >= 0:
				opt.gpu_ids.append(id)
		
		# set gpu ids
		if len(opt.gpu_ids) > 0:
			torch.cuda.set_device(opt.gpu_ids[0])
		
		args = vars(opt)
		
		print('------------ Options -------------')
		for k, v in sorted(args.items()):
			print('%s: %s' % (str(k), str(v)))
		print('-------------- End ----------------')
		
		if opt.suffix:
			suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
			opt.name = opt.name + suffix
		# save to the disk
		expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
		util.mkdirs(expr_dir)
		file_name = os.path.join(expr_dir, 'opt.txt')
		with open(file_name, 'wt') as opt_file:
			opt_file.write('------------ Options -------------\n')
			for k, v in sorted(args.items()):
				opt_file.write('%s: %s\n' % (str(k), str(v)))
			opt_file.write('-------------- End ----------------\n')
		self.opt = opt
		return self.opt


================================================
FILE: cyclegan/options/test_options.py
================================================
from .base_options import BaseOptions


class TestOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
        self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
        self.isTrain = False


================================================
FILE: cyclegan/options/train_options.py
================================================
from .base_options import BaseOptions


class TrainOptions(BaseOptions):
	def initialize(self):
		BaseOptions.initialize(self)
		self.parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
		self.parser.add_argument('--display_ncols', type=int, default=4,
		                         help='if positive, display all images in a single visdom web panel with certain number of images per row.')
		self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
		self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
		self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
		self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
		self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
		self.parser.add_argument('--epoch_count', type=int, default=1,
		                         help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
		self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
		self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
		self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
		self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
		self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
		self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
		self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
		self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)')
		self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for cycle loss (B -> A -> B)')
		self.parser.add_argument('--lambda_identity', type=float, default=0,
		                         help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the '
		                              'identity mapping loss.'
		                              'For example, if the weight of the identity loss should be 10 times smaller than the weight of the '
		                              'reconstruction loss, please set lambda_identity = 0.1')
		self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
		self.parser.add_argument('--no_html', action='store_true',
		                         help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
		self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
		self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
		self.isTrain = True


================================================
FILE: cyclegan/test.py
================================================
import os
import sys

import torch
from models import create_model
from options.test_options import TestOptions
from util import html
from util.visualizer import save_images

from data import CreateDataLoader
import logging

sys.path.append("/nfs/project/libo_i/MADAN")

if __name__ == '__main__':
	opt = TestOptions().parse()
	opt.serial_batches = True  # no shuffle
	opt.no_flip = True  # no flip
	opt.display_id = -1  # no visdom display
	data_loader = CreateDataLoader(opt)
	dataset = data_loader.load_data()
	model = create_model(opt)
	model.setup(opt)
	# create website
	web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
	webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
	# test
	for i, data in enumerate(dataset):
		if i >= opt.how_many:
			break
		# check img size
		if i == 0:
			for item in data.items():
				if isinstance(item[1], torch.Tensor):
					logging.info(item[0], item[1].size())
		
		model.set_input(data)
		model.test()
		visuals = model.get_current_visuals()
		# remove reductant files when outputing
		if opt.out_all:
			remove_list = []
			for item in visuals:
				if 'fake_B' not in item:
					remove_list.append(item)
			
			for rm_item in remove_list:
				del visuals[rm_item]
		
		img_path = model.get_image_paths()
		if i % 5 == 0:
			logging.info('processing (%04d)-th image...' % (i * opt.batchSize))
		if 'mul' in opt.model:
			save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, multi_flag=True)
		else:
			save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)



================================================
FILE: cyclegan/train.py
================================================
import subprocess
import sys
import time

sys.path.append("/nfs/project/libo_i/MADAN/cyclegan")
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
import torch
import logging

if __name__ == '__main__':
	opt = TrainOptions().parse()
	data_loader = CreateDataLoader(opt)
	dataset = data_loader.load_data()
	dataset_size = len(data_loader)
	logging.info('#training images = %d' % dataset_size)
	model = create_model(opt)
	model.setup(opt)
	visualizer = Visualizer(opt)
	total_steps = 0
	for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
		epoch_start_time = time.time()
		iter_data_time = time.time()
		epoch_iter = 0
		opt.current_epoch = epoch
		logging.info("Current epoch update to {}".format(opt.current_epoch))
		for i, data in enumerate(dataset):
			if total_steps == 0:
				for item in data.items():
					if isinstance(item[1], torch.Tensor):
						logging.info(item[1].size())
			iter_start_time = time.time()
			if total_steps % opt.print_freq == 0:
				t_data = iter_start_time - iter_data_time
			visualizer.reset()
			total_steps += opt.batchSize
			epoch_iter += opt.batchSize
			model.set_input(data)
			model.optimize_parameters(opt)
			
			if total_steps % opt.display_freq == 0:
				save_result = total_steps % opt.update_html_freq == 0
				visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
			
			if total_steps % opt.print_freq == 0:
				losses = model.get_current_losses()
				t = (time.time() - iter_start_time) / opt.batchSize
				visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
				if opt.display_id > 0:
					visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
			
			if total_steps % opt.save_latest_freq == 0:
				logging.info('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
				model.save_networks('latest')
			iter_data_time = time.time()
		
		if epoch % opt.save_epoch_freq == 0:
			logging.info('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
			model.save_networks('latest')
			model.save_networks(epoch)
		
		logging.info('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.max_epoch, time.time() - epoch_start_time))
		model.update_learning_rate()


================================================
FILE: cyclegan/util/__init__.py
================================================


================================================
FILE: cyclegan/util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename


class GetData(object):
    """

    Download CycleGAN or Pix2Pix Data.

    Args:
        technique : str
            One of: 'cyclegan' or 'pix2pix'.
        verbose : bool
            If True, print additional information.

    Examples:
        >>> from util.get_data import GetData
        >>> gd = GetData(technique='cyclegan')
        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.

    """

    def __init__(self, technique='cyclegan', verbose=True):
        url_dict = {
            'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
        }
        self.url = url_dict.get(technique.lower())
        self._verbose = verbose

    def _print(self, text):
        if self._verbose:
            print(text)

    @staticmethod
    def _get_options(r):
        soup = BeautifulSoup(r.text, 'lxml')
        options = [h.text for h in soup.find_all('a', href=True)
                   if h.text.endswith(('.zip', 'tar.gz'))]
        return options

    def _present_options(self):
        r = requests.get(self.url)
        options = self._get_options(r)
        print('Options:\n')
        for i, o in enumerate(options):
            print("{0}: {1}".format(i, o))
        choice = input("\nPlease enter the number of the "
                       "dataset above you wish to download:")
        return options[int(choice)]

    def _download_data(self, dataset_url, save_path):
        if not isdir(save_path):
            os.makedirs(save_path)

        base = basename(dataset_url)
        temp_save_path = join(save_path, base)

        with open(temp_save_path, "wb") as f:
            r = requests.get(dataset_url)
            f.write(r.content)

        if base.endswith('.tar.gz'):
            obj = tarfile.open(temp_save_path)
        elif base.endswith('.zip'):
            obj = ZipFile(temp_save_path, 'r')
        else:
            raise ValueError("Unknown File Type: {0}.".format(base))

        self._print("Unpacking Data...")
        obj.extractall(save_path)
        obj.close()
        os.remove(temp_save_path)

    def get(self, save_path, dataset=None):
        """

        Download a dataset.

        Args:
            save_path : str
                A directory to save the data to.
            dataset : str, optional
                A specific dataset to download.
                Note: this must include the file extension.
                If None, options will be presented for you
                to choose from.

        Returns:
            save_path_full : str
                The absolute path to the downloaded data.

        """
        if dataset is None:
            selected_dataset = self._present_options()
        else:
            selected_dataset = dataset

        save_path_full = join(save_path, selected_dataset.split('.')[0])

        if isdir(save_path_full):
            warn("\n'{0}' already exists. Voiding Download.".format(
                save_path_full))
        else:
            self._print('Downloading Data...')
            url = "{0}/{1}".format(self.url, selected_dataset)
            self._download_data(url, save_path=save_path)

        return abspath(save_path_full)


================================================
FILE: cyclegan/util/html.py
================================================
import dominate
from dominate.tags import *
import os


class HTML:
    def __init__(self, web_dir, title, reflesh=0):
        self.title = title
        self.web_dir = web_dir
        self.img_dir = os.path.join(self.web_dir, 'images')
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        # print(self.img_dir)

        self.doc = dominate.document(title=title)
        if reflesh > 0:
            with self.doc.head:
                meta(http_equiv="reflesh", content=str(reflesh))

    def get_image_dir(self):
        return self.img_dir

    def add_header(self, str):
        with self.doc:
            h3(str)

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")
        self.doc.add(self.t)

    def add_images(self, ims, txts, links, width=400):
        self.add_table()
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('images', link)):
                                img(style="width:%dpx" % width, src=os.path.join('images', im))
                            br()
                            p(txt)

    def save(self):
        html_file = '%s/index.html' % self.web_dir
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims = []
    txts = []
    links = []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()


================================================
FILE: cyclegan/util/image_pool.py
================================================
import random
import torch


class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)
        return return_images


================================================
FILE: cyclegan/util/util.py
================================================
from __future__ import print_function

import os

import numpy as np
import torch
from PIL import Image


# Converts a Tensor into an image array (numpy)
# |imtype|: the desired type of the converted numpy array
def tensor2im(input_image, imtype=np.uint8):
	if isinstance(input_image, torch.Tensor):
		image_tensor = input_image.data
	else:
		return input_image
	image_numpy = image_tensor.cpu().float().numpy()
	if image_numpy.shape[0] == 1:
		image_numpy = np.tile(image_numpy, (3, 1, 1))
	image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
	return image_numpy.astype(imtype)

# def tensor2im(input_image, imtype=np.uint8):
#     if isinstance(input_image, torch.Tensor):
#         image_tensor = input_image.data
#     else:
#         return input_image
#     image_numpy = image_tensor[0].cpu().float().numpy()
#     if image_numpy.shape[0] == 1:
#         image_numpy = np.tile(image_numpy, (3, 1, 1))
#     image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
#     return image_numpy.astype(imtype)


def diagnose_network(net, name='network'):
	mean = 0.0
	count = 0
	for param in net.parameters():
		if param.grad is not None:
			mean += torch.mean(torch.abs(param.grad.data))
			count += 1
	if count > 0:
		mean = mean / count
	print(name)
	print(mean)


def save_image(image_numpy, image_path):
	image_pil = Image.fromarray(image_numpy)
	image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
	x = x.astype(np.float64)
	if shp:
		print('shape,', x.shape)
	if val:
		x = x.flatten()
		print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
			np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
	if isinstance(paths, list) and not isinstance(paths, str):
		for path in paths:
			mkdir(path)
	else:
		mkdir(paths)


def mkdir(path):
	if not os.path.exists(path):
		os.makedirs(path)


================================================
FILE: cyclegan/util/visualizer.py
================================================
import ntpath
import os
import time

import numpy as np
from . import html, util


# save image to the disk
def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=256, multi_flag=False):
	for i in range(len(image_path)):
		short_path = ntpath.basename(image_path[i])
		name = os.path.splitext(short_path)[0]
		
		for ind, (label, im_data) in enumerate(visuals.items()):
			# align visual names and real image name
			if multi_flag is True and (str(i + 1) not in label):
				continue
			im = util.tensor2im(im_data[i, :, :, :])
			image_name = '%s_%s.png' % (name, label)
			save_path = os.path.join(image_dir, image_name)
			h, w, _ = im.shape
			util.save_image(im, save_path)


class Visualizer():
	def __init__(self, opt):
		self.display_id = opt.display_id
		self.use_html = opt.isTrain and not opt.no_html
		self.win_size = opt.display_winsize
		self.name = opt.name
		self.opt = opt
		self.saved = False
		if self.display_id > 0:
			import visdom
			self.ncols = opt.display_ncols
			self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
		
		if self.use_html:
			self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
			self.img_dir = os.path.join(self.web_dir, 'images')
			print('create web directory %s...' % self.web_dir)
			util.mkdirs([self.web_dir, self.img_dir])
		self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
		with open(self.log_name, "a") as log_file:
			now = time.strftime("%c")
			log_file.write('================ Training Loss (%s) ================\n' % now)
	
	def reset(self):
		self.saved = False
	
	# |visuals|: dictionary of images to display or save
	def display_current_results(self, visuals, epoch, save_result):
		if self.display_id > 0:  # show images in the browser
			ncols = self.ncols
			if ncols > 0:
				ncols = min(ncols, len(visuals))
				h, w = next(iter(visuals.values())).shape[:2]
				table_css = """<style>
                        table
Download .txt
gitextract_rf2a9vy3/

├── .gitignore
├── .idea/
│   ├── MADAN.iml
│   ├── deployment.xml
│   ├── inspectionProfiles/
│   │   └── profiles_settings.xml
│   ├── misc.xml
│   ├── modules.xml
│   ├── remote-mappings.xml
│   └── vcs.xml
├── LICENSE
├── README.md
├── cycada/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── adda_datasets.py
│   │   ├── bdds.py
│   │   ├── cityscapes.py
│   │   ├── cityscapes_labels.py
│   │   ├── cyclegan.py
│   │   ├── cyclegta5.py
│   │   ├── cyclesynthia.py
│   │   ├── cyclesynthia_cyclegta5.py
│   │   ├── data_loader.py
│   │   ├── gta5.py
│   │   ├── rotater.py
│   │   ├── synthia.py
│   │   └── util.py
│   ├── logging.yml
│   ├── models/
│   │   ├── MDAN.py
│   │   ├── __init__.py
│   │   ├── adda_net.py
│   │   ├── drn.py
│   │   ├── fcn8s.py
│   │   ├── models.py
│   │   ├── task_net.py
│   │   └── util.py
│   ├── tools/
│   │   ├── __init__.py
│   │   ├── train_adda_net.py
│   │   ├── train_task_net.py
│   │   └── util.py
│   ├── transforms.py
│   └── util.py
├── cyclegan/
│   ├── .gitignore
│   ├── data/
│   │   ├── __init__.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── cityscapes.py
│   │   ├── gta5_cityscapes.py
│   │   ├── gta_synthia_cityscapes.py
│   │   ├── image_folder.py
│   │   └── synthia_cityscapes.py
│   ├── environment.yml
│   ├── models/
│   │   ├── __init__.py
│   │   ├── base_model.py
│   │   ├── cycle_gan_model.py
│   │   ├── cycle_gan_semantic_model.py
│   │   ├── multi_cycle_gan_semantic_model.py
│   │   ├── networks.py
│   │   └── test_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── test.py
│   ├── train.py
│   └── util/
│       ├── __init__.py
│       ├── get_data.py
│       ├── html.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
├── requirements.txt
├── scripts/
│   ├── ADDA/
│   │   ├── adda_cyclegta2cs_feat.sh
│   │   ├── adda_cyclegta2cs_score.sh
│   │   ├── adda_cyclesyn2cs_feat.sh
│   │   ├── adda_cyclesyn2cs_score.sh
│   │   └── adda_templates.sh
│   ├── CycleGAN/
│   │   ├── cyclegan_gta2cityscapes.sh
│   │   ├── cyclegan_gta_synthia2cityscapes.sh
│   │   ├── cyclegan_synthia2cityscapes.sh
│   │   ├── test_templates.sh
│   │   └── test_templates_cycle.sh
│   ├── FCN/
│   │   ├── train_fcn8s_cyclesgta5.sh
│   │   └── train_fcn8s_cyclesynthia.sh
│   ├── eval_fcn.py
│   ├── train_fcn.py
│   ├── train_fcn_adda.py
│   └── train_fcn_mdan.py
└── tools/
    ├── __init__.py
    └── eval_templates.sh
Download .txt
SYMBOL INDEX (386 symbols across 52 files)

FILE: cycada/data/adda_datasets.py
  class AddaDataLoader (line 9) | class AddaDataLoader(object):
    method __init__ (line 10) | def __init__(self, net_transform, dataset, rootdir, downscale, crop_si...
    method __iter__ (line 31) | def __iter__(self):
    method __next__ (line 34) | def __next__(self):
    method next (line 37) | def next(self):
    method __len__ (line 51) | def __len__(self):
    method set_loader_src (line 54) | def set_loader_src(self):
    method set_loader_tgt (line 69) | def set_loader_tgt(self):

FILE: cycada/data/bdds.py
  class BDDS (line 10) | class BDDS(data.Dataset):
    method __init__ (line 11) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 22) | def collect_ids(self):
    method img_path (line 32) | def img_path(self, filename):
    method label_path (line 35) | def label_path(self, filename):
    method __getitem__ (line 38) | def __getitem__(self, index, debug=False):
    method __len__ (line 51) | def __len__(self):

FILE: cycada/data/cityscapes.py
  function remap_labels_to_train_ids (line 10) | def remap_labels_to_train_ids(arr):
  class CityScapesParams (line 18) | class CityScapesParams(DatasetParams):
  class Cityscapes (line 28) | class Cityscapes(data.Dataset):
    method __init__ (line 29) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 43) | def collect_ids(self):
    method img_path (line 52) | def img_path(self, id):
    method label_path (line 58) | def label_path(self, id):
    method __getitem__ (line 64) | def __getitem__(self, index, debug=False):
    method __len__ (line 78) | def __len__(self):

FILE: cycada/data/cityscapes_labels.py
  function label_img_to_color (line 7) | def label_img_to_color(img):

FILE: cycada/data/cyclegan.py
  class CycleGANDataset (line 10) | class CycleGANDataset(data.Dataset):
    method __init__ (line 11) | def __init__(self, root, regexp, transform=None, target_transform=None,
    method find_images (line 19) | def find_images(self, regexp='*.png'):
    method __getitem__ (line 28) | def __getitem__(self, index):
    method __len__ (line 39) | def __len__(self):
  class Svhn2MNIST (line 44) | class Svhn2MNIST(CycleGANDataset):
    method __init__ (line 45) | def __init__(self, root, train=True, transform=None, target_transform=...
  class Svhn2MNISTParams (line 56) | class Svhn2MNISTParams(DatasetParams):
  class Usps2Mnist (line 76) | class Usps2Mnist(CycleGANDataset):
    method __init__ (line 77) | def __init__(self, root, train=True, transform=None, target_transform=...
  class Usps2MnistParams (line 88) | class Usps2MnistParams(DatasetParams):
  class Mnist2Usps (line 100) | class Mnist2Usps(CycleGANDataset):
    method __init__ (line 101) | def __init__(self, root, train=True, transform=None, target_transform=...
  class Mnist2UspsParams (line 112) | class Mnist2UspsParams(DatasetParams):

FILE: cycada/data/cyclegta5.py
  class CycleGTA5 (line 12) | class CycleGTA5(GTA5):
    method collect_ids (line 13) | def collect_ids(self):
    method __getitem__ (line 29) | def __getitem__(self, index, debug=False):

FILE: cycada/data/cyclesynthia.py
  function syn_relabel (line 55) | def syn_relabel(arr):
  class SYNTHIAParams (line 63) | class SYNTHIAParams(DatasetParams):
  class CycleSYNTHIA (line 73) | class CycleSYNTHIA(data.Dataset):
    method __init__ (line 75) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 85) | def collect_ids(self):
    method img_path (line 99) | def img_path(self, filename):
    method label_path (line 102) | def label_path(self, filename):
    method __getitem__ (line 110) | def __getitem__(self, index, debug=False):
    method __len__ (line 126) | def __len__(self):

FILE: cycada/data/cyclesynthia_cyclegta5.py
  function syn_relabel (line 56) | def syn_relabel(arr):
  class SYNTHIAParams (line 64) | class SYNTHIAParams(DatasetParams):
  class CycleSYNTHIACycleGTA5 (line 74) | class CycleSYNTHIACycleGTA5(data.Dataset):
    method __init__ (line 76) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 89) | def collect_ids(self, datasets_name):
    method img_path (line 110) | def img_path(self, prefix, filename):
    method syn_label_path (line 115) | def syn_label_path(self, filename):
    method gta_label_path (line 121) | def gta_label_path(self, filename):
    method __getitem__ (line 127) | def __getitem__(self, index, debug=False):
    method __len__ (line 166) | def __len__(self):

FILE: cycada/data/data_loader.py
  function load_data (line 15) | def load_data(name, dset, batch=64, rootdir='', num_channels=3,
  function get_transform_dataset (line 34) | def get_transform_dataset(dataset_name, rootdir, net_transform, downscal...
  function get_orig_size (line 43) | def get_orig_size(dataset_name):
  function get_transform2 (line 51) | def get_transform2(dataset_name, net_transform, downscale, resize):
  function get_transform (line 72) | def get_transform(params, image_size, num_channels):
  function get_target_transform (line 98) | def get_target_transform(params):
  class AddaDataset (line 108) | class AddaDataset(data.Dataset):
    method __init__ (line 110) | def __init__(self, src_data, tgt_data):
    method __getitem__ (line 114) | def __getitem__(self, index):
    method __len__ (line 121) | def __len__(self):
  function register_data_params (line 128) | def register_data_params(name):
  function register_dataset_obj (line 139) | def register_dataset_obj(name):
  class DatasetParams (line 147) | class DatasetParams(object):
  function get_dataset (line 157) | def get_dataset(name, rootdir, dset, image_size, num_channels, download=...
  function get_fcn_dataset (line 167) | def get_fcn_dataset(name, rootdir, **kwargs):

FILE: cycada/data/gta5.py
  class GTA5Params (line 13) | class GTA5Params(DatasetParams):
  class GTA5 (line 23) | class GTA5(data.Dataset):
    method __init__ (line 25) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 41) | def collect_ids(self):
    method img_path (line 46) | def img_path(self, id):
    method label_path (line 50) | def label_path(self, id):
    method __getitem__ (line 54) | def __getitem__(self, index, debug=False):
    method __len__ (line 71) | def __len__(self):

FILE: cycada/data/rotater.py
  class Rotater (line 1) | class Rotater(object):
    method __init__ (line 3) | def __init__(self, dataset, orientations=6, transform=None,
    method __getitem__ (line 10) | def __getitem__(self, index):
    method __len__ (line 21) | def __len__(self):

FILE: cycada/data/synthia.py
  function syn_relabel (line 9) | def syn_relabel(arr):
  class SYNTHIAParams (line 16) | class SYNTHIAParams(DatasetParams):
  class SYNTHIA (line 26) | class SYNTHIA(data.Dataset):
    method __init__ (line 28) | def __init__(self, root, num_cls=19, split='train', remap_labels=True,...
    method collect_ids (line 40) | def collect_ids(self):
    method img_path (line 48) | def img_path(self, filename):
    method label_path (line 56) | def label_path(self, filename):
    method __getitem__ (line 64) | def __getitem__(self, index, debug=False):
    method __len__ (line 87) | def __len__(self):

FILE: cycada/data/util.py
  function maybe_download (line 57) | def maybe_download(url, dest):
  function download (line 66) | def download(url, dest):

FILE: cycada/models/MDAN.py
  class GradientReversalLayer (line 13) | class GradientReversalLayer(torch.autograd.Function):
    method forward (line 19) | def forward(self, inputs):
    method backward (line 22) | def backward(self, grad_output):
  class MDANet (line 28) | class MDANet(nn.Module):
    method __init__ (line 33) | def __init__(self, configs):
    method forward (line 54) | def forward(self, sinputs_syn, sinputs_gta, tinputs):
    method inference (line 94) | def inference(self, inputs):

FILE: cycada/models/adda_net.py
  class AddaNet (line 10) | class AddaNet(nn.Module):
    method __init__ (line 12) | def __init__(self, num_cls=10, model='LeNet', src_weights_init=None,
    method forward (line 30) | def forward(self, x_s, x_t):
    method setup_net (line 44) | def setup_net(self):
    method load (line 61) | def load(self, init_path):
    method load_src_net (line 66) | def load_src_net(self, init_path):
    method save (line 72) | def save(self, out_path):
    method save_tgt_net (line 75) | def save_tgt_net(self, out_path):

FILE: cycada/models/drn.py
  function conv3x3 (line 20) | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
  class BasicBlock (line 25) | class BasicBlock(nn.Module):
    method __init__ (line 28) | def __init__(self, inplanes, planes, stride=1, downsample=None,
    method forward (line 42) | def forward(self, x):
  class Bottleneck (line 61) | class Bottleneck(nn.Module):
    method __init__ (line 64) | def __init__(self, inplanes, planes, stride=1, downsample=None,
    method forward (line 79) | def forward(self, x):
  class DRN (line 102) | class DRN(nn.Module):
    method __init__ (line 110) | def __init__(self, block, layers, num_cls=1000,
    method _make_layer (line 175) | def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
    method forward (line 199) | def forward(self, x):
  function drn26 (line 252) | def drn26(pretrained=True, finetune=False, out_map=True, **kwargs):
  function drn42 (line 267) | def drn42(pretrained=False, finetune=False, out_map=True, **kwargs):
  function drn58 (line 275) | def drn58(pretrained=False, **kwargs):

FILE: cycada/models/fcn8s.py
  function get_upsample_filter (line 14) | def get_upsample_filter(size):
  class Bilinear (line 27) | class Bilinear(nn.Module):
    method __init__ (line 29) | def __init__(self, factor, num_channels):
    method forward (line 38) | def forward(self, x):
  class VGG16_FCN8s (line 43) | class VGG16_FCN8s(nn.Module):
    method __init__ (line 51) | def __init__(self, num_cls=19, pretrained=True, weights_init=None,
    method load_base_vgg (line 86) | def load_base_vgg(self, weights_state_dict):
    method load_vgg_head (line 90) | def load_vgg_head(self, weights_state_dict):
    method get_dict_by_prefix (line 94) | def get_dict_by_prefix(self, weights_state_dict, prefix):
    method load_weights (line 99) | def load_weights(self, weights_state_dict):
    method split_vgg_head (line 103) | def split_vgg_head(self):
    method forward (line 107) | def forward(self, x):
    method load_base_weights (line 142) | def load_base_weights(self):
  class VGG16_FCN8s_caffe (line 161) | class VGG16_FCN8s_caffe(VGG16_FCN8s):
    method load_base_weights (line 171) | def load_base_weights(self):
  class Discriminator (line 188) | class Discriminator(nn.Module):
    method __init__ (line 189) | def __init__(self, input_dim=4096, output_dim=2, pretrained=False, wei...
    method forward (line 206) | def forward(self, x):
    method load_weights (line 210) | def load_weights(self, weights):
  class Transform_Module (line 215) | class Transform_Module(nn.Module):
    method __init__ (line 216) | def __init__(self, input_dim=4096):
    method forward (line 229) | def forward(self, x):
  function init_eye (line 234) | def init_eye(tensor):
  function _crop (line 241) | def _crop(input, shape, offset=0):
  function make_layers (line 246) | def make_layers(cfg, batch_norm=False):

FILE: cycada/models/models.py
  function register_model (line 4) | def register_model(name):
  function get_model (line 11) | def get_model(name, num_cls=10, **args):

FILE: cycada/models/task_net.py
  class TaskNet (line 8) | class TaskNet(nn.Module):
    method __init__ (line 15) | def __init__(self, num_cls=10, weights_init=None):
    method forward (line 25) | def forward(self, x, with_ft=False):
    method setup_net (line 35) | def setup_net(self):
    method load (line 39) | def load(self, init_path):
    method save (line 43) | def save(self, out_path):
  class LeNet (line 47) | class LeNet(TaskNet):
    method setup_net (line 55) | def setup_net(self):
  class DTNClassifier (line 76) | class DTNClassifier(TaskNet):
    method setup_net (line 84) | def setup_net(self):

FILE: cycada/models/util.py
  function init_weights (line 4) | def init_weights(obj):

FILE: cycada/tools/train_adda_net.py
  function train (line 21) | def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch):
  function train_adda (line 118) | def train_adda(src, tgt, model, num_cls, num_epoch=200,

FILE: cycada/tools/train_task_net.py
  function train_epoch (line 22) | def train_epoch(loader, net, opt_net, epoch):
  function train (line 55) | def train(data, datadir, model, num_cls, outdir='',

FILE: cycada/tools/util.py
  function make_variable (line 7) | def make_variable(tensor, volatile=False, requires_grad=True):
  function pairwise_distance (line 15) | def pairwise_distance(x, y):
  function gaussian_kernel_matrix (line 30) | def gaussian_kernel_matrix(x, y, sigmas):
  function maximum_mean_discrepancy (line 40) | def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix):
  function mmd_loss (line 48) | def mmd_loss(source_features, target_features):

FILE: cycada/transforms.py
  class RandomCrop (line 17) | class RandomCrop(object):
    method __init__ (line 23) | def __init__(self, size):
    method __call__ (line 29) | def __call__(self, tensors):
  class HalfCrop (line 48) | class HalfCrop(object):
    method __call__ (line 54) | def __call__(self, tensors):
  class RandomHorizontalFlip (line 65) | class RandomHorizontalFlip(object):
    method __call__ (line 69) | def __call__(self, tensors):
  function augment_collate (line 79) | def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=N...

FILE: cycada/util.py
  class TqdmHandler (line 13) | class TqdmHandler(logging.StreamHandler):
    method __init__ (line 15) | def __init__(self):
    method emit (line 18) | def emit(self, record):
  function config_logging (line 23) | def config_logging(logfile=None):
  function to_tensor_raw (line 35) | def to_tensor_raw(im):
  function safe_load_state_dict (line 39) | def safe_load_state_dict(net, state_dict):
  function step_lr (line 66) | def step_lr(optimizer, mult):

FILE: cyclegan/data/__init__.py
  function CreateDataLoader (line 10) | def CreateDataLoader(opt):
  function CreateDataset (line 17) | def CreateDataset(opt):
  class CustomDatasetDataLoader (line 39) | class CustomDatasetDataLoader(BaseDataLoader):
    method name (line 40) | def name(self):
    method initialize (line 43) | def initialize(self, opt):
    method load_data (line 52) | def load_data(self):
    method __len__ (line 55) | def __len__(self):
    method __iter__ (line 58) | def __iter__(self):

FILE: cyclegan/data/base_data_loader.py
  class BaseDataLoader (line 1) | class BaseDataLoader():
    method __init__ (line 2) | def __init__(self):
    method initialize (line 5) | def initialize(self, opt):
    method load_data (line 9) | def load_data():

FILE: cyclegan/data/base_dataset.py
  class BaseDataset (line 8) | class BaseDataset(data.Dataset):
    method __init__ (line 9) | def __init__(self):
    method name (line 12) | def name(self):
    method initialize (line 15) | def initialize(self, opt):
  function get_transform (line 20) | def get_transform(opt):
  function get_label_transform (line 46) | def get_label_transform(opt):
  function __scale_width (line 71) | def __scale_width(img, target_width):
  function to_tensor_raw (line 80) | def to_tensor_raw(im):

FILE: cyclegan/data/cityscapes.py
  function remap_labels_to_train_ids (line 18) | def remap_labels_to_train_ids(arr):

FILE: cyclegan/data/gta5_cityscapes.py
  class GTAVCityscapesDataset (line 57) | class GTAVCityscapesDataset(BaseDataset):
    method initialize (line 58) | def initialize(self, opt):
    method __getitem__ (line 83) | def __getitem__(self, index):
    method __len__ (line 134) | def __len__(self):
    method name (line 137) | def name(self):

FILE: cyclegan/data/gta_synthia_cityscapes.py
  function syn_relabel (line 56) | def syn_relabel(arr):
  class GTASynthiaCityscapesDataset (line 63) | class GTASynthiaCityscapesDataset(BaseDataset):
    method initialize (line 64) | def initialize(self, opt):
    method __getitem__ (line 98) | def __getitem__(self, index):
    method __len__ (line 160) | def __len__(self):
    method name (line 163) | def name(self):

FILE: cyclegan/data/image_folder.py
  function is_image_file (line 21) | def is_image_file(filename):
  function make_cs_labels (line 24) | def make_cs_labels(dir):
  function make_dataset (line 37) | def make_dataset(dir):
  function load_labels (line 49) | def load_labels(dir, images):
  function default_loader (line 64) | def default_loader(path):
  class ImageFolder (line 68) | class ImageFolder(data.Dataset):
    method __init__ (line 70) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 84) | def __getitem__(self, index):
    method __len__ (line 94) | def __len__(self):

FILE: cyclegan/data/synthia_cityscapes.py
  function syn_relabel (line 57) | def syn_relabel(arr):
  class SynthiaCityscapesDataset (line 64) | class SynthiaCityscapesDataset(BaseDataset):
    method initialize (line 65) | def initialize(self, opt):
    method __getitem__ (line 90) | def __getitem__(self, index):
    method __len__ (line 137) | def __len__(self):
    method name (line 140) | def name(self):

FILE: cyclegan/models/__init__.py
  function create_model (line 3) | def create_model(opt):

FILE: cyclegan/models/base_model.py
  class BaseModel (line 9) | class BaseModel():
    method name (line 10) | def name(self):
    method initialize (line 13) | def initialize(self, opt):
    method set_input (line 26) | def set_input(self, input):
    method forward (line 29) | def forward(self):
    method setup (line 33) | def setup(self, opt):
    method eval (line 42) | def eval(self):
    method test (line 50) | def test(self):
    method get_image_paths (line 55) | def get_image_paths(self):
    method optimize_parameters (line 58) | def optimize_parameters(self):
    method update_learning_rate (line 62) | def update_learning_rate(self):
    method get_current_visuals (line 69) | def get_current_visuals(self):
    method get_current_losses (line 77) | def get_current_losses(self):
    method save_networks (line 86) | def save_networks(self, which_epoch):
    method __patch_instance_norm_state_dict (line 100) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
    method load_networks (line 111) | def load_networks(self, which_epoch):
    method print_networks (line 130) | def print_networks(self, verbose):
    method set_requires_grad (line 144) | def set_requires_grad(self, nets, requires_grad=False):

FILE: cyclegan/models/cycle_gan_model.py
  class CycleGANModel (line 8) | class CycleGANModel(BaseModel):
    method name (line 9) | def name(self):
    method initialize (line 12) | def initialize(self, opt):
    method set_input (line 64) | def set_input(self, input):
    method forward (line 70) | def forward(self):
    method backward_D_basic (line 77) | def backward_D_basic(self, netD, real, fake):
    method backward_D_A (line 90) | def backward_D_A(self):
    method backward_D_B (line 94) | def backward_D_B(self):
    method backward_G (line 98) | def backward_G(self):
    method optimize_parameters (line 126) | def optimize_parameters(self):

FILE: cyclegan/models/cycle_gan_semantic_model.py
  class CycleGANSemanticModel (line 15) | class CycleGANSemanticModel(BaseModel):
    method name (line 16) | def name(self):
    method initialize (line 19) | def initialize(self, opt):
    method set_input (line 90) | def set_input(self, input):
    method forward (line 101) | def forward(self):
    method backward_D_basic (line 117) | def backward_D_basic(self, netD, real, fake):
    method backward_PixelCLS (line 130) | def backward_PixelCLS(self):
    method backward_D_A (line 137) | def backward_D_A(self):
    method backward_D_B (line 141) | def backward_D_B(self):
    method backward_G (line 145) | def backward_G(self, opt):
    method optimize_parameters (line 182) | def optimize_parameters(self, opt):

FILE: cyclegan/models/multi_cycle_gan_semantic_model.py
  class CycleGANSemanticModel (line 15) | class CycleGANSemanticModel(BaseModel):
    method name (line 16) | def name(self):
    method initialize (line 19) | def initialize(self, opt):
    method set_input (line 193) | def set_input(self, input):
    method forward (line 206) | def forward(self, opt):
    method backward_D_basic (line 242) | def backward_D_basic(self, netD, real, fake, SAD=False):
    method backward_D_A (line 259) | def backward_D_A(self, Shared_DT):
    method backward_D_B (line 273) | def backward_D_B(self):
    method backward_D (line 282) | def backward_D(self, which_D):
    method backward_G (line 301) | def backward_G(self, opt):
    method backward_HF_CCD (line 383) | def backward_HF_CCD(self, opt):
    method optimize_parameters (line 409) | def optimize_parameters(self, opt):

FILE: cyclegan/models/networks.py
  function get_norm_layer (line 14) | def get_norm_layer(norm_type='instance'):
  function get_scheduler (line 26) | def get_scheduler(optimizer, opt):
  function init_weights (line 42) | def init_weights(net, init_type='normal', gain=0.02):
  function init_net (line 66) | def init_net(net, init_type='normal', gpu_ids=[]):
  function print_network (line 75) | def print_network(net):
  function define_G (line 83) | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', u...
  function define_D (line 100) | def define_D(input_nc, ndf, which_model_netD,
  function define_C (line 117) | def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]):
  class GANLoss (line 135) | class GANLoss(nn.Module):
    method __init__ (line 136) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
    method get_target_tensor (line 145) | def get_target_tensor(self, input, target_is_real):
    method __call__ (line 152) | def __call__(self, input, target_is_real):
  class ResnetGenerator (line 161) | class ResnetGenerator(nn.Module):
    method __init__ (line 162) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 205) | def forward(self, input):
  class ResnetBlock (line 210) | class ResnetBlock(nn.Module):
    method __init__ (line 211) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
    method build_conv_block (line 215) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
    method forward (line 247) | def forward(self, x):
  class UnetGenerator (line 256) | class UnetGenerator(nn.Module):
    method __init__ (line 257) | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
    method forward (line 273) | def forward(self, input):
  class UnetSkipConnectionBlock (line 280) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 281) | def __init__(self, outer_nc, inner_nc, input_nc=None,
    method forward (line 326) | def forward(self, x):
  class NLayerDiscriminator (line 334) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 335) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
    method forward (line 377) | def forward(self, input):
  class PixelDiscriminator (line 381) | class PixelDiscriminator(nn.Module):
    method __init__ (line 382) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_si...
    method forward (line 402) | def forward(self, input):
  class Classifier (line 406) | class Classifier(nn.Module):
    method __init__ (line 407) | def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d):
    method forward (line 436) | def forward(self, x):

FILE: cyclegan/models/test_model.py
  class TestModel (line 5) | class TestModel(BaseModel):
    method name (line 6) | def name(self):
    method initialize (line 9) | def initialize(self, opt):
    method set_input (line 37) | def set_input(self, input):
    method forward (line 42) | def forward(self):

FILE: cyclegan/options/base_options.py
  class BaseOptions (line 8) | class BaseOptions():
    method __init__ (line 9) | def __init__(self):
    method initialize (line 13) | def initialize(self):
    method parse (line 77) | def parse(self):

FILE: cyclegan/options/test_options.py
  class TestOptions (line 4) | class TestOptions(BaseOptions):
    method initialize (line 5) | def initialize(self):

FILE: cyclegan/options/train_options.py
  class TrainOptions (line 4) | class TrainOptions(BaseOptions):
    method initialize (line 5) | def initialize(self):

FILE: cyclegan/util/get_data.py
  class GetData (line 11) | class GetData(object):
    method __init__ (line 29) | def __init__(self, technique='cyclegan', verbose=True):
    method _print (line 37) | def _print(self, text):
    method _get_options (line 42) | def _get_options(r):
    method _present_options (line 48) | def _present_options(self):
    method _download_data (line 58) | def _download_data(self, dataset_url, save_path):
    method get (line 81) | def get(self, save_path, dataset=None):

FILE: cyclegan/util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 7) | def __init__(self, web_dir, title, reflesh=0):
    method get_image_dir (line 22) | def get_image_dir(self):
    method add_header (line 25) | def add_header(self, str):
    method add_table (line 29) | def add_table(self, border=1):
    method add_images (line 33) | def add_images(self, ims, txts, links, width=400):
    method save (line 45) | def save(self):

FILE: cyclegan/util/image_pool.py
  class ImagePool (line 5) | class ImagePool():
    method __init__ (line 6) | def __init__(self, pool_size):
    method query (line 12) | def query(self, images):

FILE: cyclegan/util/util.py
  function tensor2im (line 12) | def tensor2im(input_image, imtype=np.uint8):
  function diagnose_network (line 35) | def diagnose_network(net, name='network'):
  function save_image (line 48) | def save_image(image_numpy, image_path):
  function print_numpy (line 53) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 63) | def mkdirs(paths):
  function mkdir (line 71) | def mkdir(path):

FILE: cyclegan/util/visualizer.py
  function save_images (line 10) | def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=...
  class Visualizer (line 26) | class Visualizer():
    method __init__ (line 27) | def __init__(self, opt):
    method reset (line 49) | def reset(self):
    method display_current_results (line 53) | def display_current_results(self, visuals, epoch, save_result):
    method plot_current_losses (line 119) | def plot_current_losses(self, epoch, counter_ratio, opt, losses):
    method print_current_losses (line 135) | def print_current_losses(self, epoch, i, losses, t, t_data):

FILE: scripts/eval_fcn.py
  function fmt_array (line 26) | def fmt_array(arr, fmt=','):
  function fast_hist (line 31) | def fast_hist(a, b, n):
  function result_stats (line 36) | def result_stats(hist):
  function main (line 57) | def main(path, dataset, datadir, model, gpu, num_cls, batch_size, loadSi...

FILE: scripts/train_fcn.py
  function to_tensor_raw (line 25) | def to_tensor_raw(im):
  function roundrobin_infinite (line 29) | def roundrobin_infinite(*loaders):
  function supervised_loss (line 43) | def supervised_loss(score, label, weights=None):
  function main (line 78) | def main(output, dataset, datadir, batch_size, lr, step, iterations,

FILE: scripts/train_fcn_adda.py
  function check_label (line 25) | def check_label(label, num_cls):
  function forward_pass (line 40) | def forward_pass(net, discriminator, im, requires_grad=False, discrim_fe...
  function supervised_loss (line 53) | def supervised_loss(score, label, weights=None):
  function discriminator_loss (line 59) | def discriminator_loss(score, target_val, lsgan=False):
  function fast_hist (line 69) | def fast_hist(a, b, n):
  function seg_accuracy (line 74) | def seg_accuracy(score, label, num_cls):
  function main (line 112) | def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cl...

FILE: scripts/train_fcn_mdan.py
  function to_tensor_raw (line 29) | def to_tensor_raw(im):
  function roundrobin_infinite (line 33) | def roundrobin_infinite(*loaders):
  function multi_source_infinite (line 47) | def multi_source_infinite(loaders, target_loader):
  function supervised_loss (line 64) | def supervised_loss(score, label, weights=None):
  function main (line 96) | def main(output, dataset, target_name, datadir, batch_size, lr, iterations,
Condensed preview — 88 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (265K chars).
[
  {
    "path": ".gitignore",
    "chars": 1463,
    "preview": ".DS_Store\n# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm\n# "
  },
  {
    "path": ".idea/MADAN.iml",
    "chars": 558,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager"
  },
  {
    "path": ".idea/deployment.xml",
    "chars": 482,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"PublishConfigData\" autoUpload=\"Always\" s"
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "chars": 174,
    "preview": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n"
  },
  {
    "path": ".idea/misc.xml",
    "chars": 462,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"JavaScriptSettings\">\n    <option name=\"l"
  },
  {
    "path": ".idea/modules.xml",
    "chars": 262,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n   "
  },
  {
    "path": ".idea/remote-mappings.xml",
    "chars": 473,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"RemoteMappingsManager\">\n    <list>\n     "
  },
  {
    "path": ".idea/vcs.xml",
    "chars": 180,
    "preview": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping dire"
  },
  {
    "path": "LICENSE",
    "chars": 1066,
    "preview": "MIT License\n\nCopyright (c) 2019 liljprime\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "README.md",
    "chars": 4148,
    "preview": "# MADAN\n\nA Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181)"
  },
  {
    "path": "cycada/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cycada/data/__init__.py",
    "chars": 122,
    "preview": "from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds\nfrom . import adda_datase"
  },
  {
    "path": "cycada/data/adda_datasets.py",
    "chars": 3239,
    "preview": "import os.path\n\nimport torch.utils.data\n\nfrom .data_loader import get_transform_dataset\nfrom ..transforms import augment"
  },
  {
    "path": "cycada/data/bdds.py",
    "chars": 1516,
    "preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import classes, igno"
  },
  {
    "path": "cycada/data/cityscapes.py",
    "chars": 2253,
    "preview": "import os.path\nimport sys\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import cl"
  },
  {
    "path": "cycada/data/cityscapes_labels.py",
    "chars": 772,
    "preview": "# function for colorizing a label image:\n# camera-ready\n\nimport numpy as np\n\n\ndef label_img_to_color(img):\n\tlabel_to_col"
  },
  {
    "path": "cycada/data/cyclegan.py",
    "chars": 3660,
    "preview": "import os\nfrom os.path import join\nimport glob\nfrom PIL import Image\n\nimport torch.utils.data as data\nfrom .data_loader "
  },
  {
    "path": "cycada/data/cyclegta5.py",
    "chars": 1805,
    "preview": "import os.path\n\nimport numpy as np\nfrom PIL import Image\n\nfrom .cityscapes import remap_labels_to_train_ids\nfrom .data_l"
  },
  {
    "path": "cycada/data/cyclesynthia.py",
    "chars": 3323,
    "preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .data_loader import Datas"
  },
  {
    "path": "cycada/data/cyclesynthia_cyclegta5.py",
    "chars": 5214,
    "preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .cityscapes import remap_"
  },
  {
    "path": "cycada/data/data_loader.py",
    "chars": 5111,
    "preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\n\nimport numpy as np\nimport torch\nimport torch."
  },
  {
    "path": "cycada/data/gta5.py",
    "chars": 2074,
    "preview": "import os.path\n\nimport numpy as np\nimport scipy.io\nimport torch.utils.data as data\nfrom PIL import Image\n\nfrom .cityscap"
  },
  {
    "path": "cycada/data/rotater.py",
    "chars": 755,
    "preview": "class Rotater(object):\n\n    def __init__(self, dataset, orientations=6, transform=None,\n                 target_transfor"
  },
  {
    "path": "cycada/data/synthia.py",
    "chars": 2520,
    "preview": "import os.path\n\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nfrom .util import classes, igno"
  },
  {
    "path": "cycada/data/util.py",
    "chars": 1812,
    "preview": "import logging\nimport os.path\n\nimport requests\n\nlogger = logging.getLogger(__name__)\n\nignore_label = 255\nid2label = {0: "
  },
  {
    "path": "cycada/logging.yml",
    "chars": 722,
    "preview": "---\nversion: 1\ndisable_existing_loggers: False\nformatters:\n    simple:\n        format: \"[%(asctime)s] %(levelname)-8s %("
  },
  {
    "path": "cycada/models/MDAN.py",
    "chars": 3563,
    "preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\nimport logging\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functi"
  },
  {
    "path": "cycada/models/__init__.py",
    "chars": 193,
    "preview": "from .models import get_model\nfrom .task_net import LeNet\nfrom .task_net import DTNClassifier\nfrom .adda_net import Adda"
  },
  {
    "path": "cycada/models/adda_net.py",
    "chars": 2499,
    "preview": "\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom .util import init_weights\nfrom .mo"
  },
  {
    "path": "cycada/models/drn.py",
    "chars": 8303,
    "preview": "import math\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.model_zoo as model_zoo\nimport torchvision\n\nfrom .mode"
  },
  {
    "path": "cycada/models/fcn8s.py",
    "chars": 7917,
    "preview": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torchvision\nfrom torch import nn\nfrom torch.autog"
  },
  {
    "path": "cycada/models/models.py",
    "chars": 308,
    "preview": "import torch\n\nmodels = {}\ndef register_model(name):\n    def decorator(cls):\n        models[name] = cls\n        return cl"
  },
  {
    "path": "cycada/models/task_net.py",
    "chars": 3070,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom .models import register_model \nfrom .util import init_"
  },
  {
    "path": "cycada/models/util.py",
    "chars": 349,
    "preview": "import torch.nn as nn\nfrom torch.nn import init\n\ndef init_weights(obj):\n    for m in obj.modules():\n        if isinstanc"
  },
  {
    "path": "cycada/tools/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cycada/tools/train_adda_net.py",
    "chars": 5678,
    "preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\nimport numpy as np\n\n# Import from torch\nimport"
  },
  {
    "path": "cycada/tools/train_task_net.py",
    "chars": 3334,
    "preview": "from __future__ import print_function\n\nimport os\nfrom os.path import join\nimport numpy as np\nimport argparse\n\n# Import f"
  },
  {
    "path": "cycada/tools/util.py",
    "chars": 1560,
    "preview": "from functools import partial\n\nimport torch\nfrom torch.autograd import Variable\n\n\ndef make_variable(tensor, volatile=Fal"
  },
  {
    "path": "cycada/transforms.py",
    "chars": 2714,
    "preview": "\"\"\"These random transforms extend the transforms provided in torchvision to\nallow for transforming multiple images at th"
  },
  {
    "path": "cycada/util.py",
    "chars": 2005,
    "preview": "import logging\nimport logging.config\nimport os.path\nfrom collections import OrderedDict\n\nimport numpy as np\nimport torch"
  },
  {
    "path": "cyclegan/.gitignore",
    "chars": 736,
    "preview": ".DS_Store\ndebug*\ncheckpoints/\nresults/\nbuild/\ndist/\n*.png\ntorch.egg-info/\n*/**/__pycache__\ntorch/version.py\ntorch/csrc/g"
  },
  {
    "path": "cyclegan/data/__init__.py",
    "chars": 1819,
    "preview": "import sys\n\nimport torch.utils.data\nfrom data.base_data_loader import BaseDataLoader\n\nsys.path.append('/nfs/project/libo"
  },
  {
    "path": "cyclegan/data/base_data_loader.py",
    "chars": 171,
    "preview": "class BaseDataLoader():\n    def __init__(self):\n        pass\n\n    def initialize(self, opt):\n        self.opt = opt\n    "
  },
  {
    "path": "cyclegan/data/base_dataset.py",
    "chars": 2925,
    "preview": "import numpy as np\nimport torch\nimport torch.utils.data as data\nimport torchvision.transforms as transforms\nfrom PIL imp"
  },
  {
    "path": "cyclegan/data/cityscapes.py",
    "chars": 1216,
    "preview": "import numpy as np\n\nignore_label = 255\nid2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label,\n"
  },
  {
    "path": "cyclegan/data/gta5_cityscapes.py",
    "chars": 3753,
    "preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
  },
  {
    "path": "cyclegan/data/gta_synthia_cityscapes.py",
    "chars": 4626,
    "preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
  },
  {
    "path": "cyclegan/data/image_folder.py",
    "chars": 3052,
    "preview": "###############################################################################\n# Code from\n# https://github.com/pytorch"
  },
  {
    "path": "cyclegan/data/synthia_cityscapes.py",
    "chars": 3775,
    "preview": "import os.path\nimport random\n\nimport numpy as np\nfrom PIL import Image\nfrom data.base_dataset import BaseDataset, get_la"
  },
  {
    "path": "cyclegan/environment.yml",
    "chars": 249,
    "preview": "name: pytorch-CycleGAN-and-pix2pix\nchannels:\n- peterjc123\n- defaults\ndependencies:\n- python=3.5.5\n- pytorch=0.3.1\n- scip"
  },
  {
    "path": "cyclegan/models/__init__.py",
    "chars": 745,
    "preview": "import logging\n\ndef create_model(opt):\n\tmodel = None\n\tif opt.model == 'cycle_gan':\n\t\t# assert(opt.dataset_mode == 'unali"
  },
  {
    "path": "cyclegan/models/base_model.py",
    "chars": 4958,
    "preview": "import os\nfrom collections import OrderedDict\n\nimport torch\n\nfrom . import networks\n\n\nclass BaseModel():\n\tdef name(self)"
  },
  {
    "path": "cyclegan/models/cycle_gan_model.py",
    "chars": 6300,
    "preview": "import torch\nimport itertools\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import netw"
  },
  {
    "path": "cyclegan/models/cycle_gan_semantic_model.py",
    "chars": 8035,
    "preview": "import itertools\nimport sys\n\nimport torch\nimport torch.nn.functional as F\nfrom util.image_pool import ImagePool\n\nfrom . "
  },
  {
    "path": "cyclegan/models/multi_cycle_gan_semantic_model.py",
    "chars": 19776,
    "preview": "import itertools\nimport sys\n\nimport torch\nimport torch.nn.functional as F\nfrom util.image_pool import ImagePool\n\nfrom . "
  },
  {
    "path": "cyclegan/models/networks.py",
    "chars": 15300,
    "preview": "import functools\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn import init\nfrom torch.optim import lr_scheduler\n\n\n##"
  },
  {
    "path": "cyclegan/models/test_model.py",
    "chars": 1795,
    "preview": "from . import networks\nfrom .base_model import BaseModel\n\n\nclass TestModel(BaseModel):\n\tdef name(self):\n\t\treturn 'TestMo"
  },
  {
    "path": "cyclegan/options/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cyclegan/options/base_options.py",
    "chars": 7391,
    "preview": "import argparse\nimport os\n\nimport torch\nfrom util import util\n\n\nclass BaseOptions():\n\tdef __init__(self):\n\t\tself.parser "
  },
  {
    "path": "cyclegan/options/test_options.py",
    "chars": 845,
    "preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.in"
  },
  {
    "path": "cyclegan/options/train_options.py",
    "chars": 3352,
    "preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n\tdef initialize(self):\n\t\tBaseOptions.initialize"
  },
  {
    "path": "cyclegan/test.py",
    "chars": 1722,
    "preview": "import os\nimport sys\n\nimport torch\nfrom models import create_model\nfrom options.test_options import TestOptions\nfrom uti"
  },
  {
    "path": "cyclegan/train.py",
    "chars": 2354,
    "preview": "import subprocess\nimport sys\nimport time\n\nsys.path.append(\"/nfs/project/libo_i/MADAN/cyclegan\")\nfrom options.train_optio"
  },
  {
    "path": "cyclegan/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cyclegan/util/get_data.py",
    "chars": 3511,
    "preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
  },
  {
    "path": "cyclegan/util/html.py",
    "chars": 1912,
    "preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n    def __init__(self, web_dir, title, reflesh=0):\n "
  },
  {
    "path": "cyclegan/util/image_pool.py",
    "chars": 1072,
    "preview": "import random\nimport torch\n\n\nclass ImagePool():\n    def __init__(self, pool_size):\n        self.pool_size = pool_size\n  "
  },
  {
    "path": "cyclegan/util/util.py",
    "chars": 1896,
    "preview": "from __future__ import print_function\n\nimport os\n\nimport numpy as np\nimport torch\nfrom PIL import Image\n\n\n# Converts a T"
  },
  {
    "path": "cyclegan/util/visualizer.py",
    "chars": 5282,
    "preview": "import ntpath\nimport os\nimport time\n\nimport numpy as np\nfrom . import html, util\n\n\n# save image to the disk\ndef save_ima"
  },
  {
    "path": "requirements.txt",
    "chars": 142,
    "preview": "scipy\ntorchvision\ntensorboardX\ntensorflow\nclick\ntqdm\nrequests\ncolorlog\npyyaml\ntorch>=1.1.0\ntorchvision>=0.3.0\ndominate>="
  },
  {
    "path": "scripts/ADDA/adda_cyclegta2cs_feat.sh",
    "chars": 1568,
    "preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=2e-5\nmomentum="
  },
  {
    "path": "scripts/ADDA/adda_cyclegta2cs_score.sh",
    "chars": 1600,
    "preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=2e-5\nmomentum="
  },
  {
    "path": "scripts/ADDA/adda_cyclesyn2cs_feat.sh",
    "chars": 1379,
    "preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
  },
  {
    "path": "scripts/ADDA/adda_cyclesyn2cs_score.sh",
    "chars": 1379,
    "preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
  },
  {
    "path": "scripts/ADDA/adda_templates.sh",
    "chars": 1274,
    "preview": "#!/usr/bin/env bash\n\ngpu=0,1,2,3\n\n######################\n# loss weight params #\n######################\nlr=1e-5\nmomentum="
  },
  {
    "path": "scripts/CycleGAN/cyclegan_gta2cityscapes.sh",
    "chars": 532,
    "preview": "#!/usr/bin/env bash\ncd /nfs/project/libo_i/MADAN/cyclegan\n\nsudo python3 train.py --name cyclegan_gta2cityscapes \\\n    --"
  },
  {
    "path": "scripts/CycleGAN/cyclegan_gta_synthia2cityscapes.sh",
    "chars": 710,
    "preview": "#!/usr/bin/env bash\ncd /nfs/project/libo_i/MADAN/cyclegan\n\npython3 train.py --name cyclegan_gta_synthia2cityscapes_noIde"
  },
  {
    "path": "scripts/CycleGAN/cyclegan_synthia2cityscapes.sh",
    "chars": 742,
    "preview": "#!/usr/bin/env bash\ncd /root/MADAN/cyclegan\n\npython3 train.py --name cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_S"
  },
  {
    "path": "scripts/CycleGAN/test_templates.sh",
    "chars": 444,
    "preview": "#!/usr/bin/env bash\n\nhow_many=100000\n\ncd /root/MADAN/cyclegan\nname=$1\nepoch=$2\n\npython3 test.py --name ${name} --resize_"
  },
  {
    "path": "scripts/CycleGAN/test_templates_cycle.sh",
    "chars": 988,
    "preview": "#!/usr/bin/env bash\n# Sequentially load two generators(GTA, Synthia) and finish\nhow_many=100000\n\ncd /root/MADAN/cyclegan"
  },
  {
    "path": "scripts/FCN/train_fcn8s_cyclesgta5.sh",
    "chars": 617,
    "preview": "#!/usr/bin/env bash\ngpu=0,1,2,3\ndata=cyclegta5\nmodel=fcn8s\n\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\n\ndatadir=/root/MAD"
  },
  {
    "path": "scripts/FCN/train_fcn8s_cyclesynthia.sh",
    "chars": 637,
    "preview": "#!/usr/bin/env bash\ngpu=0,1,2,3\ndata=cyclesynthia\nmodel=fcn8s\n\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\n\ndatadir=/root/"
  },
  {
    "path": "scripts/eval_fcn.py",
    "chars": 4985,
    "preview": "import os\nimport sys\n\nfrom torchvision.transforms import transforms\n\nsys.path.append('/nfs/project/libo_iMADAN')\nimport "
  },
  {
    "path": "scripts/train_fcn.py",
    "chars": 6789,
    "preview": "import logging\nimport os.path\nimport sys\nfrom collections import deque\n\nimport click\nimport numpy as np\nimport torch\nimp"
  },
  {
    "path": "scripts/train_fcn_adda.py",
    "chars": 14835,
    "preview": "import logging\nimport os\nimport os.path\nimport sys\nfrom collections import deque\nfrom datetime import datetime\n\nimport c"
  },
  {
    "path": "scripts/train_fcn_mdan.py",
    "chars": 10341,
    "preview": "import itertools\nimport json\nimport logging\nimport os.path\nimport subprocess\nimport sys\nfrom collections import deque\n\ni"
  },
  {
    "path": "tools/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "tools/eval_templates.sh",
    "chars": 363,
    "preview": "#!/usr/bin/env bash\nexport LC_ALL=C.UTF-8\nexport LANG=C.UTF-8\ncd /nfs/project/libo_i/MADAN\n\nckpt_path=$1\ndatadir=/nfs/pr"
  }
]

About this extraction

This page contains the full source code of the Luodian/MADAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 88 files (237.8 KB), approximately 68.9k tokens, and a symbol index with 386 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!