Repository: Luodian/MADAN Branch: master Commit: 7a2918da44f5 Files: 88 Total size: 237.8 KB Directory structure: gitextract_rf2a9vy3/ ├── .gitignore ├── .idea/ │ ├── MADAN.iml │ ├── deployment.xml │ ├── inspectionProfiles/ │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ ├── remote-mappings.xml │ └── vcs.xml ├── LICENSE ├── README.md ├── cycada/ │ ├── __init__.py │ ├── data/ │ │ ├── __init__.py │ │ ├── adda_datasets.py │ │ ├── bdds.py │ │ ├── cityscapes.py │ │ ├── cityscapes_labels.py │ │ ├── cyclegan.py │ │ ├── cyclegta5.py │ │ ├── cyclesynthia.py │ │ ├── cyclesynthia_cyclegta5.py │ │ ├── data_loader.py │ │ ├── gta5.py │ │ ├── rotater.py │ │ ├── synthia.py │ │ └── util.py │ ├── logging.yml │ ├── models/ │ │ ├── MDAN.py │ │ ├── __init__.py │ │ ├── adda_net.py │ │ ├── drn.py │ │ ├── fcn8s.py │ │ ├── models.py │ │ ├── task_net.py │ │ └── util.py │ ├── tools/ │ │ ├── __init__.py │ │ ├── train_adda_net.py │ │ ├── train_task_net.py │ │ └── util.py │ ├── transforms.py │ └── util.py ├── cyclegan/ │ ├── .gitignore │ ├── data/ │ │ ├── __init__.py │ │ ├── base_data_loader.py │ │ ├── base_dataset.py │ │ ├── cityscapes.py │ │ ├── gta5_cityscapes.py │ │ ├── gta_synthia_cityscapes.py │ │ ├── image_folder.py │ │ └── synthia_cityscapes.py │ ├── environment.yml │ ├── models/ │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── cycle_gan_model.py │ │ ├── cycle_gan_semantic_model.py │ │ ├── multi_cycle_gan_semantic_model.py │ │ ├── networks.py │ │ └── test_model.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── test.py │ ├── train.py │ └── util/ │ ├── __init__.py │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── util.py │ └── visualizer.py ├── requirements.txt ├── scripts/ │ ├── ADDA/ │ │ ├── adda_cyclegta2cs_feat.sh │ │ ├── adda_cyclegta2cs_score.sh │ │ ├── adda_cyclesyn2cs_feat.sh │ │ ├── adda_cyclesyn2cs_score.sh │ │ └── adda_templates.sh │ ├── CycleGAN/ │ │ ├── cyclegan_gta2cityscapes.sh │ │ ├── cyclegan_gta_synthia2cityscapes.sh │ │ ├── cyclegan_synthia2cityscapes.sh │ │ ├── test_templates.sh │ │ └── test_templates_cycle.sh │ ├── FCN/ │ │ ├── train_fcn8s_cyclesgta5.sh │ │ └── train_fcn8s_cyclesynthia.sh │ ├── eval_fcn.py │ ├── train_fcn.py │ ├── train_fcn_adda.py │ └── train_fcn_mdan.py └── tools/ ├── __init__.py └── eval_templates.sh ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .DS_Store # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # User-specific stuff .idea/**/workspace.xml .idea/**/tasks.xml .idea/**/usage.statistics.xml .idea/**/dictionaries .idea/**/shelf # Generated files .idea/**/contentModel.xml # Sensitive or high-churn files .idea/**/dataSources/ .idea/**/dataSources.ids .idea/**/dataSources.local.xml .idea/**/sqlDataSources.xml .idea/**/dynamic.xml .idea/**/uiDesigner.xml .idea/**/dbnavigator.xml # Gradle .idea/**/gradle.xml .idea/**/libraries # Gradle and Maven with auto-import # When using Gradle or Maven with auto-import, you should exclude module files, # since they will be recreated, and may cause churn. Uncomment if using # auto-import. # .idea/modules.xml # .idea/*.iml # .idea/modules # *.iml # *.ipr # CMake cmake-build-*/ # Mongo Explorer plugin .idea/**/mongoSettings.xml # File-based project format *.iws # IntelliJ out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Cursive Clojure plugin .idea/replstate.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties # Editor-based Rest Client .idea/httpRequests # Android studio 3.1+ serialized cache file .idea/caches/build_file_checksums.ser /models/ ================================================ FILE: .idea/MADAN.iml ================================================ ================================================ FILE: .idea/deployment.xml ================================================ ================================================ FILE: .idea/inspectionProfiles/profiles_settings.xml ================================================ ================================================ FILE: .idea/misc.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/remote-mappings.xml ================================================ ================================================ FILE: .idea/vcs.xml ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2019 liljprime Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # MADAN A Pytorch Code for [Multi-source Domain Adaptation for Semantic Segmentation](https://arxiv.org/abs/1910.12181) If you use this code in your research please consider citing: ``` @InProceedings{zhao2019madan, title = {Multi-source Domain Adaptation for Semantic Segmentation}, author = {Zhao, Sicheng and Li, Bo and Yue, Xiangyu and Gu, Yang and Xu, Pengfei and Tan, Hu, Runbo and Chai, Hua and Keutzer, Kurt}, booktitle = {Advances in Neural Information Processing Systems}, year = {2019} } ``` ## Quick Look Our multi-source domain adaptation builds on the work [CyCADA](https://github.com/jhoffman/cycada_release) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). Since we focus on Semantic Segmentation task, we remove Digit Classfication part in CyCADA. We add following modules and achieve startling improvements. 1. Dynamic Semantic Consistency Module 2. Adversarial Aggregation Module 1. Sub-domain Aggregation Discriminator 2. Cross-domain Cycle Discriminator While we implements [MDAN](https://openreview.net/pdf?id=ryDNZZZAW) for Semantic Segmentation task in Pytorch as our baseline comparasion. ## Overall Structure ![image-20190608104531451](http://ww4.sinaimg.cn/large/006tNc79ly1g3tjype7qlj31vo0u0hb1.jpg) ## Setup Check out this repo: ```bash git clone https://github.com/pikachusocute/MADAN.git ``` Install Python3 requirements ```bash pip3 install -r requirements.txt ``` ## Dynamic Adversarial Image Generation We follow the way in CyCADA, in the first step, we need to train Image Adaptation module to transfer source image(GTA, Synthia or Multi-source) to "source as target". ![image-20190608111738818](http://ww4.sinaimg.cn/large/006tNc79ly1g3tkvxw9rrj31r40e8kjl.jpg) We refer Image Adaptation module from GTA to Cityscapes as GTA->Cityscapes in the following. #### GTA->Cityscapes ```bash cd scripts/CycleGAN bash cyclegan_gta2cityscapes.sh ``` In the training process, snapshot files will be stored in `cyclegan/checkpoints/[EXP_NAME]`. Usually, afer we run for 20 epochs, there'll be a file `20_net_G_A.pth ` in previous folder path. Then we run the test process. ```bash bash scripts/CycleGAN/test_templates.sh [EXP_NAME] 20 cycle_gan_semantic_fcn gta5_cityscapes ``` In multi-source case, there are both `20_net_G_A_1.pth` and `20_net_G_A_2.pth` exist. We use another script to run test process. ![image](https://tva1.sinaimg.cn/large/006y8mN6ly1g9cqt9m2kmj31j80skgsh.jpg) ```bash bash scripts/CycleGAN/test_templates_cycle.sh [EXP_NAME] 20 test synthia_cityscapes gta5_cityscapes ``` New dataset will be generated at `~/cyclegan/results/[EXP_NAME]/train_20`. After we obtain a new source stylized dataset, we then train segmenter on the new dataset. ## Pixel Level Adaptation In this part, we train our new segmenter on new dataset. ```bash ln -s ~/cyclegan/results/[EXP_NAME]/train_20 ~/data/cyclegta5/[EXP_NAME]_TRAIN_60 ``` Then we set `dataflag = [EXP_NAME]_TRAIN_60` to find datasets' paths, and follow instructions to train segmenter to perform pixel level adaptation. ```bash bash scripts/FCN/train_fcn8s_cyclesgta5_DSC.sh ``` ## Feature Level Adaptation For adaptation, we use ```bash bash scripts/ADDA/adda_cyclegta2cs_score.sh ``` Make sure you choose the desired `src` and `tgt` and `datadir` before. In this process, you should load your `base_model` trained on synthetic dataset and perform adaptation in feature level to real scene dataset. ### Our Model We release our adaptation model in the `./models`, you can use `scripts/eval_templates.sh` to evaluate its validity. 1. [CycleGTA5_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/1moGF7L2hkTHUPUzqsSxPwKNlHCHQm4Ms/view?usp=sharing) 2. [CycleSYNTHIA_Dynamic_Semantic_Consistency](https://drive.google.com/file/d/19V5J1zyF3ct3247gSSr-u3WVkDJqQvUk/view?usp=sharing) 3. [Multi_Source_SAD_CCD](https://drive.google.com/file/d/1xgmLwhsbwv-isy7R5FkNevVSH9gcMxuq/view?usp=sharing) ### Transfered Dataset We will release our transfer dataset soon, where our `CycleGTA5_Dynamic_Semantic_Consistency` model is trained to perform pixel level adaptation. ================================================ FILE: cycada/__init__.py ================================================ ================================================ FILE: cycada/data/__init__.py ================================================ from . import gta5, cityscapes, cyclegta5, synthia, cyclesynthia, cyclesynthia_cyclegta5, bdds from . import adda_datasets ================================================ FILE: cycada/data/adda_datasets.py ================================================ import os.path import torch.utils.data from .data_loader import get_transform_dataset from ..transforms import augment_collate class AddaDataLoader(object): def __init__(self, net_transform, dataset, rootdir, downscale, crop_size=None, resize=None, batch_size=1, shuffle=False, num_workers=2, half_crop=None, src_data_flag=None, small=False): self.dataset = dataset self.downscale = downscale self.resize = resize self.crop_size = crop_size self.half_crop = half_crop self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers assert len(self.dataset) == 2, 'Requires two datasets: source, target' sourcedir = os.path.join(rootdir, self.dataset[0]) targetdir = os.path.join(rootdir, self.dataset[1]) self.source = get_transform_dataset(self.dataset[0], sourcedir, net_transform, downscale, resize, src_data_flag=src_data_flag, small=small) self.target = get_transform_dataset(self.dataset[1], targetdir, net_transform, downscale, resize, small=small) print('Source length:', len(self.source), 'Target length:', len(self.target)) self.n = max(len(self.source), len(self.target)) # make sure you see all images self.num = 0 self.set_loader_src() self.set_loader_tgt() def __iter__(self): return self def __next__(self): return self.next() def next(self): if self.num % len(self.iters_src) == 0: print('restarting source dataset') self.set_loader_src() if self.num % len(self.iters_tgt) == 0: print('restarting target dataset') self.set_loader_tgt() img_src, label_src = next(self.iters_src) img_tgt, label_tgt = next(self.iters_tgt) self.num += 1 return img_src, img_tgt, label_src, label_tgt def __len__(self): return min(len(self.source), len(self.target)) def set_loader_src(self): batch_size = self.batch_size shuffle = self.shuffle num_workers = self.num_workers if self.crop_size is not None or self.resize is not None: collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size, halfcrop=self.half_crop, flip=True) else: collate_fn = torch.utils.data.dataloader.default_collate self.loader_src = torch.utils.data.DataLoader(self.source, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) self.iters_src = iter(self.loader_src) def set_loader_tgt(self): batch_size = self.batch_size shuffle = self.shuffle num_workers = self.num_workers if self.crop_size is not None or self.resize is not None: collate_fn = lambda batch: augment_collate(batch, resize=self.resize, crop=self.crop_size, halfcrop=self.half_crop, flip=True) else: collate_fn = torch.utils.data.dataloader.default_collate self.loader_tgt = torch.utils.data.DataLoader(self.target, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, collate_fn=collate_fn, pin_memory=True) self.iters_tgt = iter(self.loader_tgt) ================================================ FILE: cycada/data/bdds.py ================================================ import os.path import numpy as np import torch.utils.data as data from PIL import Image from .util import classes, ignore_label, id2label from .data_loader import register_dataset_obj @register_dataset_obj('bdds') class BDDS(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None): self.root = root self.split = split self.remap_labels = remap_labels self.transform = transform self.target_transform = target_transform self.classes = classes self.data_flag = data_flag self.num_cls = num_cls self.ids = self.collect_ids() def collect_ids(self): splits = [] path = os.path.join(self.root, "images", self.split) files = os.listdir(path) for item in files: fip = os.path.join(path, item) splits.append(fip.split('/')[-1]) return splits def img_path(self, filename): return os.path.join(self.root, "images", self.split, filename) def label_path(self, filename): return os.path.join(self.root, 'labels', self.split, "{}_train_id.png".format(filename[:-4])) def __getitem__(self, index, debug=False): id = self.ids[index] img_path = self.img_path(id) label_path = self.label_path(id) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) ================================================ FILE: cycada/data/cityscapes.py ================================================ import os.path import sys import numpy as np import torch.utils.data as data from PIL import Image from .util import classes, ignore_label, id2label from .data_loader import DatasetParams, register_data_params, register_dataset_obj def remap_labels_to_train_ids(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out @register_data_params('cityscapes') class CityScapesParams(DatasetParams): num_channels = 3 image_size = 1024 mean = 0.5 std = 0.5 num_cls = 19 target_transform = None @register_dataset_obj('cityscapes') class Cityscapes(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None): self.root = root sys.path.append(root) self.split = split self.remap_labels = remap_labels self.ids = self.collect_ids() self.transform = transform self.target_transform = target_transform self.num_cls = num_cls self.id2label = id2label self.classes = classes def collect_ids(self): im_dir = os.path.join(self.root, 'leftImg8bit', self.split) ids = [] for dirpath, dirnames, filenames in os.walk(im_dir): for filename in filenames: if filename.endswith('.png'): ids.append('_'.join(filename.split('_')[:3])) return ids def img_path(self, id): fmt = 'leftImg8bit/{}/{}/{}_leftImg8bit.png' subdir = id.split('_')[0] path = fmt.format(self.split, subdir, id) return os.path.join(self.root, path) def label_path(self, id): fmt = 'gtFine/{}/{}/{}_gtFine_labelIds.png' subdir = id.split('_')[0] path = fmt.format(self.split, subdir, id) return os.path.join(self.root, path) def __getitem__(self, index, debug=False): id = self.ids[index] img = Image.open(self.img_path(id)).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(self.label_path(id)).convert('L') if self.remap_labels: target = np.asarray(target) target = remap_labels_to_train_ids(target) target = Image.fromarray(np.uint8(target), 'L') if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) ================================================ FILE: cycada/data/cityscapes_labels.py ================================================ # function for colorizing a label image: # camera-ready import numpy as np def label_img_to_color(img): label_to_color = { 0: [128, 64, 128], 1: [244, 35, 232], 2: [70, 70, 70], 3: [102, 102, 156], 4: [190, 153, 153], 5: [153, 153, 153], 6: [250, 170, 30], 7: [220, 220, 0], 8: [107, 142, 35], 9: [152, 251, 152], 10: [70, 130, 180], 11: [220, 20, 60], 12: [255, 0, 0], 13: [0, 0, 142], 14: [0, 0, 70], 15: [0, 60, 100], 16: [0, 80, 100], 17: [0, 0, 230], 18: [119, 11, 32] } img_height, img_width = img.shape img_color = np.zeros((img_height, img_width, 3)) for row in range(img_height): for col in range(img_width): label = img[row, col] img_color[row, col] = np.array(label_to_color[label]) return img_color ================================================ FILE: cycada/data/cyclegan.py ================================================ import os from os.path import join import glob from PIL import Image import torch.utils.data as data from .data_loader import DatasetParams from .data_loader import register_dataset_obj, register_data_params class CycleGANDataset(data.Dataset): def __init__(self, root, regexp, transform=None, target_transform=None, download=False): self.root = root self.transform = transform self.target_transform = target_transform self.image_paths, self.labels = self.find_images(regexp) def find_images(self, regexp='*.png'): basenames = sorted(glob.glob(join(self.root, regexp))) image_paths = [] labels = [] for basename in basenames: image_paths.append(os.path.join(self.root, basename)) labels.append(int(basename.split('/')[-1].split('_')[0])) return image_paths, labels def __getitem__(self, index): im = Image.open(self.image_paths[index]) #.convert('L') target = self.labels[index] if self.transform is not None: im = self.transform(im) if self.target_transform is not None: target = self.target_transform(target) return im, target def __len__(self): return len(self.image_paths) @register_dataset_obj('svhn2mnist') class Svhn2MNIST(CycleGANDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): if not train: print('No test set for svhn2mnist.') self.image_paths = [] else: super(Svhn2MNIST, self).__init__(root, '*_fake_B.png', transform=transform, target_transform=target_transform, download=download) @register_data_params('svhn2mnist') class Svhn2MNISTParams(DatasetParams): num_channels = 3 image_size = 32 mean = 0.5 std = 0.5 #mean = 0.1307 #std = 0.3081 # mean and std (when scaled between [0,1]) #mean = 0.127 # ep50 #mean = 0.21 # ep100 -- more white pixels... #std = 0.29 #mean = 0.21 #std = 0.2 num_cls = 10 target_transform = None @register_dataset_obj('usps2mnist') class Usps2Mnist(CycleGANDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): if not train: print('No test set for usps2mnist.') self.image_paths = [] else: super(Usps2Mnist, self).__init__(root, '*_fake_A.png', transform=transform, target_transform=target_transform, download=download) @register_data_params('usps2mnist') class Usps2MnistParams(DatasetParams): num_channels = 3 image_size = 16 #mean = 0.1307 #std = 0.3081 mean = 0.5 std = 0.5 num_cls = 10 target_transform = None @register_dataset_obj('mnist2usps') class Mnist2Usps(CycleGANDataset): def __init__(self, root, train=True, transform=None, target_transform=None, download=False): if not train: print('No test set for mnist2usps.') self.image_paths = [] else: super(Mnist2Usps, self).__init__(root, '*_fake_B.png', transform=transform, target_transform=target_transform, download=download) @register_data_params('mnist2usps') class Mnist2UspsParams(DatasetParams): num_channels = 3 image_size = 16 # this seems wrong... #mean = 0.25 #std = 0.37 #mean = 0.1307 #std = 0.3081 mean = 0.5 std = 0.5 num_cls = 10 target_transform = None ================================================ FILE: cycada/data/cyclegta5.py ================================================ import os.path import numpy as np from PIL import Image from .cityscapes import remap_labels_to_train_ids from .data_loader import register_dataset_obj from .gta5 import GTA5 # , LABEL2TRAIN @register_dataset_obj('cyclegta5') class CycleGTA5(GTA5): def collect_ids(self): # ids = GTA5.collect_ids(self) existing_ids = [] if self.data_flag: path = os.path.join(self.root, self.data_flag) else: path = os.path.join(self.root, "images") files = os.listdir(path) for item in files: full_path = os.path.join(path, item) if os.path.exists(full_path) is False: continue existing_ids.append(full_path.split('/')[-1]) return sorted(existing_ids) def __getitem__(self, index, debug=False): filename = self.ids[index] if self.data_flag == '' or self.data_flag is None: img_path = os.path.join(self.root, "images", filename) else: img_path = os.path.join(self.root, self.data_flag, filename) if self.data_flag == '' or self.data_flag is None: label_path = os.path.join(self.root, 'labels_600x1080', filename) else: if filename.endswith('_fake_B.png'): label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B.png', '.png')) elif filename.endswith('_fake_B_2.png'): label_path = os.path.join(self.root, 'labels_600x1080', filename.replace('_fake_B_2.png', '.png')) img = Image.open(img_path).convert('RGB') target = Image.open(label_path) img = img.resize(target.size, resample=Image.BILINEAR) if self.transform is not None: img = self.transform(img) if self.remap_labels: target = np.asarray(target) target = remap_labels_to_train_ids(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) return img, target ================================================ FILE: cycada/data/cyclesynthia.py ================================================ import os.path import numpy as np import torch.utils.data as data from PIL import Image from .data_loader import DatasetParams, register_data_params, register_dataset_obj ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out @register_data_params('cyclesynthia') class SYNTHIAParams(DatasetParams): num_channels = 3 image_size = 1024 mean = 0.5 std = 0.5 num_cls = 19 target_transform = None @register_dataset_obj('cyclesynthia') class CycleSYNTHIA(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None): self.root = root.replace('cycle', '') self.split = split self.remap_labels = remap_labels self.transform = transform self.target_transform = target_transform self.classes = classes self.num_cls = num_cls self.ids = self.collect_ids() def collect_ids(self): splits = [] if self.data_flag: path = os.path.join(self.root, self.data_flag) else: path = os.path.join(self.root, 'Cycle') files = os.listdir(path) for item in files: fip = os.path.join(path, item) if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')): splits.append(fip.split('/')[-1]) return splits def img_path(self, filename): return os.path.join(self.root, filename) def label_path(self, filename): # Case for loading images generated in multi-source cycle # In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5 if filename.endswith('_fake_B_1.png'): return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png')) elif filename.endswith('_fake_B.png'): return os.path.join(self.root, 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png')) def __getitem__(self, index, debug=False): id = self.ids[index] img_path = self.img_path(id) label_path = self.label_path(id) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.remap_labels: target = np.asarray(target) target = syn_relabel(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) ================================================ FILE: cycada/data/cyclesynthia_cyclegta5.py ================================================ import os.path import numpy as np import torch.utils.data as data from PIL import Image from .cityscapes import remap_labels_to_train_ids from .data_loader import DatasetParams, register_data_params, register_dataset_obj ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out @register_data_params('cyclesynthia_cyclegta5') class SYNTHIAParams(DatasetParams): num_channels = 3 image_size = 1024 mean = 0.5 std = 0.5 num_cls = 19 target_transform = None # In this class, we iteratively load transferred images from cyclesynthia and cyclegta5 @register_dataset_obj('cyclesynthia_cyclegta5') class CycleSYNTHIACycleGTA5(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None): self.dataset_name = os.path.basename(root) self.parent_path = root.replace(self.dataset_name, '') self.syn_name = os.path.join(self.parent_path, 'synthia') self.gta_name = os.path.join(self.parent_path, 'cyclegta5') self.remap_labels = remap_labels self.transform = transform self.target_transform = target_transform self.classes = classes self.num_cls = num_cls self.syn_ids = self.collect_ids('syn') self.gta_ids = self.collect_ids('gta') def collect_ids(self, datasets_name): splits = [] if datasets_name == 'syn': files = os.listdir(self.syn_name) for item in files: fip = os.path.join(self.syn_name, item) if (fip.endswith('_fake_B_1.png') or fip.endswith('_fake_B.png')): splits.append(fip.split('/')[-1]) elif datasets_name == 'gta': files = os.listdir(self.gta_name) for item in files: fip = os.path.join(self.gta_name, item) if (fip.endswith('_fake_B_2.png') or fip.endswith('_fake_B.png')): splits.append(fip.split('/')[-1]) else: print("Don't Recognize {}".format(datasets_name)) return splits def img_path(self, prefix, filename): return os.path.join(prefix, filename) # Case for loading images generated in multi-source cycle # In this case, you will generate fake_B_1 for cyclesynthia dataset and fake_B_2 for cyclegta5 def syn_label_path(self, filename): if filename.endswith('_fake_B_1.png'): return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B_1.png', '.png')) elif filename.endswith('_fake_B.png'): return os.path.join("/nfs/project/libo_i/MADAN/data/synthia", 'GT', 'parsed_LABELS', filename.replace('_fake_B.png', '.png')) def gta_label_path(self, filename): if filename.endswith('_fake_B_2.png'): return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B_2.png', '.png')) elif filename.endswith('_fake_B.png'): return os.path.join('/nfs/project/libo_i/MADAN/data/cyclegta5', 'labels', filename.replace('_fake_B.png', '.png')) def __getitem__(self, index, debug=False): # we iteratively load images from cyclesynthia and cyclegta5 if index % 2: id = self.syn_ids[index % len(self.syn_ids)] img_path = self.img_path(self.syn_name, id) label_path = self.syn_label_path(id) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.remap_labels: target = np.asarray(target) target = syn_relabel(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) else: id = self.gta_ids[index % len(self.gta_ids)] img_path = self.img_path(self.gta_name, id) label_path = self.gta_label_path(id) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.remap_labels: target = np.asarray(target) target = remap_labels_to_train_ids(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) # if debug: # print(self.__class__.__name__) # print("IMG Path: {}".format(img_path)) # print("Label Path: {}".format(label_path)) # return img, target def __len__(self): return len(self.syn_ids) + len(self.gta_ids) ================================================ FILE: cycada/data/data_loader.py ================================================ from __future__ import print_function import os from os.path import join import numpy as np import torch import torch.utils.data as data from PIL import Image from torchvision import transforms from ..util import to_tensor_raw def load_data(name, dset, batch=64, rootdir='', num_channels=3, image_size=32, download=True, kwargs={}): is_train = (dset == 'train') if isinstance(name, list) and len(name) == 2: # load adda data src_dataset = get_dataset(name[0], join(rootdir, name[0]), dset, image_size, num_channels, download=download) tgt_dataset = get_dataset(name[1], join(rootdir, name[1]), dset, image_size, num_channels, download=download) dataset = AddaDataset(src_dataset, tgt_dataset) else: dataset = get_dataset(name, rootdir, dset, image_size, num_channels, download=download) if len(dataset) == 0: return None loader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=is_train, **kwargs) return loader def get_transform_dataset(dataset_name, rootdir, net_transform, downscale, resize=None, src_data_flag=None, small=False): user_paths = os.environ['PYTHONPATH'].split(os.pathsep) transform, target_transform = get_transform2(dataset_name, net_transform, downscale, resize) return get_fcn_dataset(dataset_name, rootdir, transform=transform, target_transform=target_transform, data_flag=src_data_flag, small=small) sizes = {'cyclesynthia_cyclegta5': 1280, 'cyclesynthia': 1280, 'cityscapes': 1280, 'gta5': 1280, 'cyclegta5': 1280, "synthia": 1280} def get_orig_size(dataset_name): "Size of images in the dataset for relative scaling." try: return sizes[dataset_name] except: raise Exception('Unknown dataset size:', dataset_name) def get_transform2(dataset_name, net_transform, downscale, resize): "Returns image and label transform to downscale, crop and prepare for net." orig_size = get_orig_size(dataset_name) transform = [] target_transform = [] if downscale is not None: transform.append(transforms.Resize(orig_size // downscale)) target_transform.append(transforms.Resize(orig_size // downscale, interpolation=Image.NEAREST)) if resize is not None: transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.BICUBIC)]) target_transform.extend([transforms.Resize([int(resize), int(int(resize) * 1.8)], interpolation=Image.NEAREST)]) transform.extend([net_transform]) target_transform.extend([to_tensor_raw]) transform = transforms.Compose(transform) target_transform = transforms.Compose(target_transform) return transform, target_transform def get_transform(params, image_size, num_channels): # Transforms for PIL Images: Gray <-> RGB Gray2RGB = transforms.Lambda(lambda x: x.convert('RGB')) RGB2Gray = transforms.Lambda(lambda x: x.convert('L')) transform = [] # Does size request match original size? if not image_size == params.image_size: transform.append(transforms.Resize(image_size)) # Does number of channels requested match original? if not num_channels == params.num_channels: if num_channels == 1: transform.append(RGB2Gray) elif num_channels == 3: transform.append(Gray2RGB) else: print('NumChannels should be 1 or 3', num_channels) raise Exception transform += [transforms.ToTensor(), transforms.Normalize((params.mean,), (params.std,))] return transforms.Compose(transform) def get_target_transform(params): transform = params.target_transform t_uniform = transforms.Lambda(lambda x: x[:, 0] if isinstance(x, (list, np.ndarray)) and len(x) == 2 else x) if transform is None: return t_uniform else: return transforms.Compose([transform, t_uniform]) class AddaDataset(data.Dataset): def __init__(self, src_data, tgt_data): self.src = src_data self.tgt = tgt_data def __getitem__(self, index): ns = len(self.src) nt = len(self.tgt) xs, ys = self.src[index % ns] xt, yt = self.tgt[index % nt] return (xs, ys), (xt, yt) def __len__(self): return min(len(self.src), len(self.tgt)) data_params = {} def register_data_params(name): def decorator(cls): data_params[name] = cls return cls return decorator dataset_obj = {} def register_dataset_obj(name): def decorator(cls): dataset_obj[name] = cls return cls return decorator class DatasetParams(object): "Class variables defined." num_channels = 1 image_size = 16 mean = 0.1307 std = 0.3081 num_cls = 10 target_transform = None def get_dataset(name, rootdir, dset, image_size, num_channels, download=True): is_train = (dset == 'train') print('get dataset:', name, rootdir, dset) params = data_params[name] transform = get_transform(params, image_size, num_channels) target_transform = get_target_transform(params) return dataset_obj[name](rootdir, train=is_train, transform=transform, target_transform=target_transform, download=download) def get_fcn_dataset(name, rootdir, **kwargs): return dataset_obj[name](rootdir, **kwargs) ================================================ FILE: cycada/data/gta5.py ================================================ import os.path import numpy as np import scipy.io import torch.utils.data as data from PIL import Image from .cityscapes import id2label as LABEL2TRAIN, remap_labels_to_train_ids from .data_loader import DatasetParams, register_data_params, register_dataset_obj @register_data_params('gta5') class GTA5Params(DatasetParams): num_channels = 3 image_size = 1024 mean = 0.5 std = 0.5 num_cls = 19 target_transform = None @register_dataset_obj('gta5') class GTA5(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None): self.root = root self.split = split self.remap_labels = remap_labels self.data_flag = data_flag self.ids = self.collect_ids() self.transform = transform self.target_transform = target_transform m = scipy.io.loadmat(os.path.join(self.root, 'mapping.mat')) full_classes = [x[0] for x in m['classes'][0]] self.classes = [] for old_id, new_id in LABEL2TRAIN.items(): if not new_id == 255 and old_id > 0: self.classes.append(full_classes[old_id]) self.num_cls = num_cls def collect_ids(self): splits = scipy.io.loadmat(os.path.join(self.root, 'split.mat')) ids = splits['{}Ids'.format(self.split)].squeeze() return ids def img_path(self, id): filename = '{:05d}.png'.format(id) return os.path.join(self.root, 'images', filename) def label_path(self, id): filename = '{:05d}.png'.format(id) return os.path.join(self.root, 'labels', filename) def __getitem__(self, index, debug=False): id = self.ids[index] img_path = self.img_path(id) label_path = self.label_path(id) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.remap_labels: target = np.asarray(target) target = remap_labels_to_train_ids(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) ================================================ FILE: cycada/data/rotater.py ================================================ class Rotater(object): def __init__(self, dataset, orientations=6, transform=None, target_transform=None): self.dataset = dataset self.orientations = orientations self.transform = transform self.target_transform = target_transform def __getitem__(self, index): im, target = self.dataset[index] rotation = index % self.orientations degrees = 360 / self.orientations * rotation im = im.rotate(degrees) if self.transform is not None: im = self.transform(im) if self.target_transform is not None: target = self.target_transform(target) return im, target, degrees def __len__(self): return len(self.dataset) ================================================ FILE: cycada/data/synthia.py ================================================ import os.path import numpy as np import torch.utils.data as data from PIL import Image from .util import classes, ignore_label, id2label from .data_loader import DatasetParams, register_data_params, register_dataset_obj def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out @register_data_params('synthia') class SYNTHIAParams(DatasetParams): num_channels = 3 image_size = 1024 mean = 0.5 std = 0.5 num_cls = 19 target_transform = None @register_dataset_obj('synthia') class SYNTHIA(data.Dataset): def __init__(self, root, num_cls=19, split='train', remap_labels=True, transform=None, target_transform=None, data_flag=None, small=2): self.root = root self.split = split self.small = small self.remap_labels = remap_labels self.ids = self.collect_ids() self.transform = transform self.target_transform = target_transform self.classes = classes self.num_cls = num_cls self.data_flag = data_flag def collect_ids(self): splits = [] with open(os.path.join(self.root, 'SYNTHIA_imagelist_{}.txt'.format(self.split))) as f: for line in f: line = line.strip('\n') splits.append(line.split('/')[-1]) return splits def img_path(self, filename): if self.small == 0: return os.path.join(self.root, 'RGB_300x540', filename) elif self.small == 1: return os.path.join(self.root, 'RGB_600x1080', filename) else: return os.path.join(self.root, 'RGB', filename) def label_path(self, filename): if self.small == 0: return os.path.join(self.root, 'GT', 'parsed_LABELS_300x540', filename) elif self.small == 1: return os.path.join(self.root, 'GT', 'parsed_LABELS_600x1080', filename) else: return os.path.join(self.root, 'GT', 'parsed_LABELS', filename) def __getitem__(self, index, debug=False): id = self.ids[index] img_path = self.img_path(id) label_path = self.label_path(id) if debug: print(self.__class__.__name__) print("IMG Path: {}".format(img_path)) print("Label Path: {}".format(label_path)) img = Image.open(img_path).convert('RGB') if self.transform is not None: img = self.transform(img) target = Image.open(label_path) if self.remap_labels: target = np.asarray(target) target = syn_relabel(target) target = Image.fromarray(target, 'L') if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.ids) ================================================ FILE: cycada/data/util.py ================================================ import logging import os.path import requests logger = logging.getLogger(__name__) ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] def maybe_download(url, dest): """Download the url to dest if necessary, optionally checking file integrity. """ if not os.path.exists(dest): logger.info('Downloading %s to %s', url, dest) download(url, dest) def download(url, dest): """Download the url to dest, overwriting dest if it already exists.""" response = requests.get(url, stream=True) with open(dest, 'wb') as f: for chunk in response.iter_content(chunk_size=1024): if chunk: f.write(chunk) ================================================ FILE: cycada/logging.yml ================================================ --- version: 1 disable_existing_loggers: False formatters: simple: format: "[%(asctime)s] %(levelname)-8s %(message)s" color: class: colorlog.ColoredFormatter format: "[%(asctime)s] %(log_color)s%(levelname)-8s%(reset)s %(message)s" log_colors: DEBUG: "cyan" INFO: "green" WARNING: "yellow" ERROR: "red" CRITICAL: "red,bg_white" handlers: console: class: cycada.util.TqdmHandler level: INFO formatter: color file_handler: class: logging.FileHandler level: INFO formatter: simple encoding: utf8 root: level: INFO handlers: [console, file_handler] ================================================ FILE: cycada/models/MDAN.py ================================================ #!/usr/bin/env python # -*- coding: utf-8 -*- import logging import torch import torch.nn as nn import torch.nn.functional as F logger = logging.getLogger(__name__) class GradientReversalLayer(torch.autograd.Function): """ Implement the gradient reversal layer for the convenience of domain adaptation neural network. The forward part is the identity function while the backward part is the negative function. """ def forward(self, inputs): return inputs def backward(self, grad_output): grad_input = grad_output.clone() grad_input = -grad_input return grad_input class MDANet(nn.Module): """ Multi-layer perceptron with adversarial regularizer by domain classification. """ def __init__(self, configs): super(MDANet, self).__init__() self.pooling_layer = nn.AdaptiveAvgPool2d((2, 2)) self.dim_reduction = nn.Conv2d(4096, 512, kernel_size=1) nn.init.xavier_normal_(self.dim_reduction.weight) nn.init.constant_(self.dim_reduction.bias, 0.1) self.input_dim = configs["input_dim"] self.num_hidden_layers = len(configs["hidden_layers"]) self.num_neurons = [] + [self.input_dim] + configs["hidden_layers"] self.num_domains = configs["num_domains"] # Parameters of hidden, fully-connected layers, feature learning component. self.hiddens = nn.ModuleList([nn.Linear(self.num_neurons[i], self.num_neurons[i + 1]) for i in range(self.num_hidden_layers)]) # Parameter of the final softmax classification layer. self.softmax = nn.Linear(self.num_neurons[-1], configs["num_classes"]) # Parameter of the domain classification layer, multiple sources single target domain adaptation. self.domains = nn.ModuleList([nn.Linear(self.num_neurons[-1], 2) for _ in range(self.num_domains)]) # Gradient reversal layer. self.grls = [GradientReversalLayer() for _ in range(self.num_domains)] def forward(self, sinputs_syn, sinputs_gta, tinputs): """ :param sinputs: A list of k inputs from k source domains. :param tinputs: Input from the target domain. :return: """ sinputs_gta = self.pooling_layer(sinputs_gta) sinputs_syn = self.pooling_layer(sinputs_syn) tinputs = self.pooling_layer(tinputs) sinputs_gta = self.dim_reduction(sinputs_gta) sinputs_syn = self.dim_reduction(sinputs_syn) tinputs = self.dim_reduction(tinputs) b = sinputs_gta.size()[0] syn_relu, gta_relu, th_relu = sinputs_syn.view(b, -1), sinputs_gta.view(b, -1), tinputs.view(b, -1) assert (syn_relu[0].size()[0] == self.input_dim) for hidden in self.hiddens: syn_relu = F.relu(hidden(syn_relu)) gta_relu = F.relu(hidden(gta_relu)) for hidden in self.hiddens: th_relu = F.relu(hidden(th_relu)) # Classification probabilities on k source domains. logprobs = [] logprobs.append(F.log_softmax(self.softmax(syn_relu), dim=1)) logprobs.append(F.log_softmax(self.softmax(gta_relu), dim=1)) # Domain classification accuracies. sdomains, tdomains = [], [] sdomains.append(F.log_softmax(self.domains[0](self.grls[0](syn_relu)), dim=1)) tdomains.append(F.log_softmax(self.domains[0](self.grls[0](th_relu)), dim=1)) sdomains.append(F.log_softmax(self.domains[1](self.grls[1](gta_relu)), dim=1)) tdomains.append(F.log_softmax(self.domains[1](self.grls[1](th_relu)), dim=1)) return logprobs, sdomains, tdomains def inference(self, inputs): h_relu = inputs for hidden in self.hiddens: h_relu = F.relu(hidden(h_relu)) # Classification probability. logprobs = F.log_softmax(self.softmax(h_relu), dim=1) return logprobs ================================================ FILE: cycada/models/__init__.py ================================================ from .models import get_model from .task_net import LeNet from .task_net import DTNClassifier from .adda_net import AddaNet from .fcn8s import VGG16_FCN8s, Discriminator from .drn import drn26 ================================================ FILE: cycada/models/adda_net.py ================================================ import numpy as np import torch import torch.nn as nn from torch.nn import init from .util import init_weights from .models import register_model, get_model @register_model('AddaNet') class AddaNet(nn.Module): "Defines and Adda Network." def __init__(self, num_cls=10, model='LeNet', src_weights_init=None, weights_init=None): super(AddaNet, self).__init__() self.name = 'AddaNet' self.base_model = model self.num_cls = num_cls self.cls_criterion = nn.CrossEntropyLoss() self.gan_criterion = nn.CrossEntropyLoss() self.setup_net() if weights_init is not None: self.load(weights_init) elif src_weights_init is not None: self.load_src_net(src_weights_init) else: raise Exception('AddaNet must be initialized with weights.') def forward(self, x_s, x_t): """Pass source and target images through their respective networks.""" score_s, x_s = self.src_net(x_s, with_ft=True) score_t, x_t = self.tgt_net(x_t, with_ft=True) if self.discrim_feat: d_s = self.discriminator(x_s) d_t = self.discriminator(x_t) else: d_s = self.discriminator(score_s) d_t = self.discriminator(score_t) return score_s, score_t, d_s, d_t def setup_net(self): """Setup source, target and discriminator networks.""" self.src_net = get_model(self.base_model, num_cls=self.num_cls) self.tgt_net = get_model(self.base_model, num_cls=self.num_cls) input_dim = self.num_cls self.discriminator = nn.Sequential( nn.Linear(input_dim, 500), nn.ReLU(), nn.Linear(500, 500), nn.ReLU(), nn.Linear(500, 2), ) self.image_size = self.src_net.image_size self.num_channels = self.src_net.num_channels def load(self, init_path): "Loads full src and tgt models." net_init_dict = torch.load(init_path) self.load_state_dict(net_init_dict) def load_src_net(self, init_path): """Initialize source and target with source weights.""" self.src_net.load(init_path) self.tgt_net.load(init_path) def save(self, out_path): torch.save(self.state_dict(), out_path) def save_tgt_net(self, out_path): torch.save(self.tgt_net.state_dict(), out_path) ================================================ FILE: cycada/models/drn.py ================================================ import math import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo import torchvision from .models import register_model from ..util import safe_load_state_dict __all__ = ['DRN', 'drn26', 'drn42', 'drn58'] model_urls = { 'drn26': 'https://tigress-web.princeton.edu/~fy/drn/models/drn26-ddedf421.pth', 'drn42': 'https://tigress-web.princeton.edu/~fy/drn/models/drn42-9d336e8c.pth', 'drn58': 'https://tigress-web.princeton.edu/~fy/drn/models/drn58-0a53a92c.pth' } def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding, bias=False, dilation=dilation) class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride, padding=dilation[0], dilation=dilation[0]) self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes, padding=dilation[1], dilation=dilation[1]) self.bn2 = nn.BatchNorm2d(planes) self.downsample = downsample self.stride = stride self.residual = residual def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) if self.residual: out += residual out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=(1, 1), residual=True): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation[1], bias=False, dilation=dilation[1]) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class DRN(nn.Module): transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __init__(self, block, layers, num_cls=1000, channels=(16, 32, 64, 128, 256, 512, 512, 512), out_map=False, out_middle=False, pool_size=28, weights_init=None, pretrained=True, finetune=False, output_last_ft=False, modelname='drn26'): if output_last_ft: print('DRN discrim feat not implemented, using scores') super(DRN, self).__init__() self.inplanes = channels[0] self.output_last_ft = output_last_ft self.out_map = out_map self.out_dim = channels[-1] self.out_middle = out_middle self.conv1 = nn.Conv2d(3, channels[0], kernel_size=7, stride=1, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(channels[0]) self.relu = nn.ReLU(inplace=True) self.layer1 = self._make_layer(BasicBlock, channels[0], layers[0], stride=1) self.layer2 = self._make_layer(BasicBlock, channels[1], layers[1], stride=2) self.layer3 = self._make_layer(block, channels[2], layers[2], stride=2) self.layer4 = self._make_layer(block, channels[3], layers[3], stride=2) self.layer5 = self._make_layer(block, channels[4], layers[4], dilation=2, new_level=False) self.layer6 = None if layers[5] == 0 else \ self._make_layer(block, channels[5], layers[5], dilation=4, new_level=False) self.layer7 = None if layers[6] == 0 else \ self._make_layer(BasicBlock, channels[6], layers[6], dilation=2, new_level=False, residual=False) self.layer8 = None if layers[7] == 0 else \ self._make_layer(BasicBlock, channels[7], layers[7], dilation=1, new_level=False, residual=False) if num_cls > 0: self.avgpool = nn.AvgPool2d(pool_size) # self.fc = nn.Linear(self.out_dim, num_classes) self.fc = nn.Conv2d(self.out_dim, num_cls, kernel_size=1, stride=1, padding=0, bias=True) for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() if pretrained: if not weights_init is None: state_dict = torch.load(weights_init) print('Using state dict from', weights_init) else: state_dict = model_zoo.load_url(model_urls[modelname]) if finetune: del state_dict['fc.weight'] del state_dict['fc.bias'] safe_load_state_dict(self, state_dict) print('Finetune: remove last layer') else: self.load_state_dict(state_dict) print('Loading full model') def _make_layer(self, block, planes, blocks, stride=1, dilation=1, new_level=True, residual=True): assert dilation == 1 or dilation % 2 == 0 downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block( self.inplanes, planes, stride, downsample, dilation=(1, 1) if dilation == 1 else ( dilation // 2 if new_level else dilation, dilation), residual=residual)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, residual=residual, dilation=(dilation, dilation))) return nn.Sequential(*layers) def forward(self, x): _, _, h, w = x.size() y = list() x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.layer1(x) y.append(x) x = self.layer2(x) y.append(x) x = self.layer3(x) y.append(x) x = self.layer4(x) y.append(x) x = self.layer5(x) y.append(x) if self.layer6 is not None: x = self.layer6(x) y.append(x) if self.layer7 is not None: x = self.layer7(x) y.append(x) if self.layer8 is not None: x = self.layer8(x) y.append(x) if self.output_last_ft: ft_to_save = x if self.out_map: x = self.fc(x) x = nn.functional.interpolate(x, (h, w), mode='bilinear', align_corners=True) else: x = self.avgpool(x) x = self.fc(x) x = x.view(x.size(0), -1) if self.out_middle: return x, y elif self.output_last_ft: return x, ft_to_save else: return x @register_model('drn26') def drn26(pretrained=True, finetune=False, out_map=True, **kwargs): model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], modelname='drn26', out_map=out_map, finetune=finetune, **kwargs) # if pretrained: # state_dict = model_zoo.load_url(model_urls['drn26']) # if finetune: # del state_dict['fc.weight'] # del state_dict['fc.bias'] # safe_load_state_dict(model, state_dict) # else: # model.load_state_dict(state_dict) return model @register_model('drn42') def drn42(pretrained=False, finetune=False, out_map=True, **kwargs): model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], modelname='drn42', out_map=out_map, finetune=finetune, **kwargs) # if pretrained: # model.load_state_dict(model_zoo.load_url(model_urls['drn42'])) return model def drn58(pretrained=False, **kwargs): model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['drn58'])) return model ================================================ FILE: cycada/models/fcn8s.py ================================================ import numpy as np import torch import torch.nn.functional as F import torchvision from torch import nn from torch.autograd import Variable from torch.nn import init from torch.utils import model_zoo from torchvision.models import vgg from .models import register_model def get_upsample_filter(size): """Make a 2D bilinear kernel suitable for upsampling""" factor = (size + 1) // 2 if size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:size, :size] filter = (1 - abs(og[0] - center) / factor) * \ (1 - abs(og[1] - center) / factor) return torch.from_numpy(filter).float() class Bilinear(nn.Module): def __init__(self, factor, num_channels): super().__init__() self.factor = factor filter = get_upsample_filter(factor * 2) w = torch.zeros(num_channels, num_channels, factor * 2, factor * 2) for i in range(num_channels): w[i, i] = filter self.register_buffer('w', w) def forward(self, x): return F.conv_transpose2d(x, Variable(self.w), stride=self.factor) @register_model('fcn8s') class VGG16_FCN8s(nn.Module): transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def __init__(self, num_cls=19, pretrained=True, weights_init=None, output_last_ft=False): super().__init__() self.output_last_ft = output_last_ft if weights_init: batch_norm = False else: batch_norm = True self.vgg = make_layers(vgg.cfg['D'], batch_norm=False) self.vgg_head = nn.Sequential( nn.Conv2d(512, 4096, 7), nn.ReLU(inplace=True), nn.Dropout2d(p=0.5), nn.Conv2d(4096, 4096, 1), nn.ReLU(inplace=True), nn.Dropout2d(p=0.5), nn.Conv2d(4096, num_cls, 1) ) self.upscore2 = self.upscore_pool4 = Bilinear(2, num_cls) self.upscore8 = Bilinear(8, num_cls) self.score_pool4 = nn.Conv2d(512, num_cls, 1) for param in self.score_pool4.parameters(): # init.constant(param, 0) init.constant_(param, 0) self.score_pool3 = nn.Conv2d(256, num_cls, 1) for param in self.score_pool3.parameters(): # init.constant(param, 0) init.constant_(param, 0) if pretrained: if weights_init is not None: self.load_weights(torch.load(weights_init)) else: self.load_base_weights() def load_base_vgg(self, weights_state_dict): vgg_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg.') self.vgg.load_state_dict(vgg_state_dict) def load_vgg_head(self, weights_state_dict): vgg_head_state_dict = self.get_dict_by_prefix(weights_state_dict, 'vgg_head.') self.vgg_head.load_state_dict(vgg_head_state_dict) def get_dict_by_prefix(self, weights_state_dict, prefix): return {k[len(prefix):]: v for k, v in weights_state_dict.items() if k.startswith(prefix)} def load_weights(self, weights_state_dict): self.load_base_vgg(weights_state_dict) self.load_vgg_head(weights_state_dict) def split_vgg_head(self): self.classifier = list(self.vgg_head.children())[-1] self.vgg_head_feat = nn.Sequential(*list(self.vgg_head.children())[:-1]) def forward(self, x): input = x x = F.pad(x, (99, 99, 99, 99), mode='constant', value=0) intermediates = {} fts_to_save = {16: 'pool3', 23: 'pool4'} for i, module in enumerate(self.vgg): x = module(x) if i in fts_to_save: intermediates[fts_to_save[i]] = x ft_to_save = 5 # Dropout before classifier last_ft = {} for i, module in enumerate(self.vgg_head): x = module(x) if i == ft_to_save: last_ft = x _, _, h, w = x.size() upscore2 = self.upscore2(x) pool4 = intermediates['pool4'] score_pool4 = self.score_pool4(0.01 * pool4) score_pool4c = _crop(score_pool4, upscore2, offset=5) fuse_pool4 = upscore2 + score_pool4c upscore_pool4 = self.upscore_pool4(fuse_pool4) pool3 = intermediates['pool3'] score_pool3 = self.score_pool3(0.0001 * pool3) score_pool3c = _crop(score_pool3, upscore_pool4, offset=9) fuse_pool3 = upscore_pool4 + score_pool3c upscore8 = self.upscore8(fuse_pool3) score = _crop(upscore8, input, offset=31) if self.output_last_ft: return score, last_ft else: return score def load_base_weights(self): """This is complicated because we converted the base model to be fully convolutional, so some surgery needs to happen here.""" base_state_dict = model_zoo.load_url(vgg.model_urls['vgg16']) vgg_state_dict = {k[len('features.'):]: v for k, v in base_state_dict.items() if k.startswith('features.')} self.vgg.load_state_dict(vgg_state_dict) vgg_head_params = self.vgg_head.parameters() for k, v in base_state_dict.items(): if not k.startswith('classifier.'): continue if k.startswith('classifier.6.'): # skip final classifier output continue vgg_head_param = next(vgg_head_params) vgg_head_param.data = v.view(vgg_head_param.size()) class VGG16_FCN8s_caffe(VGG16_FCN8s): transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.458, 0.408], std=[0.00392156862745098] * 3), torchvision.transforms.Lambda( lambda x: torch.stack(torch.unbind(x, 1)[::-1], 1)) ]) def load_base_weights(self): base_state_dict = model_zoo.load_url('https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg16-00b39a1b.pth') vgg_state_dict = {k[len('features.'):]: v for k, v in base_state_dict.items() if k.startswith('features.')} self.vgg.load_state_dict(vgg_state_dict) vgg_head_params = self.vgg_head.parameters() for k, v in base_state_dict.items(): if not k.startswith('classifier.'): continue if k.startswith('classifier.6.'): # skip final classifier output continue vgg_head_param = next(vgg_head_params) vgg_head_param.data = v.view(vgg_head_param.size()) class Discriminator(nn.Module): def __init__(self, input_dim=4096, output_dim=2, pretrained=False, weights_init=''): super().__init__() dim1 = 1024 if input_dim == 4096 else 512 dim2 = int(dim1 / 2) self.D = nn.Sequential( nn.Conv2d(input_dim, dim1, 1), nn.Dropout2d(p=0.5), nn.ReLU(inplace=True), nn.Conv2d(dim1, dim2, 1), nn.Dropout2d(p=0.5), nn.ReLU(inplace=True), nn.Conv2d(dim2, output_dim, 1) ) if pretrained and weights_init is not None: self.load_weights(weights_init) def forward(self, x): d_score = self.D(x) return d_score def load_weights(self, weights): print('Loading discriminator weights') self.load_state_dict(torch.load(weights)) class Transform_Module(nn.Module): def __init__(self, input_dim=4096): super().__init__() self.transform = nn.Sequential( nn.Conv2d(input_dim, input_dim, 1), nn.ReLU(inplace=True), # nn.Conv2d(input_dim, input_dim, 1), # nn.ReLU(inplace=True), ) for m in self.modules(): if isinstance(m, nn.Conv2d): init_eye(m.weight) m.bias.data.zero_() def forward(self, x): t_x = self.transform(x) return t_x def init_eye(tensor): if isinstance(tensor, Variable): init_eye(tensor.data) return tensor return tensor.copy_(torch.eye(tensor.size(0), tensor.size(1))) def _crop(input, shape, offset=0): _, _, h, w = shape.size() return input[:, :, offset:offset + h, offset:offset + w].contiguous() def make_layers(cfg, batch_norm=False): """This is almost verbatim from torchvision.models.vgg, except that the MaxPool2d modules are configured with ceil_mode=True. """ layers = [] in_channels = 3 for v in cfg: if v == 'M': layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)) else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) modules = [conv2d, nn.ReLU(inplace=True)] if batch_norm: modules.insert(1, nn.BatchNorm2d(v)) layers.extend(modules) in_channels = v return nn.Sequential(*layers) ================================================ FILE: cycada/models/models.py ================================================ import torch models = {} def register_model(name): def decorator(cls): models[name] = cls return cls return decorator def get_model(name, num_cls=10, **args): net = models[name](num_cls=num_cls, **args) if torch.cuda.is_available(): net = net.cuda() return net ================================================ FILE: cycada/models/task_net.py ================================================ import torch import torch.nn as nn from torch.nn import init from .models import register_model from .util import init_weights import numpy as np class TaskNet(nn.Module): num_channels = 3 image_size = 32 name = 'TaskNet' "Basic class which does classification." def __init__(self, num_cls=10, weights_init=None): super(TaskNet, self).__init__() self.num_cls = num_cls self.setup_net() self.criterion = nn.CrossEntropyLoss() if weights_init is not None: self.load(weights_init) else: init_weights(self) def forward(self, x, with_ft=False): x = self.conv_params(x) x = x.view(x.size(0), -1) x = self.fc_params(x) score = self.classifier(x) if with_ft: return score, x else: return score def setup_net(self): """Method to be implemented in each class.""" pass def load(self, init_path): net_init_dict = torch.load(init_path) self.load_state_dict(net_init_dict) def save(self, out_path): torch.save(self.state_dict(), out_path) @register_model('LeNet') class LeNet(TaskNet): "Network used for MNIST or USPS experiments." num_channels = 1 image_size = 28 name = 'LeNet' out_dim = 500 # dim of last feature layer def setup_net(self): self.conv_params = nn.Sequential( nn.Conv2d(self.num_channels, 20, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Conv2d(20, 50, kernel_size=5), nn.Dropout2d(p=0.5), nn.MaxPool2d(2), nn.ReLU(), ) self.fc_params = nn.Linear(50*4*4, 500) self.classifier = nn.Sequential( nn.ReLU(), nn.Dropout(p=0.5), nn.Linear(500, self.num_cls) ) @register_model('DTN') class DTNClassifier(TaskNet): "Classifier used for SVHN->MNIST Experiment" num_channels = 3 image_size = 32 name = 'DTN' out_dim = 512 # dim of last feature layer def setup_net(self): self.conv_params = nn.Sequential ( nn.Conv2d(self.num_channels, 64, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(64), nn.Dropout2d(0.1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(128), nn.Dropout2d(0.3), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(256), nn.Dropout2d(0.5), nn.ReLU() ) self.fc_params = nn.Sequential ( nn.Linear(256*4*4, 512), nn.BatchNorm1d(512), ) self.classifier = nn.Sequential( nn.ReLU(), nn.Dropout(), nn.Linear(512, self.num_cls) ) ================================================ FILE: cycada/models/util.py ================================================ import torch.nn as nn from torch.nn import init def init_weights(obj): for m in obj.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): init.xavier_normal_(m.weight) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): m.reset_parameters() ================================================ FILE: cycada/tools/__init__.py ================================================ ================================================ FILE: cycada/tools/train_adda_net.py ================================================ from __future__ import print_function import os from os.path import join import numpy as np # Import from torch import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable # Import from within Package from ..models.models import get_model from ..data.data_loader import load_data from ..tools.test_task_net import test from ..tools.util import make_variable def train(loader_src, loader_tgt, net, opt_net, opt_dis, epoch): log_interval = 100 # specifies how often to display N = min(len(loader_src.dataset), len(loader_tgt.dataset)) joint_loader = zip(loader_src, loader_tgt) net.train() last_update = -1 for batch_idx, ((data_s, _), (data_t, _)) in enumerate(joint_loader): # log basic adda train info info_str = "[Train Adda] Epoch: {} [{}/{} ({:.2f}%)]".format( epoch, batch_idx*len(data_t), N, 100 * batch_idx / N) ######################## # Setup data variables # ######################## data_s = make_variable(data_s, requires_grad=False) data_t = make_variable(data_t, requires_grad=False) ########################## # Optimize discriminator # ########################## # zero gradients for optimizer opt_dis.zero_grad() # extract and concat features score_s = net.src_net(data_s) score_t = net.tgt_net(data_t) f = torch.cat((score_s, score_t), 0) # predict with discriminator pred_concat = net.discriminator(f) # prepare real and fake labels: source=1, target=0 target_dom_s = make_variable(torch.ones(len(data_s)).long(), requires_grad=False) target_dom_t = make_variable(torch.zeros(len(data_t)).long(), requires_grad=False) label_concat = torch.cat((target_dom_s, target_dom_t), 0) # compute loss for disciminator loss_dis = net.gan_criterion(pred_concat, label_concat) loss_dis.backward() # optimize discriminator opt_dis.step() # compute discriminator acc pred_dis = torch.squeeze(pred_concat.max(1)[1]) acc = (pred_dis == label_concat).float().mean() # log discriminator update info info_str += " acc: {:0.1f} D: {:.3f}".format(acc.item()*100, loss_dis.item()) ########################### # Optimize target network # ########################### # only update net if discriminator is strong if acc.item() > 0.6: last_update = batch_idx # zero out optimizer gradients opt_dis.zero_grad() opt_net.zero_grad() # extract target features score_t = net.tgt_net(data_t) # predict with discriinator pred_tgt = net.discriminator(score_t) # create fake label label_tgt = make_variable(torch.ones(pred_tgt.size(0)).long(), requires_grad=False) # compute loss for target network loss_gan_t = net.gan_criterion(pred_tgt, label_tgt) loss_gan_t.backward() # optimize tgt network opt_net.step() # log net update info info_str += " G: {:.3f}".format(loss_gan_t.item()) ########### # Logging # ########### if batch_idx % log_interval == 0: print(info_str) return last_update def train_adda(src, tgt, model, num_cls, num_epoch=200, batch=128, datadir="", outdir="", src_weights=None, weights=None, lr=1e-5, betas=(0.9,0.999), weight_decay=0): """Main function for training ADDA.""" ########################### # Setup cuda and networks # ########################### # setup cuda if torch.cuda.is_available(): kwargs = {'num_workers': 1, 'pin_memory': True} else: kwargs = {} # setup network net = get_model('AddaNet', model=model, num_cls=num_cls, src_weights_init=src_weights) # print network and arguments print(net) print('Training Adda {} model for {}->{}'.format(model, src, tgt)) ####################################### # Setup data for training and testing # ####################################### train_src_data = load_data(src, 'train', batch=batch, rootdir=join(datadir, src), num_channels=net.num_channels, image_size=net.image_size, download=True, kwargs=kwargs) train_tgt_data = load_data(tgt, 'train', batch=batch, rootdir=join(datadir, tgt), num_channels=net.num_channels, image_size=net.image_size, download=True, kwargs=kwargs) ###################### # Optimization setup # ###################### net_param = net.tgt_net.parameters() opt_net = optim.Adam(net_param, lr=lr, weight_decay=weight_decay, betas=betas) opt_dis = optim.Adam(net.discriminator.parameters(), lr=lr, weight_decay=weight_decay, betas=betas) ############## # Train Adda # ############## for epoch in range(num_epoch): err = train(train_src_data, train_tgt_data, net, opt_net, opt_dis, epoch) if err == -1: print("No suitable discriminator") break ############## # Save Model # ############## os.makedirs(outdir, exist_ok=True) outfile = join(outdir, 'adda_{:s}_net_{:s}_{:s}.pth'.format( model, src, tgt)) print('Saving to', outfile) net.save(outfile) ================================================ FILE: cycada/tools/train_task_net.py ================================================ from __future__ import print_function import os from os.path import join import numpy as np import argparse # Import from torch import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms from torch.autograd import Variable # Import from Cycada Package from ..models.models import get_model from ..data.data_loader import load_data from .test_task_net import test from .util import make_variable def train_epoch(loader, net, opt_net, epoch): log_interval = 100 # specifies how often to display net.train() for batch_idx, (data, target) in enumerate(loader): # make data variables data = make_variable(data, requires_grad=False) target = make_variable(target, requires_grad=False) # zero out gradients opt_net.zero_grad() # forward pass score = net(data) loss = net.criterion(score, target) # backward pass loss.backward() # optimize classifier and representation opt_net.step() # Logging if batch_idx % log_interval == 0: print('[Train] Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(loader.dataset), 100. * batch_idx / len(loader), loss.item()), end="") pred = score.data.max(1)[1] correct = pred.eq(target.data).cpu().sum() acc = correct.item() / len(pred) * 100.0 print(' Acc: {:.2f}'.format(acc)) def train(data, datadir, model, num_cls, outdir='', num_epoch=100, batch=128, lr=1e-4, betas=(0.9, 0.999), weight_decay=0): """Train a classification net and evaluate on test set.""" # Setup GPU Usage if torch.cuda.is_available(): kwargs = {'num_workers': 1, 'pin_memory': True} else: kwargs = {} ############ # Load Net # ############ net = get_model(model, num_cls=num_cls) print('-------Training net--------') print(net) ############################ # Load train and test data # ############################ train_data = load_data(data, 'train', batch=batch, rootdir=datadir, num_channels=net.num_channels, image_size=net.image_size, download=True, kwargs=kwargs) test_data = load_data(data, 'test', batch=batch, rootdir=datadir, num_channels=net.num_channels, image_size=net.image_size, download=True, kwargs=kwargs) ################### # Setup Optimizer # ################### opt_net = optim.Adam(net.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) ######### # Train # ######### print('Training {} model for {}'.format(model, data)) for epoch in range(num_epoch): train_epoch(train_data, net, opt_net, epoch) ######## # Test # ######## if test_data is not None: print('Evaluating {}-{} model on {} test set'.format(model, data, data)) test(test_data, net) ############ # Save net # ############ os.makedirs(outdir, exist_ok=True) outfile = join(outdir, '{:s}_net_{:s}.pth'.format(model, data)) print('Saving to', outfile) net.save(outfile) return net ================================================ FILE: cycada/tools/util.py ================================================ from functools import partial import torch from torch.autograd import Variable def make_variable(tensor, volatile=False, requires_grad=True): if torch.cuda.is_available(): tensor = tensor.cuda() if volatile: requires_grad = False return Variable(tensor, volatile=volatile, requires_grad=requires_grad) def pairwise_distance(x, y): if not len(x.shape) == len(y.shape): raise ValueError('Both inputs should be matrices.') if x.shape[1] != y.shape[1]: raise ValueError('The number of features should be the same.') x = x.view(x.shape[0], x.shape[1], 1) y = torch.transpose(y, 0, 1) output = torch.sum((x - y) ** 2, 1) output = torch.transpose(output, 0, 1) return output def gaussian_kernel_matrix(x, y, sigmas): sigmas = sigmas.view(sigmas.shape[0], 1) beta = 1. / (2. * sigmas) dist = pairwise_distance(x, y).contiguous() dist_ = dist.view(1, -1) s = torch.matmul(beta, dist_) return torch.sum(torch.exp(-s), 0).view_as(dist) def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix): cost = torch.mean(kernel(x, x)) cost += torch.mean(kernel(y, y)) cost -= 2 * torch.mean(kernel(x, y)) return cost def mmd_loss(source_features, target_features): sigmas = [ 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 1e3, 1e4, 1e5, 1e6 ] gaussian_kernel = partial( gaussian_kernel_matrix, sigmas=Variable(torch.cuda.FloatTensor(sigmas)) ) loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel) loss_value = loss_value return loss_value ================================================ FILE: cycada/transforms.py ================================================ """These random transforms extend the transforms provided in torchvision to allow for transforming multiple images at the same time. This ensures that the images receive the same transformation, e.g. the provided images are either all mirrored or all left unchanged. For example, this is useful in segmentation tasks, where a transformation to the image necessitates that same transformation on the label. """ import numbers import random import torch import torchvision class RandomCrop(object): """Crops the given PIL.Image at a random location to have a region of the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) """ def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, tensors): output = [] h, w = None, None th, tw = self.size for tensor in tensors: if h is None and w is None: _, h, w = tensor.size() elif tensor.size()[-2:] != (h, w): print(tensor.size(), (h, w)) raise ValueError('Images must be same size') if w == tw and h == th: return tensors x1 = random.randint(0, w - tw) y1 = random.randint(0, h - th) for tensor in tensors: output.append(tensor[..., y1:y1 + th, x1:x1 + tw].contiguous()) return output class HalfCrop(object): """Crops halt the given PIL.Image randomly takes left or right to have a region of the given size. size can be a tuple (target_height, target_width) or an integer, in which case the target will be of a square shape (size, size) """ def __call__(self, tensors): output = [] th, tw = self.size tw_half = tw // 2 left_side = random.randint(0, 1) x1 = 0 + left_size * tw_half # random.randint(0, w - tw) for tensor in tensors: output.append(tensor[..., ..., x1:x1 + tw_half].contiguous()) return output class RandomHorizontalFlip(object): """Randomly horizontally flips the given PIL.Image with a probability of 0.5 """ def __call__(self, tensors): if random.random() < 0.5: output = [] for tensor in tensors: indices = torch.arange(tensor.size(-1) - 1, -1, -1).long() output.append(tensor.index_select(-1, indices)) return output return tensors def augment_collate(batch, crop=None, halfcrop=None, flip=True, resize=None): transforms = [] if crop is not None: transforms.append(RandomCrop(crop)) if halfcrop is not None: transforms.append(HalfCrop()) if flip: transforms.append(RandomHorizontalFlip()) transform = torchvision.transforms.Compose(transforms) batch = [transform(x) for x in batch] return torch.utils.data.dataloader.default_collate(batch) ================================================ FILE: cycada/util.py ================================================ import logging import logging.config import os.path from collections import OrderedDict import numpy as np import torch import yaml from torch.nn.parameter import Parameter from tqdm import tqdm class TqdmHandler(logging.StreamHandler): def __init__(self): logging.StreamHandler.__init__(self) def emit(self, record): msg = self.format(record) tqdm.write(msg) def config_logging(logfile=None): path = os.path.join(os.path.dirname(__file__), 'logging.yml') with open(path, 'r') as f: config = yaml.load(f.read()) if logfile is None: del config['handlers']['file_handler'] del config['root']['handlers'][-1] else: config['handlers']['file_handler']['filename'] = logfile logging.config.dictConfig(config) def to_tensor_raw(im): return torch.from_numpy(np.array(im, np.int64, copy=False)) def safe_load_state_dict(net, state_dict): """Copies parameters and buffers from :attr:`state_dict` into this module and its descendants. Any params in :attr:`state_dict` that do not match the keys returned by :attr:`net`'s :func:`state_dict()` method or have differing sizes are skipped. Arguments: state_dict (dict): A dict containing parameters and persistent buffers. """ own_state = net.state_dict() skipped = [] for name, param in state_dict.items(): if name not in own_state: skipped.append(name) continue if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data if own_state[name].size() != param.size(): skipped.append(name) continue own_state[name].copy_(param) if skipped: logging.info('Skipped loading some parameters: {}'.format(skipped)) def step_lr(optimizer, mult): for param_group in optimizer.param_groups: lr = param_group['lr'] param_group['lr'] = lr * mult ================================================ FILE: cyclegan/.gitignore ================================================ .DS_Store debug* checkpoints/ results/ build/ dist/ *.png torch.egg-info/ */**/__pycache__ torch/version.py torch/csrc/generic/TensorMethods.cpp torch/lib/*.so* torch/lib/*.dylib* torch/lib/*.h torch/lib/build torch/lib/tmp_install torch/lib/include torch/lib/torch_shm_manager torch/csrc/cudnn/cuDNN.cpp torch/csrc/nn/THNN.cwrap torch/csrc/nn/THNN.cpp torch/csrc/nn/THCUNN.cwrap torch/csrc/nn/THCUNN.cpp torch/csrc/nn/THNN_generic.cwrap torch/csrc/nn/THNN_generic.cpp torch/csrc/nn/THNN_generic.h docs/src/**/* test/data/legacy_modules.t7 test/data/gpu_tensors.pt test/htmlcov test/.coverage */*.pyc */**/*.pyc */**/**/*.pyc */**/**/**/*.pyc */**/**/**/**/*.pyc */*.so* */**/*.so* */**/*.dylib* test/data/legacy_serialized.pt *~ .idea ================================================ FILE: cyclegan/data/__init__.py ================================================ import sys import torch.utils.data from data.base_data_loader import BaseDataLoader sys.path.append('/nfs/project/libo_i/MADAN') from cycada.transforms import augment_collate def CreateDataLoader(opt): data_loader = CustomDatasetDataLoader() print(data_loader.name()) data_loader.initialize(opt) return data_loader def CreateDataset(opt): dataset = None if opt.dataset_mode == 'synthia_cityscapes': from data.synthia_cityscapes import SynthiaCityscapesDataset dataset = SynthiaCityscapesDataset() elif opt.dataset_mode == 'gta5_cityscapes': from data.gta5_cityscapes import GTAVCityscapesDataset dataset = GTAVCityscapesDataset() elif opt.dataset_mode == 'gta_synthia_cityscapes': from data.gta_synthia_cityscapes import GTASynthiaCityscapesDataset dataset = GTASynthiaCityscapesDataset() elif opt.dataset_mode == 'merged_gta_synthia_cityscapes': from data.merged_gta_synthia_cityscapes import MergedGTASynthiaCityscapesDataset dataset = MergedGTASynthiaCityscapesDataset() else: raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) print("dataset [%s] was created" % (dataset.name())) dataset.initialize(opt) return dataset class CustomDatasetDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) self.dataset = CreateDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads)) def load_data(self): return self def __len__(self): return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): for i, data in enumerate(self.dataloader): if i * self.opt.batchSize >= self.opt.max_dataset_size: break yield data ================================================ FILE: cyclegan/data/base_data_loader.py ================================================ class BaseDataLoader(): def __init__(self): pass def initialize(self, opt): self.opt = opt pass def load_data(): return None ================================================ FILE: cyclegan/data/base_dataset.py ================================================ import numpy as np import torch import torch.utils.data as data import torchvision.transforms as transforms from PIL import Image class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass # TODO: 增加crop的部分 def get_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': osize = [int(opt.loadSize), int(opt.loadSize)] transform_list.append(transforms.Resize(osize, interpolation=Image.BICUBIC)) transform_list.append(transforms.RandomCrop(opt.fineSize)) if opt.resize_or_crop == 'resize_only': osize = [int(opt.loadSize), int(opt.loadSize)] transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC)) elif opt.resize_or_crop == 'crop': transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'scale_width': transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC)) elif opt.resize_or_crop == 'scale_width_and_crop': transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.BICUBIC)) transform_list.append(transforms.RandomCrop(opt.fineSize)) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.RandomHorizontalFlip()) transform_list += [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def get_label_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST)) transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'resize_only': osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Resize(osize, interpolation=Image.NEAREST)) elif opt.resize_or_crop == 'crop': transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'scale_width': transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST)) elif opt.resize_or_crop == 'scale_width_and_crop': transform_list.append(transforms.Resize(opt.loadSize, interpolation=Image.NEAREST)) transform_list.append(transforms.RandomCrop(opt.fineSize)) # transform_list.append(transforms.RandomCrop(opt.fineSize)) if opt.isTrain and not opt.no_flip: transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.Lambda(lambda img: to_tensor_raw(img))) return transforms.Compose(transform_list) def __scale_width(img, target_width): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), Image.BICUBIC) def to_tensor_raw(im): return torch.from_numpy(np.array(im, np.int64, copy=False)) ================================================ FILE: cyclegan/data/cityscapes.py ================================================ import numpy as np ignore_label = 255 id2label = {-1: ignore_label, 0: ignore_label, 1: ignore_label, 2: ignore_label, 3: ignore_label, 4: ignore_label, 5: ignore_label, 6: ignore_label, 7: 0, 8: 1, 9: ignore_label, 10: ignore_label, 11: 2, 12: 3, 13: 4, 14: ignore_label, 15: ignore_label, 16: ignore_label, 17: 5, 18: ignore_label, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 29: ignore_label, 30: ignore_label, 31: 16, 32: 17, 33: 18} palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def remap_labels_to_train_ids(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out ================================================ FILE: cyclegan/data/gta5_cityscapes.py ================================================ import os.path import random import numpy as np from PIL import Image from data.base_dataset import BaseDataset, get_label_transform, get_transform from data.cityscapes import remap_labels_to_train_ids from data.image_folder import make_cs_labels, make_dataset ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] # This dataset is used to conduct GTA->CityScapes images transfer procedure. class GTAVCityscapesDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_A = os.path.join(opt.dataroot, 'gta5', 'images') self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit') self.dir_A_label = os.path.join(opt.dataroot, 'gta5', 'labels') self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine') self.A_paths = make_dataset(self.dir_A) self.B_paths = make_dataset(self.dir_B) self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) self.A_labels = make_dataset(self.dir_A_label) self.B_labels = make_cs_labels(self.dir_B_label) self.A_labels = sorted(self.A_labels) self.B_labels = sorted(self.B_labels) self.transform = get_transform(opt) self.label_transform = get_label_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] if self.opt.serial_batches: index_B = index % self.B_size else: index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] A_label_path = self.A_labels[index % self.A_size] B_label_path = self.B_labels[index_B] A_label = Image.open(A_label_path) B_label = Image.open(B_label_path) A_label = np.asarray(A_label) A_label = remap_labels_to_train_ids(A_label) A_label = Image.fromarray(A_label, 'L') B_label = np.asarray(B_label) B_label = remap_labels_to_train_ids(B_label) B_label = Image.fromarray(B_label, 'L') A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') A = self.transform(A_img) B = self.transform(B_img) A_label = self.label_transform(A_label) B_label = self.label_transform(B_label) # print(A_label.unique()) # print(B_label.unique()) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label} def __len__(self): return max(self.A_size, self.B_size) def name(self): return 'GTA5_Cityscapes' ================================================ FILE: cyclegan/data/gta_synthia_cityscapes.py ================================================ import os.path import random import numpy as np from PIL import Image from data.base_dataset import BaseDataset, get_label_transform, get_transform from data.cityscapes import remap_labels_to_train_ids from data.image_folder import make_cs_labels, make_dataset ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out # This dataset is used to conduct double cyclegan for both GTAV->CityScapes and Synthia->CityScapes class GTASynthiaCityscapesDataset(BaseDataset): def initialize(self, opt): # SYNTHIA as dataset 1 # GTAV as dataset 2 self.opt = opt self.root = opt.dataroot self.dir_A_1 = os.path.join(opt.dataroot, 'synthia', 'RGB') self.dir_A_2 = os.path.join(opt.dataroot, 'gta5', 'images') self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit') self.dir_A_label_1 = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS') self.dir_A_label_2 = os.path.join(opt.dataroot, 'gta5', 'labels') self.A_paths_1 = make_dataset(self.dir_A_1) self.A_paths_2 = make_dataset(self.dir_A_2) self.B_paths = make_dataset(self.dir_B) self.A_paths_1 = sorted(self.A_paths_1) self.A_paths_2 = sorted(self.A_paths_2) self.B_paths = sorted(self.B_paths) self.A_size_1 = len(self.A_paths_1) self.A_size_2 = len(self.A_paths_2) self.B_size = len(self.B_paths) self.A_labels_1 = make_dataset(self.dir_A_label_1) self.A_labels_2 = make_dataset(self.dir_A_label_2) self.A_labels_1 = sorted(self.A_labels_1) self.A_labels_2 = sorted(self.A_labels_2) self.transform = get_transform(opt) self.label_transform = get_label_transform(opt) def __getitem__(self, index): A_path_1 = self.A_paths_1[index % self.A_size_1] A_path_2 = self.A_paths_2[index % self.A_size_2] if self.opt.serial_batches: index_B = index % self.B_size else: index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] A_label_path_1 = self.A_labels_1[index % self.A_size_1] A_label_path_2 = self.A_labels_2[index % self.A_size_2] A_label_1 = Image.open(A_label_path_1) A_label_2 = Image.open(A_label_path_2) # remaping label for synthia A_label_1 = np.asarray(A_label_1) A_label_1 = syn_relabel(A_label_1) A_label_1 = Image.fromarray(A_label_1, 'L') # remaping label for gta5 A_label_2 = np.asarray(A_label_2) A_label_2 = remap_labels_to_train_ids(A_label_2) A_label_2 = Image.fromarray(A_label_2, 'L') A_img_1 = Image.open(A_path_1).convert('RGB') A_img_2 = Image.open(A_path_2).convert('RGB') B_img = Image.open(B_path).convert('RGB') A_1 = self.transform(A_img_1) A_2 = self.transform(A_img_2) B = self.transform(B_img) A_label_1 = self.label_transform(A_label_1) A_label_2 = self.label_transform(A_label_2) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if input_nc == 1: # RGB to gray tmp = A_1[0, ...] * 0.299 + A_1[1, ...] * 0.587 + A_1[2, ...] * 0.114 A_1 = tmp.unsqueeze(0) tmp = A_2[0, ...] * 0.299 + A_2[1, ...] * 0.587 + A_2[2, ...] * 0.114 A_2 = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A_1': A_1, 'A_2': A_2, 'B': B, 'A_paths_1': A_path_1, 'A_paths_2': A_path_2, 'B_paths': B_path, 'A_label_1': A_label_1, 'A_label_2': A_label_2} def __len__(self): return max(self.A_size_1, self.B_size, self.A_size_2) def name(self): return 'GTA5_Synthia_Cityscapes' ================================================ FILE: cyclegan/data/image_folder.py ================================================ ############################################################################### # Code from # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py # Modified the original code so that it also loads images from the current # directory as well as the subdirectories ############################################################################### import torch.utils.data as data import numpy as np from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_cs_labels(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) if path.endswith("_gtFine_labelIds.png"): images.append(path) return list(set(images)) def make_dataset(dir): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return list(set(images)) def load_labels(dir, images): if os.path.exists(os.path.join(dir, 'labels.txt')): with open(os.path.join(dir, 'labels.txt'), 'r') as f: data = f.read().splitlines() parse = np.array([(x.split(' ')[0], int(x.split(' ')[1])) for x in data]) label_dict = dict(parse) labels = [] for image in images: im_id = image.split('/')[-1].split('.')[0] labels.append(label_dict[im_id]) elif os.path.isdir(os.path.join(dir, 'labels')): Exception('Not yet implemented load_labels for image folder') else: Exception('load_labels expects %s to contain labels.txt or labels folder' % dir) def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: cyclegan/data/synthia_cityscapes.py ================================================ import os.path import random import numpy as np from PIL import Image from data.base_dataset import BaseDataset, get_label_transform, get_transform from data.image_folder import make_cs_labels, make_dataset from data.cityscapes import remap_labels_to_train_ids ignore_label = 255 id2label = {0: ignore_label, 1: 10, 2: 2, 3: 0, 4: 1, 5: 4, 6: 8, 7: 5, 8: 13, 9: 7, 10: 11, 11: 18, 12: 17, 13: ignore_label, 14: ignore_label, 15: 6, 16: 9, 17: 12, 18: 14, 19: 15, 20: 16, 21: 3, 22: ignore_label} classes = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] def syn_relabel(arr): out = ignore_label * np.ones(arr.shape, dtype=np.uint8) for id, label in id2label.items(): out[arr == id] = int(label) return out class SynthiaCityscapesDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_A = os.path.join(opt.dataroot, 'synthia', 'RGB') self.dir_B = os.path.join(opt.dataroot, 'cityscapes', 'leftImg8bit') self.dir_A_label = os.path.join(opt.dataroot, 'synthia', 'GT', 'parsed_LABELS') self.dir_B_label = os.path.join(opt.dataroot, 'cityscapes', 'gtFine') self.A_paths = make_dataset(self.dir_A) self.B_paths = make_dataset(self.dir_B) self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) self.A_labels = make_dataset(self.dir_A_label) self.B_labels = make_cs_labels(self.dir_B_label) self.A_labels = sorted(self.A_labels) self.B_labels = sorted(self.B_labels) self.transform = get_transform(opt) self.label_transform = get_label_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] if self.opt.serial_batches: index_B = index % self.B_size else: index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] A_label_path = self.A_labels[index % self.A_size] B_label_path = self.B_labels[index_B] A_label = Image.open(A_label_path) B_label = Image.open(B_label_path) A_label = np.asarray(A_label) A_label = syn_relabel(A_label) A_label = Image.fromarray(A_label, 'L') B_label = np.asarray(B_label) B_label = remap_labels_to_train_ids(B_label) B_label = Image.fromarray(B_label, 'L') A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') A = self.transform(A_img) B = self.transform(B_img) A_label = self.label_transform(A_label) B_label = self.label_transform(B_label) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label} def __len__(self): return max(self.A_size, self.B_size) def name(self): return 'Synthia_Cityscapes' ================================================ FILE: cyclegan/environment.yml ================================================ name: pytorch-CycleGAN-and-pix2pix channels: - peterjc123 - defaults dependencies: - python=3.5.5 - pytorch=0.3.1 - scipy - pip: - dominate==2.3.1 - git+https://github.com/pytorch/vision.git - Pillow==5.0.0 - numpy==1.14.1 - visdom==0.1.7 ================================================ FILE: cyclegan/models/__init__.py ================================================ import logging def create_model(opt): model = None if opt.model == 'cycle_gan': # assert(opt.dataset_mode == 'unaligned') from .cycle_gan_model import CycleGANModel model = CycleGANModel() elif opt.model == 'test': from .test_model import TestModel model = TestModel() elif opt.model == 'multi_cycle_gan_semantic': from .multi_cycle_gan_semantic_model import CycleGANSemanticModel model = CycleGANSemanticModel() elif opt.model == 'cycle_gan_semantic_fcn': from .cycle_gan_semantic_model import CycleGANSemanticModel model = CycleGANSemanticModel() else: raise NotImplementedError('model [%s] not implemented.' % opt.model) model.initialize(opt) logging.info("model [%s] was created" % (model.name())) return model ================================================ FILE: cyclegan/models/base_model.py ================================================ import os from collections import OrderedDict import torch from . import networks class BaseModel(): def name(self): return 'BaseModel' def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) if opt.resize_or_crop != 'scale_width': torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.image_paths = [] def set_input(self, input): self.input = input def forward(self): pass # load and print networks; create shedulars def setup(self, opt): if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: self.load_networks(opt.which_epoch) self.print_networks(opt.verbose) # make models eval mode during test time def eval(self): for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() # used in test time, wrapping `forward` in no_grad() so we don't save # intermediate steps for backprop def test(self): with torch.no_grad(): self.forward() # get image paths def get_image_paths(self): return self.image_paths def optimize_parameters(self): pass # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) # return visualization images. train.py will display these images, and save the images to a html def get_current_visuals(self): visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret # return traning losses/errors. train.py will print out these errors as debugging information def get_current_losses(self): errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): # float(...) works for both scalar tensor and float number errors_ret[name] = float(getattr(self, 'loss_' + name)) return errors_ret # save models to the disk def save_networks(self, which_epoch): for name in self.model_names: # Don't save semantic consistency networks if isinstance(name, str) and ("PixelCLS" not in name): save_filename = '%s_net_%s.pth' % (which_epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) # load models from the disk def load_networks(self, which_epoch): for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (which_epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) # print network information def print_networks(self, verbose): print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------') # set requies_grad=Fasle to avoid computation def set_requires_grad(self, nets, requires_grad=False): if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad ================================================ FILE: cyclegan/models/cycle_gan_model.py ================================================ import torch import itertools from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class CycleGANModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B self.loss_G.backward() def optimize_parameters(self): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() ================================================ FILE: cyclegan/models/cycle_gan_semantic_model.py ================================================ import itertools import sys import torch import torch.nn.functional as F from util.image_pool import ImagePool from . import networks from .base_model import BaseModel sys.path.append('/nfs/project/libo_iMADAN') from cycada.models import get_model class CycleGANSemanticModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'sem_AB'] # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A.append('idt_A') visual_names_B.append('idt_B') self.visual_names = visual_names_A + visual_names_B # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) # Here for semantic consistency loss, load a fcn network as fs here. self.netPixelCLS = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_init) # Specially initialize Pixel CLS network if len(self.gpu_ids) > 0: assert (torch.cuda.is_available()) self.netPixelCLS.to(self.gpu_ids[0]) self.netPixelCLS = torch.nn.DataParallel(self.netPixelCLS, self.gpu_ids) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() # self.criterionCLS = torch.nn.modules.CrossEntropyLoss() self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean') # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if 'A_label' in input and 'B_label' in input: self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device) self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device) # self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain def forward(self): self.fake_B = self.netG_A(self.real_A) self.rec_A = self.netG_B(self.fake_B) self.fake_A = self.netG_B(self.real_B) self.rec_B = self.netG_A(self.fake_A) if self.isTrain: # Forward all four images through classifier # Keep predictions from fake images only self.pred_real_A = self.netPixelCLS(self.real_A) _, self.gt_pred_A = self.pred_real_A.max(1) self.pred_fake_B = self.netPixelCLS(self.fake_B) _, pfB = self.pred_fake_B.max(1) def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined Loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_PixelCLS(self): label_A = self.input_A_label # forward only real source image through semantic classifier pred_A = self.netPixelCLS(self.real_A) self.loss_PixelCLS = self.criterionSemantic(F.log_softmax(pred_A, dim=1), label_A.long()) self.loss_PixelCLS.backward() def backward_D_A(self): fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B) def backward_D_B(self): fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def backward_G(self, opt): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed. self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed. self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) self.loss_G_A = 2 * self.criterionGAN(self.netD_A(self.fake_B), True) # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A # Backward cycle loss self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss standard cyclegan self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B # real_A(syn)->fake_B->(fcn_frozen)->pred_fake_B == input_A_label if opt.semantic_loss: self.loss_sem_AB = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B, dim=1), F.softmax(self.pred_real_A, dim=1)) self.loss_sem_AB = opt.general_semantic_weight * torch.div(self.loss_sem_AB, self.pred_fake_B.shape[1] * self.pred_fake_B.shape[2] * self.pred_fake_B.shape[3]) self.loss_G += self.loss_sem_AB self.loss_G.backward() def optimize_parameters(self, opt): # forward self.forward() # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) self.optimizer_G.zero_grad() # self.optimizer_CLS.zero_grad() self.backward_G(opt) self.optimizer_G.step() # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) self.optimizer_D.zero_grad() self.backward_D_A() self.backward_D_B() self.optimizer_D.step() ================================================ FILE: cyclegan/models/multi_cycle_gan_semantic_model.py ================================================ import itertools import sys import torch import torch.nn.functional as F from util.image_pool import ImagePool from . import networks from .base_model import BaseModel sys.path.append('/nfs/project/libo_iMADAN') from cycada.models import get_model class CycleGANSemanticModel(BaseModel): def name(self): return 'CycleGANModel' def initialize(self, opt): BaseModel.initialize(self, opt) self.semantic_loss = opt.semantic_loss # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = ['D_A_1', 'G_A_1', 'cycle_A_1', 'idt_A_1', 'D_B_1', 'G_B_1', 'cycle_B_1', 'idt_B_1', 'D_A_2', 'G_A_2', 'cycle_A_2', 'idt_A_2', 'D_B_2', 'G_B_2', 'cycle_B_2', 'idt_B_2'] if opt.SAD: self.loss_names.extend(['D_3_1', 'G_s1s2']) if opt.CCD or opt.HF_CCD: self.loss_names.extend(['D_21', 'G_s1s21']) self.loss_names.extend(['D_12', 'G_s2s12']) if self.semantic_loss: self.loss_names.extend(['sem_syn', 'sem_gta']) # specify the images you want to save/display. The program will call base_model.get_current_visuals visual_names_A_1 = ['real_A_1', 'fake_B_1', 'rec_A_1'] visual_names_B_1 = ['real_B', 'fake_A_1', 'rec_B_1'] visual_names_A_2 = ['real_A_2', 'fake_B_2', 'rec_A_2'] visual_names_B_2 = ['fake_A_2', 'rec_B_2'] if self.isTrain and self.opt.lambda_identity > 0.0: visual_names_A_1.append('idt_A_1') visual_names_B_1.append('idt_B_1') visual_names_A_2.append('idt_A_2') visual_names_B_2.append('idt_B_2') self.visual_names = visual_names_A_1 + visual_names_B_1 + visual_names_A_2 + visual_names_B_2 # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if self.isTrain: # self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] if opt.Shared_DT: self.model_names = ['G_A_1', 'G_B_1', 'D_A', 'D_B_1', 'D_B_2', 'G_A_2', 'G_B_2'] else: self.model_names = ['G_A_1', 'G_B_1', 'D_A_1', 'D_B_1', 'G_A_2', 'G_B_2', 'D_A_2', 'D_B_2'] if opt.SAD: self.model_names.append('D_3') if opt.CCD or opt.HF_CCD: self.model_names.append('D_12') self.model_names.append('D_21') else: # during test time, only load Gs self.model_names = ['G_A_1', 'G_B_1', 'G_A_2', 'G_B_2'] # load/define networks # The naming conversion is different from those used in the paper # Code (paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B_1 = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) self.netG_B_2 = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) if opt.semantic_loss: self.netPixelCLS_SYN = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_syn) self.netPixelCLS_GTA = get_model(opt.weights_model_type, num_cls=opt.num_cls, pretrained=True, weights_init=opt.weights_gta) if len(self.gpu_ids) > 0: assert (torch.cuda.is_available()) self.netPixelCLS_SYN.to(self.gpu_ids[0]) self.netPixelCLS_SYN = torch.nn.DataParallel(self.netPixelCLS_SYN, self.gpu_ids) self.netPixelCLS_GTA.to(self.gpu_ids[0]) self.netPixelCLS_GTA = torch.nn.DataParallel(self.netPixelCLS_GTA, self.gpu_ids) if self.isTrain: use_sigmoid = opt.no_lsgan if opt.Shared_DT: self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) else: self.netD_A_1 = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_A_2 = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B_1 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_B_2 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if opt.SAD: self.netD_3 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if opt.CCD or opt.HF_CCD: self.netD_12 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_21 = networks.define_D(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if self.isTrain: self.fake_A_1_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_1_pool = ImagePool(opt.pool_size) self.fake_A_2_pool = ImagePool(opt.pool_size) self.fake_B_2_pool = ImagePool(opt.pool_size) self.fake_A_21_pool = ImagePool(opt.pool_size) self.fake_A_12_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionSemantic = torch.nn.KLDivLoss(reduction='batchmean') # initialize optimizers if opt.Shared_DT: self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B_1.parameters(), self.netD_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) else: self.optimizer_D_1 = torch.optim.Adam(itertools.chain(self.netD_A_1.parameters(), self.netD_B_1.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_2 = torch.optim.Adam(itertools.chain(self.netD_A_2.parameters(), self.netD_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G_1 = torch.optim.Adam(itertools.chain(self.netG_A_1.parameters(), self.netG_B_1.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_G_2 = torch.optim.Adam(itertools.chain(self.netG_A_2.parameters(), self.netG_B_2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.SAD: self.optimizer_D_3 = torch.optim.Adam(self.netD_3.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) if opt.CCD or opt.HF_CCD: self.optimizer_D_21 = torch.optim.Adam(self.netD_21.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_12 = torch.optim.Adam(self.netD_12.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.optimizers.append(self.optimizer_G_1) self.optimizers.append(self.optimizer_G_2) if opt.Shared_DT: self.optimizers.append(self.optimizer_D) else: self.optimizers.append(self.optimizer_D_1) self.optimizers.append(self.optimizer_D_2) if opt.SAD: self.optimizers.append(self.optimizer_D_3) if opt.CCD or opt.HF_CCD: self.optimizers.append(self.optimizer_D_12) self.optimizers.append(self.optimizer_D_21) def set_input(self, input): self.real_A_1 = input['A_1'].to(self.device) self.real_A_2 = input['A_2'].to(self.device) self.real_B = input['B'].to(self.device) self.image_paths_1 = input['A_paths_1'] self.image_paths_2 = input['A_paths_2'] self.image_paths = self.image_paths_1 + self.image_paths_2 if 'A_label_1' in input and 'B_label' in input and 'A_label_2' in input: self.input_A_label_1 = input['A_label_1'].to(self.device) self.input_A_label_2 = input['A_label_2'].to(self.device) self.input_B_label = input['B_label'].to(self.device) def forward(self, opt): # cycle for data input #1 self.fake_B_1 = self.netG_A_1(self.real_A_1) self.rec_A_1 = self.netG_B_1(self.fake_B_1) self.fake_A_1 = self.netG_B_1(self.real_B) self.rec_B_1 = self.netG_A_1(self.fake_A_1) # cycle for data input #2 self.fake_B_2 = self.netG_A_2(self.real_A_2) self.rec_A_2 = self.netG_B_2(self.fake_B_2) self.fake_A_2 = self.netG_B_2(self.real_B) self.rec_B_2 = self.netG_A_2(self.fake_A_2) if opt.CCD: # generate s21 for d21 branch self.fake_A_21 = self.netG_B_1(self.fake_B_2) # generate s12 for d12 branch self.fake_A_12 = self.netG_B_2(self.fake_B_1) if self.isTrain and self.semantic_loss: # Forward all four images through classifier # Keep predictions from fake images only self.pred_real_A_SYN = self.netPixelCLS_SYN(self.real_A_1) _, self.gt_pred_A_SYN = self.pred_real_A_SYN.max(1) self.pred_fake_B_SYN = self.netPixelCLS_SYN(self.fake_B_1) _, pfB_SYN = self.pred_fake_B_SYN.max(1) self.pred_real_A_GTA = self.netPixelCLS_GTA(self.real_A_2) _, self.gt_pred_A_GTA = self.pred_real_A_GTA.max(1) self.pred_fake_B_GTA = self.netPixelCLS_GTA(self.fake_B_2) _, pfB_GTA = self.pred_fake_B_GTA.max(1) def backward_D_basic(self, netD, real, fake, SAD=False): # Real if SAD == False: pred_real = netD(real) else: pred_real = netD(real.detach()) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D_A(self, Shared_DT): # data 1 A1->B fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1) if Shared_DT: self.loss_D_A_1 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_1) else: self.loss_D_A_1 = self.backward_D_basic(self.netD_A_1, self.real_B, fake_B_1) # data 2 A2->B fake_B_2 = self.fake_B_2_pool.query(self.fake_B_2) if Shared_DT: self.loss_D_A_2 = self.backward_D_basic(self.netD_A, self.real_B, fake_B_2) else: self.loss_D_A_2 = self.backward_D_basic(self.netD_A_2, self.real_B, fake_B_2) def backward_D_B(self): # data 1 B->A1 fake_A_1 = self.fake_A_1_pool.query(self.fake_A_1) self.loss_D_B_1 = self.backward_D_basic(self.netD_B_1, self.real_A_1, fake_A_1) # data 2 B->A2 fake_A_2 = self.fake_A_2_pool.query(self.fake_A_2) self.loss_D_B_2 = self.backward_D_basic(self.netD_B_2, self.real_A_2, fake_A_2) def backward_D(self, which_D): if which_D == 'SAD': fake_B_1 = self.fake_B_1_pool.query(self.fake_B_1) self.loss_D_3_1 = self.backward_D_basic(self.netD_3, self.fake_B_2, fake_B_1, SAD=True) elif which_D == 'CCD_21': fake_A_21 = self.fake_A_21_pool.query(self.fake_A_21) self.loss_D_21 = self.backward_D_basic(self.netD_21, self.real_A_1, fake_A_21) elif which_D == 'CCD_12': fake_A_12 = self.fake_A_12_pool.query(self.fake_A_12) self.loss_D_12 = self.backward_D_basic(self.netD_12, self.real_A_2, fake_A_12) else: raise Exception("Invalid Choice {}".format(which_D)) # fake_B_2 = self.fake_B_pool.query(self.fake_B_2) # self.loss_D_3_2 = self.backward_D_basic(self.netD_3, self.fake_B_1, fake_B_2) def backward_G(self, opt): lambda_idt = self.opt.lambda_identity lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Identity loss if lambda_idt > 0: self.idt_A_1 = self.netG_A_1(self.real_B) self.loss_idt_A_1 = self.criterionIdt(self.idt_A_1, self.real_B) * lambda_B * lambda_idt self.idt_A_2 = self.netG_A_2(self.real_B) self.loss_idt_A_2 = self.criterionIdt(self.idt_A_2, self.real_B) * lambda_B * lambda_idt self.idt_B_1 = self.netG_B_1(self.real_A_1) self.loss_idt_B_1 = self.criterionIdt(self.idt_B_1, self.real_A_1) * lambda_A * lambda_idt self.idt_B_2 = self.netG_B_2(self.real_A_2) self.loss_idt_B_2 = self.criterionIdt(self.idt_B_2, self.real_A_2) * lambda_A * lambda_idt else: self.loss_idt_A_1 = 0 self.loss_idt_A_2 = 0 self.loss_idt_B_1 = 0 self.loss_idt_B_2 = 0 if opt.Shared_DT: self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A(self.fake_B_1), True) self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A(self.fake_B_2), True) else: self.loss_G_A_1 = 2 * self.criterionGAN(self.netD_A_1(self.fake_B_1), True) self.loss_G_A_2 = 2 * self.criterionGAN(self.netD_A_2(self.fake_B_2), True) # GAN loss D_B(G_B(B)) self.loss_G_B_1 = self.criterionGAN(self.netD_B_1(self.fake_A_1), True) self.loss_G_B_2 = self.criterionGAN(self.netD_B_2(self.fake_A_2), True) # Forward cycle loss self.loss_cycle_A_1 = self.criterionCycle(self.rec_A_1, self.real_A_1) * lambda_A self.loss_cycle_A_2 = self.criterionCycle(self.rec_A_2, self.real_A_2) * lambda_A # Backward cycle loss self.loss_cycle_B_1 = self.criterionCycle(self.rec_B_1, self.real_B) * lambda_B self.loss_cycle_B_2 = self.criterionCycle(self.rec_B_2, self.real_B) * lambda_B # combined loss standard cyclegan self.loss_G_1 = self.loss_G_A_1 + self.loss_G_B_1 + self.loss_cycle_A_1 + self.loss_cycle_B_1 + self.loss_idt_A_1 + self.loss_idt_B_1 self.loss_G_2 = self.loss_G_A_2 + self.loss_G_B_2 + self.loss_cycle_A_2 + self.loss_cycle_B_2 + self.loss_idt_A_2 + self.loss_idt_B_2 self.loss_G = self.loss_G_1 + self.loss_G_2 if opt.SAD: # D3 loss if opt.SAD_frozen_epoch != -1 and opt.current_epoch > opt.SAD_frozen_epoch: self.loss_G_s1s2 = self.criterionGAN(self.netD_3(self.fake_B_1), True) else: self.loss_G_s1s2 = 0 self.loss_G += self.loss_G_s1s2 if opt.CCD: # D21 loss if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch: self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True) self.loss_G += self.loss_G_s1s21 * opt.D1D2_weight else: self.loss_G_s1s21 = 0 if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch: self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True) self.loss_G += self.loss_G_s2s12 * opt.D1D2_weight else: self.loss_G_s2s12 = 0 if opt.semantic_loss: self.loss_sem_syn = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_SYN, dim=1), F.softmax(self.pred_real_A_SYN, dim=1)) self.loss_sem_gta = opt.dynamic_weight * self.criterionSemantic(F.log_softmax(self.pred_fake_B_GTA, dim=1), F.softmax(self.pred_real_A_GTA, dim=1)) self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_syn, self.pred_fake_B_SYN.shape[1] * self.pred_fake_B_SYN.shape[2] * self.pred_fake_B_SYN.shape[3]) self.loss_G += opt.general_semantic_weight * torch.div(self.loss_sem_gta, self.pred_fake_B_GTA.shape[1] * self.pred_fake_B_GTA.shape[2] * self.pred_fake_B_GTA.shape[3]) self.loss_G.backward() def backward_HF_CCD(self, opt): self.fake_B_1 = self.netG_A_1(self.real_A_1) self.fake_B_2 = self.netG_A_2(self.real_A_2) # generate s21 for d21 branch self.fake_A_21 = self.netG_B_1(self.fake_B_2) # generate s12 for d12 branch self.fake_A_12 = self.netG_B_2(self.fake_B_1) # D12 loss if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch: self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True) else: self.loss_G_s2s12 = 0 # D21 loss if opt.CCD_frozen_epoch != -1 and opt.current_epoch > opt.CCD_frozen_epoch: self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True) else: self.loss_G_s1s21 = 0 # self.loss_G_s2s12 = self.criterionGAN(self.netD_12(self.fake_A_12), True) # self.loss_G_s1s21 = self.criterionGAN(self.netD_21(self.fake_A_21), True) self.loss_G_HF = self.loss_G_s1s21 * opt.CCD_weight + self.loss_G_s2s12 * opt.CCD_weight if isinstance(self.loss_G_HF, torch.Tensor): self.loss_G_HF.backward() def optimize_parameters(self, opt): # forward self.forward(opt) # G_A and G_B # set D to false, back prop G's gradients if opt.Shared_DT: self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], False) else: self.set_requires_grad([self.netD_A_1, self.netD_B_1], False) self.set_requires_grad([self.netD_A_2, self.netD_B_2], False) if opt.SAD: self.set_requires_grad([self.netD_3], False) if opt.CCD or opt.HF_CCD: self.set_requires_grad([self.netD_21], False) self.set_requires_grad([self.netD_12], False) self.set_requires_grad([self.netG_A_1, self.netG_B_1], True) self.set_requires_grad([self.netG_A_2, self.netG_B_2], True) self.optimizer_G_1.zero_grad() self.optimizer_G_2.zero_grad() # self.optimizer_CLS.zero_grad() self.backward_G(opt) self.optimizer_G_1.step() self.optimizer_G_2.step() if opt.HF_CCD: self.optimizer_G_1.zero_grad() self.optimizer_G_2.zero_grad() self.set_requires_grad([self.netG_A_1, self.netG_A_2], True) self.set_requires_grad([self.netG_B_1, self.netG_B_2], False) self.backward_HF_CCD(opt) self.optimizer_G_1.step() self.optimizer_G_2.step() # D_A and D_B if opt.Shared_DT: self.set_requires_grad([self.netD_A, self.netD_B_1, self.netD_B_2], True) else: self.set_requires_grad([self.netD_A_1, self.netD_B_1], True) self.set_requires_grad([self.netD_A_2, self.netD_B_2], True) if opt.Shared_DT: self.optimizer_D.zero_grad() else: self.optimizer_D_1.zero_grad() self.optimizer_D_2.zero_grad() self.backward_D_B() self.backward_D_A(opt.Shared_DT) if opt.Shared_DT: self.optimizer_D.step() else: self.optimizer_D_1.step() self.optimizer_D_2.step() if opt.SAD: self.set_requires_grad([self.netD_3], True) self.optimizer_D_3.zero_grad() self.backward_D('SAD') self.optimizer_D_3.step() if opt.CCD or opt.HF_CCD: self.set_requires_grad([self.netD_21], True) self.optimizer_D_21.zero_grad() self.backward_D('CCD_21') self.optimizer_D_21.step() self.set_requires_grad([self.netD_12], True) self.optimizer_D_12.zero_grad() self.backward_D('CCD_12') self.optimizer_D_12.step() ================================================ FILE: cyclegan/models/networks.py ================================================ import functools import torch import torch.nn as nn from torch.nn import init from torch.optim import lr_scheduler ############################################################################### # Helper Functions ############################################################################### def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == 'none': norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_scheduler(optimizer, opt): if opt.lr_policy == 'lambda': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: init.normal_(m.weight.data, 1.0, gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) def init_net(net, init_type='normal', gpu_ids=[]): if len(gpu_ids) > 0: assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) init_weights(net, init_type) return net def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('Total number of parameters: %d' % num_params) def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): netG = None norm_layer = get_norm_layer(norm_type=norm) if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) elif which_model_netG == 'resnet_6blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) elif which_model_netG == 'unet_128': netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif which_model_netG == 'unet_256': netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) return init_net(netG, init_type, gpu_ids) def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): netD = None norm_layer = get_norm_layer(norm_type=norm) if which_model_netD == 'basic': netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) elif which_model_netD == 'pixel': netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) return init_net(netD, init_type, gpu_ids) def define_C(output_nc, ndf, init_type='normal', gpu_ids=[]): # if output_nc == 3: # netC = get_model('DTN', num_cls=10) # else: # Exception('classifier only implemented for 32x32x3 images') netC = Classifier(output_nc, ndf) return init_net(netC, init_type, gpu_ids) ############################################################################## # Classes ############################################################################## # Defines the GAN loss which uses either LSGAN or the regular GAN. # When LSGAN is used, it is basically same as MSELoss, # but it abstracts away the need to create the target label tensor # that has the same size as the input class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(input) def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) return self.loss(input, target_tensor) # Defines the generator that consists of Resnet blocks between a few # downsampling/upsampling operations. # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): assert (n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): return self.model(input) # Define a resnet block class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out # Defines the Unet generator. # |num_downs|: number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block def forward(self, input): return self.model(input) # Defines the submodule with skip connection. # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1) # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) def forward(self, input): return self.model(input) class PixelDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): super(PixelDiscriminator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] if use_sigmoid: self.net.append(nn.Sigmoid()) self.net = nn.Sequential(*self.net) def forward(self, input): return self.net(input) class Classifier(nn.Module): def __init__(self, input_nc, ndf, norm_layer=nn.BatchNorm2d): super(Classifier, self).__init__() kw = 3 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(3): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2), norm_layer(ndf * nf_mult, affine=True), nn.LeakyReLU(0.2, True) ] self.before_linear = nn.Sequential(*sequence) sequence = [ nn.Linear(ndf * nf_mult, 1024), nn.Linear(1024, 10) ] self.after_linear = nn.Sequential(*sequence) def forward(self, x): bs = x.size(0) out = self.after_linear(self.before_linear(x).view(bs, -1)) return out # return nn.functional.log_softmax(out, dim=1) ================================================ FILE: cyclegan/models/test_model.py ================================================ from . import networks from .base_model import BaseModel class TestModel(BaseModel): def name(self): return 'TestModel' def initialize(self, opt): assert (not opt.isTrain) BaseModel.initialize(self, opt) # specify the training losses you want to print out. The program will call base_model.get_current_losses self.loss_names = [] # specify the images you want to save/display. The program will call base_model.get_current_visuals self.visual_names = ['real_A'] # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks if opt.dataset_mode == 'synthia_cityscapes': self.model_names = ['G_A_1'] self.visual_names.append('fake_B_1') self.netG_A_1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) elif opt.dataset_mode == 'gta5_cityscapes': self.model_names = ['G_A_2'] self.visual_names.append('fake_B_2') self.netG_A_2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) def set_input(self, input): # we need to use single_dataset mode self.real_A = input['A'].to(self.device) self.image_paths = input['A_paths'] def forward(self): if hasattr(self, 'netG_A_1'): self.fake_B_1 = self.netG_A_1(self.real_A) elif hasattr(self, 'netG_A_2'): self.fake_B_2 = self.netG_A_2(self.real_A) ================================================ FILE: cyclegan/options/__init__.py ================================================ ================================================ FILE: cyclegan/options/base_options.py ================================================ import argparse import os import torch from util import util class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.initialized = False def initialize(self): self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') self.parser.add_argument('--loadSize', type=int, default=600, help='scale images to this size') self.parser.add_argument('--fineSize', type=int, default=600, help='then crop to this size') self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') self.parser.add_argument('--which_model_netD', type=str, default='n_layers', help='selects model to use for netD') self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG') self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') self.parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. cycle_gan, pix2pix, test') self.parser.add_argument('--weights_model_type', type=str, default='drn26', help='chooses which model to use. drn26, fcn8s') self.parser.add_argument('--num_cls', default=19, type=int) self.parser.add_argument('--max_epoch', default=20, type=int) self.parser.add_argument('--current_epoch', default=0, type=int) self.parser.add_argument('--weights_init', type=str) self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') self.parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, ' 'only a subset is loaded.') self.parser.add_argument('--resize_or_crop', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') self.parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') self.parser.add_argument('--out_all', action='store_true', help='output all stylized images(fake_B_{})') self.parser.add_argument('--SAD', action='store_true', help='Sub-domain Aggregation Discriminator module') self.parser.add_argument('--CCD', action='store_true', help='Cross-domain Cycle Discriminator module') self.parser.add_argument('--CCD_weight', type=float, default=1, help='weight for cross domain cycle discriminator loss') self.parser.add_argument('--HF_CCD', action='store_true', help='Half Freeze Cross-domain Cycle Discriminator module') self.parser.add_argument('--CCD_frozen_epoch', type=int, default=-1) self.parser.add_argument('--SAD_frozen_epoch', type=int, default=-1) self.parser.add_argument('--Shared_DT', type=bool, default=True, help="Through ") self.parser.add_argument('--model_type', type=str, default='fcn8s', help="choose to load which type of model (fcn8s, drn26, deeplabv2)") self.parser.add_argument('--semantic_loss', action='store_true', help='use semantic loss') self.parser.add_argument('--general_semantic_weight', type=float, default=0.2, help='weight for semantic loss') self.parser.add_argument('--weights_syn', type=str, default='', help='init weights for synthia') self.parser.add_argument('--weights_gta', type=str, default='', help='init weights for gta') self.parser.add_argument('--inference_script', type=str, default='', help='inference script') self.parser.add_argument('--dynamic_weight', type=float, default=10, help='Weight for Dynamic Semantic Loss(KL div) loss') self.initialized = True def parse(self): if not self.initialized: self.initialize() opt = self.parser.parse_args() opt.isTrain = self.isTrain # train or test str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) # set gpu ids if len(opt.gpu_ids) > 0: torch.cuda.set_device(opt.gpu_ids[0]) args = vars(opt) print('------------ Options -------------') for k, v in sorted(args.items()): print('%s: %s' % (str(k), str(v))) print('-------------- End ----------------') if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt.txt') with open(file_name, 'wt') as opt_file: opt_file.write('------------ Options -------------\n') for k, v in sorted(args.items()): opt_file.write('%s: %s\n' % (str(k), str(v))) opt_file.write('-------------- End ----------------\n') self.opt = opt return self.opt ================================================ FILE: cyclegan/options/test_options.py ================================================ from .base_options import BaseOptions class TestOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.isTrain = False ================================================ FILE: cyclegan/options/train_options.py ================================================ from .base_options import BaseOptions class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') self.parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for cycle loss (A -> B -> A)') self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for cycle loss (B -> A -> B)') self.parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the ' 'identity mapping loss.' 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the ' 'reconstruction loss, please set lambda_identity = 0.1') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') self.isTrain = True ================================================ FILE: cyclegan/test.py ================================================ import os import sys import torch from models import create_model from options.test_options import TestOptions from util import html from util.visualizer import save_images from data import CreateDataLoader import logging sys.path.append("/nfs/project/libo_i/MADAN") if __name__ == '__main__': opt = TestOptions().parse() opt.serial_batches = True # no shuffle opt.no_flip = True # no flip opt.display_id = -1 # no visdom display data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() model = create_model(opt) model.setup(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test for i, data in enumerate(dataset): if i >= opt.how_many: break # check img size if i == 0: for item in data.items(): if isinstance(item[1], torch.Tensor): logging.info(item[0], item[1].size()) model.set_input(data) model.test() visuals = model.get_current_visuals() # remove reductant files when outputing if opt.out_all: remove_list = [] for item in visuals: if 'fake_B' not in item: remove_list.append(item) for rm_item in remove_list: del visuals[rm_item] img_path = model.get_image_paths() if i % 5 == 0: logging.info('processing (%04d)-th image...' % (i * opt.batchSize)) if 'mul' in opt.model: save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, multi_flag=True) else: save_images(webpage.get_image_dir(), visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) ================================================ FILE: cyclegan/train.py ================================================ import subprocess import sys import time sys.path.append("/nfs/project/libo_i/MADAN/cyclegan") from options.train_options import TrainOptions from data import CreateDataLoader from models import create_model from util.visualizer import Visualizer import torch import logging if __name__ == '__main__': opt = TrainOptions().parse() data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) logging.info('#training images = %d' % dataset_size) model = create_model(opt) model.setup(opt) visualizer = Visualizer(opt) total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 opt.current_epoch = epoch logging.info("Current epoch update to {}".format(opt.current_epoch)) for i, data in enumerate(dataset): if total_steps == 0: for item in data.items(): if isinstance(item[1], torch.Tensor): logging.info(item[1].size()) iter_start_time = time.time() if total_steps % opt.print_freq == 0: t_data = iter_start_time - iter_data_time visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters(opt) if total_steps % opt.display_freq == 0: save_result = total_steps % opt.update_html_freq == 0 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_steps % opt.print_freq == 0: losses = model.get_current_losses() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) if opt.display_id > 0: visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) if total_steps % opt.save_latest_freq == 0: logging.info('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save_networks('latest') iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: logging.info('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save_networks('latest') model.save_networks(epoch) logging.info('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.max_epoch, time.time() - epoch_start_time)) model.update_learning_rate() ================================================ FILE: cyclegan/util/__init__.py ================================================ ================================================ FILE: cyclegan/util/get_data.py ================================================ from __future__ import print_function import os import tarfile import requests from warnings import warn from zipfile import ZipFile from bs4 import BeautifulSoup from os.path import abspath, isdir, join, basename class GetData(object): """ Download CycleGAN or Pix2Pix Data. Args: technique : str One of: 'cyclegan' or 'pix2pix'. verbose : bool If True, print additional information. Examples: >>> from util.get_data import GetData >>> gd = GetData(technique='cyclegan') >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. """ def __init__(self, technique='cyclegan', verbose=True): url_dict = { 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' } self.url = url_dict.get(technique.lower()) self._verbose = verbose def _print(self, text): if self._verbose: print(text) @staticmethod def _get_options(r): soup = BeautifulSoup(r.text, 'lxml') options = [h.text for h in soup.find_all('a', href=True) if h.text.endswith(('.zip', 'tar.gz'))] return options def _present_options(self): r = requests.get(self.url) options = self._get_options(r) print('Options:\n') for i, o in enumerate(options): print("{0}: {1}".format(i, o)) choice = input("\nPlease enter the number of the " "dataset above you wish to download:") return options[int(choice)] def _download_data(self, dataset_url, save_path): if not isdir(save_path): os.makedirs(save_path) base = basename(dataset_url) temp_save_path = join(save_path, base) with open(temp_save_path, "wb") as f: r = requests.get(dataset_url) f.write(r.content) if base.endswith('.tar.gz'): obj = tarfile.open(temp_save_path) elif base.endswith('.zip'): obj = ZipFile(temp_save_path, 'r') else: raise ValueError("Unknown File Type: {0}.".format(base)) self._print("Unpacking Data...") obj.extractall(save_path) obj.close() os.remove(temp_save_path) def get(self, save_path, dataset=None): """ Download a dataset. Args: save_path : str A directory to save the data to. dataset : str, optional A specific dataset to download. Note: this must include the file extension. If None, options will be presented for you to choose from. Returns: save_path_full : str The absolute path to the downloaded data. """ if dataset is None: selected_dataset = self._present_options() else: selected_dataset = dataset save_path_full = join(save_path, selected_dataset.split('.')[0]) if isdir(save_path_full): warn("\n'{0}' already exists. Voiding Download.".format( save_path_full)) else: self._print('Downloading Data...') url = "{0}/{1}".format(self.url, selected_dataset) self._download_data(url, save_path=save_path) return abspath(save_path_full) ================================================ FILE: cyclegan/util/html.py ================================================ import dominate from dominate.tags import * import os class HTML: def __init__(self, web_dir, title, reflesh=0): self.title = title self.web_dir = web_dir self.img_dir = os.path.join(self.web_dir, 'images') if not os.path.exists(self.web_dir): os.makedirs(self.web_dir) if not os.path.exists(self.img_dir): os.makedirs(self.img_dir) # print(self.img_dir) self.doc = dominate.document(title=title) if reflesh > 0: with self.doc.head: meta(http_equiv="reflesh", content=str(reflesh)) def get_image_dir(self): return self.img_dir def add_header(self, str): with self.doc: h3(str) def add_table(self, border=1): self.t = table(border=border, style="table-layout: fixed;") self.doc.add(self.t) def add_images(self, ims, txts, links, width=400): self.add_table() with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): img(style="width:%dpx" % width, src=os.path.join('images', im)) br() p(txt) def save(self): html_file = '%s/index.html' % self.web_dir f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': html = HTML('web/', 'test_html') html.add_header('hello world') ims = [] txts = [] links = [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: cyclegan/util/image_pool.py ================================================ import random import torch class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: random_id = random.randint(0, self.pool_size - 1) # randint is inclusive tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) return_images = torch.cat(return_images, 0) return return_images ================================================ FILE: cyclegan/util/util.py ================================================ from __future__ import print_function import os import numpy as np import torch from PIL import Image # Converts a Tensor into an image array (numpy) # |imtype|: the desired type of the converted numpy array def tensor2im(input_image, imtype=np.uint8): if isinstance(input_image, torch.Tensor): image_tensor = input_image.data else: return input_image image_numpy = image_tensor.cpu().float().numpy() if image_numpy.shape[0] == 1: image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 return image_numpy.astype(imtype) # def tensor2im(input_image, imtype=np.uint8): # if isinstance(input_image, torch.Tensor): # image_tensor = input_image.data # else: # return input_image # image_numpy = image_tensor[0].cpu().float().numpy() # if image_numpy.shape[0] == 1: # image_numpy = np.tile(image_numpy, (3, 1, 1)) # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # return image_numpy.astype(imtype) def diagnose_network(net, name='network'): mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path) def print_numpy(x, val=True, shp=False): x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) ================================================ FILE: cyclegan/util/visualizer.py ================================================ import ntpath import os import time import numpy as np from . import html, util # save image to the disk def save_images(image_dir, visuals, image_path, aspect_ratio=1.0, width=256, multi_flag=False): for i in range(len(image_path)): short_path = ntpath.basename(image_path[i]) name = os.path.splitext(short_path)[0] for ind, (label, im_data) in enumerate(visuals.items()): # align visual names and real image name if multi_flag is True and (str(i + 1) not in label): continue im = util.tensor2im(im_data[i, :, :, :]) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape util.save_image(im, save_path) class Visualizer(): def __init__(self, opt): self.display_id = opt.display_id self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name self.opt = opt self.saved = False if self.display_id > 0: import visdom self.ncols = opt.display_ncols self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port) if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def reset(self): self.saved = False # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, save_result): if self.display_id > 0: # show images in the browser ncols = self.ncols if ncols > 0: ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) title = self.name label_html = '' label_html_row = '' images = [] idx = 0 for label, image in visuals.items(): image_numpy = util.tensor2im(image) label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '%s' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '' idx += 1 if label_html_row != '': label_html += '%s' % label_html_row # pane col = image row self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image in visuals.items(): image_numpy = util.tensor2im(image) self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 if self.use_html and (save_result or not self.saved): # save images to a html file self.saved = True for label, image in visuals.items(): image_numpy = util.tensor2im(image[0, :, :, :]) img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims, txts, links = [], [], [] for label, image_numpy in visuals.items(): # image_numpy = util.tensor2im(image) img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() # losses: dictionary of error labels and values def plot_current_losses(self, epoch, counter_ratio, opt, losses): if not hasattr(self, 'plot_data'): self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) self.vis.line( X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', 'legend': self.plot_data['legend'], 'xlabel': 'epoch', 'ylabel': 'loss'}, win=self.display_id) # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, i, losses, t, t_data): message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) ================================================ FILE: requirements.txt ================================================ scipy torchvision tensorboardX tensorflow click tqdm requests colorlog pyyaml torch>=1.1.0 torchvision>=0.3.0 dominate>=2.3.1 visdom>=0.1.8.3 ================================================ FILE: scripts/ADDA/adda_cyclegta2cs_feat.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 ###################### # loss weight params # ###################### lr=2e-5 momentum=0.9 lambda_d=1 lambda_g=0.1 export LC_ALL=C.UTF-8 export LANG=C.UTF-8 export PYTHONPATH='/usr/bin/python3' ################ # train params # ################ max_iter=50000 crop=600 snapshot=5000 batch=8 weight_share='weights_shared' discrim='discrim_feat' ######## # Data # ######## src='cyclegta5' tgt='cityscapes' data_flag='V4_SEM_300' datadir='/nfs/project/libo_i/cycada/data/' resdir="results/${src}_to_${tgt}/adda_sgd_${weight_share}_nolsgan_${discrim}_${data_flag}" # init with pre-trained cyclegta5 model #model='drn26' #baseiter=115000 model='fcn8s' baseiter=100000 base_model="/nfs/project/libo_i/cycada/pretrained_models/GTA2CS_score_V4_SEM_300_net-itercurr.pth" discrim_model="/nfs/project/libo_i/cycada/pretrained_models/new_Dis_cyclegta5_sem_v4_300x540_iter-iter5440_abv60.pth" outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}_${discrim}" echo $outdir echo $base_model cd /nfs/project/libo_i/cycada # Run python script # python3 scripts/train_fcn_adda.py \ ${outdir} \ --dataset ${src} --dataset ${tgt} --datadir ${datadir} \ --lr ${lr} --momentum ${momentum} --gpu ${gpu} \ --lambda_d ${lambda_d} --lambda_g ${lambda_g} \ --weights_init ${base_model} --model ${model} \ --"${weight_share}" --${discrim} --no_lsgan \ --max_iter ${max_iter} --batch ${batch} \ --snapshot ${snapshot} --no_mmd_loss --small 0 --resize 300 --data_flag ${data_flag} ================================================ FILE: scripts/ADDA/adda_cyclegta2cs_score.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 ###################### # loss weight params # ###################### lr=2e-5 momentum=0.9 lambda_d=1 lambda_g=0.1 export LC_ALL=C.UTF-8 export LANG=C.UTF-8 export PYTHONPATH='/usr/bin/python3' ################ # train params # ################ max_iter=25000 crop=600 snapshot=1000 batch=8 weight_share='weights_shared' discrim='discrim_score' ######## # Data # ######## src='cyclegta5' tgt='cityscapes' data_flag='V4_SEM_300' datadir='/nfs/project/libo_i/cycada/data/' resdir="results/${src}_to_${tgt}/adda_sgd_${weight_share}_nolsgan_${discrim}_${data_flag}" # init with pre-trained cyclegta5 model #model='drn26' #baseiter=115000 model='fcn8s' baseiter=100000 base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclegta_V4_SEM_Final_best_model.pth" discrim_model="/nfs/project/libo_i/cycada/pretrained_models/new_Dis_cyclegta5_sem_v4_300x540_iter-iter5440_abv60.pth" outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}_${discrim}" echo $outdir echo $base_model cd /nfs/project/libo_i/cycada # Run python script # python3 scripts/train_fcn_adda.py \ ${outdir} \ --dataset ${src} --dataset ${tgt} --datadir ${datadir} \ --lr ${lr} --momentum ${momentum} --gpu ${gpu} \ --lambda_d ${lambda_d} --lambda_g ${lambda_g} \ --weights_init ${base_model} --model ${model} \ --"${weight_share}" --${discrim} --no_lsgan \ --max_iter ${max_iter} --batch ${batch} --weights_discrim ${discrim_model} \ --snapshot ${snapshot} --no_mmd_loss --small 0 --resize 300 --data_flag ${data_flag} ================================================ FILE: scripts/ADDA/adda_cyclesyn2cs_feat.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 ###################### # loss weight params # ###################### lr=1e-5 momentum=0.99 lambda_d=1 lambda_g=0.1 export LC_ALL=C.UTF-8 export LANG=C.UTF-8 export PYTHONPATH='/usr/bin/python3' ################ # train params # ################ max_iter=100000 crop=800 snapshot=5000 batch=4 weight_share='weights_shared' discrim='discrim_score' ######## # Data # ######## src='cyclesynthia' tgt='cityscapes' data_flag='V2_SEM' datadir='/nfs/project/libo_i/cycada/data/' resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}" # init with pre-trained cyclegta5 model #model='drn26' #baseiter=115000 model='fcn8s' baseiter=100000 base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V2_SEM_fcn8s-iter21000.pth" outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}" echo $outdir echo $base_model cd /nfs/project/libo_i/cycada # Run python script # python3 scripts/train_fcn_adda.py ${outdir} \ --dataset ${src} --dataset ${tgt} --datadir ${datadir} \ --lr ${lr} --momentum ${momentum} --gpu ${gpu} \ --lambda_d ${lambda_d} --lambda_g ${lambda_g} \ --weights_init ${base_model} --model ${model} \ --"${weight_share}" --${discrim} --no_lsgan \ --max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \ --snapshot ${snapshot} ================================================ FILE: scripts/ADDA/adda_cyclesyn2cs_score.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 ###################### # loss weight params # ###################### lr=1e-5 momentum=0.99 lambda_d=1 lambda_g=0.1 export LC_ALL=C.UTF-8 export LANG=C.UTF-8 export PYTHONPATH='/usr/bin/python3' ################ # train params # ################ max_iter=100000 crop=800 snapshot=5000 batch=4 weight_share='weights_shared' discrim='discrim_score' ######## # Data # ######## src='cyclesynthia' tgt='cityscapes' data_flag='V2_SEM' datadir='/nfs/project/libo_i/cycada/data/' resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}" # init with pre-trained cyclegta5 model #model='drn26' #baseiter=115000 model='fcn8s' baseiter=100000 base_model="/nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V2_SEM_fcn8s-iter21000.pth" outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}" echo $outdir echo $base_model cd /nfs/project/libo_i/cycada # Run python script # python3 scripts/train_fcn_adda.py ${outdir} \ --dataset ${src} --dataset ${tgt} --datadir ${datadir} \ --lr ${lr} --momentum ${momentum} --gpu ${gpu} \ --lambda_d ${lambda_d} --lambda_g ${lambda_g} \ --weights_init ${base_model} --model ${model} \ --"${weight_share}" --${discrim} --no_lsgan \ --max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \ --snapshot ${snapshot} ================================================ FILE: scripts/ADDA/adda_templates.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 ###################### # loss weight params # ###################### lr=1e-5 momentum=0.99 lambda_d=1 lambda_g=0.1 export LC_ALL=C.UTF-8 export LANG=C.UTF-8 export PYTHONPATH='/usr/bin/python3' ################ # train params # ################ max_iter=100000 crop=800 snapshot=5000 batch=4 weight_share='weights_shared' discrim='discrim_score' ######## # Data # ######## src=$1 tgt='cityscapes' data_flag=$2 datadir='/nfs/project/libo_i/cycada/data/' resdir="results/${src}_to_${tgt}/adda_sgd/${weight_share}_nolsgan_${discrim}" # init with pre-trained cyclegta5 model #model='drn26' #baseiter=115000 model=$2 baseiter=$3 base_model=$4 outdir="${resdir}/${model}/lr${lr}_crop${crop}_ld${lambda_d}_lg${lambda_g}_momentum${momentum}" echo $outdir echo $base_model cd /nfs/project/libo_i/cycada # Run python script # python3 scripts/train_fcn_adda.py \ ${outdir} \ --dataset ${src} --dataset ${tgt} --datadir ${datadir} \ --lr ${lr} --momentum ${momentum} --gpu ${gpu} \ --lambda_d ${lambda_d} --lambda_g ${lambda_g} \ --weights_init ${base_model} --model ${model} \ --"${weight_share}" --${discrim} --no_lsgan \ --max_iter ${max_iter} --crop_size ${crop} --batch ${batch} \ --snapshot ${snapshot} ================================================ FILE: scripts/CycleGAN/cyclegan_gta2cityscapes.sh ================================================ #!/usr/bin/env bash cd /nfs/project/libo_i/MADAN/cyclegan sudo python3 train.py --name cyclegan_gta2cityscapes \ --resize_or_crop scale_width_and_crop --loadSize 600 --fineSize 500 --which_model_netD n_layers --n_layers_D 3 \ --model cycle_gan_semantic_fcn --no_flip --batchSize 2 --nThreads 8 \ --dataset_mode gta5_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \ --model_type drn26 --weights_init /nfs/project/libo_i/MADAN/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \ --semantic_loss --gpu 0 ================================================ FILE: scripts/CycleGAN/cyclegan_gta_synthia2cityscapes.sh ================================================ #!/usr/bin/env bash cd /nfs/project/libo_i/MADAN/cyclegan python3 train.py --name cyclegan_gta_synthia2cityscapes_noIdentity \ --resize_or_crop scale_width_and_crop --loadSize 500 --fineSize 400 \ --model multi_cycle_gan_semantic --no_flip --batchSize 4 \ --dataset_mode gta_synthia_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \ --DSC --general_semantic_weight 20 --CCD --SAD --CCD_weight 0.2 --SAD_frozen_epoch 5 --CCD_frozen_epoch 10 --max_epoch 40 \ --weights_syn /nfs/project/libo_i/MADAN/pretrained_models/cyclesynthia_drn26_iter2000.pth \ --weights_gta /nfs/project/libo_i/cycada/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \ --gpu 0,1,2,3 --semantic_loss ================================================ FILE: scripts/CycleGAN/cyclegan_synthia2cityscapes.sh ================================================ #!/usr/bin/env bash cd /root/MADAN/cyclegan python3 train.py --name cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_SEM_final_scale \ --resize_or_crop scale_width_and_crop --loadSize 500 --fineSize 400 \ --model multi_cycle_gan_semantic --no_flip --batchSize 4 \ --dataset_mode gta_synthia_cityscapes --dataroot /nfs/project/libo_i/MADAN/data \ --DSC --general_semantic_weight 20 --CCD --SAD --CCD_weight 0.2 --SAD_frozen_epoch 5 --CCD_frozen_epoch 10 --max_epoch 40 --gpu 0,1,2,3 \ --weights_syn /nfs/project/libo_i/cycada/pretrained_models/cyclesynthia_V4_SEM_Final_iter_6000.pth \ --weights_gta /nfs/project/libo_i/cycada/pretrained_models/drn26_cycada_cyclegta2cityscapes.pth \ --gpu 0,1,2,3 --semantic_loss ================================================ FILE: scripts/CycleGAN/test_templates.sh ================================================ #!/usr/bin/env bash how_many=100000 cd /root/MADAN/cyclegan name=$1 epoch=$2 python3 test.py --name ${name} --resize_or_crop=None \ --which_model_netD n_layers --n_layers_D 3 \ --model $3 --loadSize 600 \ --no_flip --batchSize 32 --nThreads 16 \ --dataset_mode $4 --dataroot /nfs/project/libo_i/cycada/data \ --which_direction AtoB \ --phase train --out_all \ --how_many ${how_many} --which_epoch ${epoch} --gpu 0 ================================================ FILE: scripts/CycleGAN/test_templates_cycle.sh ================================================ #!/usr/bin/env bash # Sequentially load two generators(GTA, Synthia) and finish how_many=100000 cd /root/MADAN/cyclegan model=$1 epoch=$2 python3 test.py --name ${model} --resize_or_crop=None \ --which_model_netD n_layers --n_layers_D 3 \ --model $3 \ --no_flip --batchSize 32 --nThreads 16 \ --dataset_mode $4 --dataroot /nfs/project/libo_i/cycada/data \ --which_direction AtoB \ --phase train --out_all \ --how_many ${how_many} --which_epoch ${epoch} --gpu 0 python3 test.py --name ${model} --resize_or_crop=None \ --which_model_netD n_layers --n_layers_D 3 \ --model $3 \ --no_flip --batchSize 32 --nThreads 16 \ --dataset_mode $5 --dataroot /nfs/project/libo_i/cycada/data \ --which_direction AtoB \ --phase train --out_all \ --how_many ${how_many} --which_epoch ${epoch} --gpu 0 # cyclegan/test_templates_cycle.sh cycada_gta_synthia2cityscapes_noIdentity_D12D21D3_SEM_final_scale 15 test synthia_cityscapes gta5_cityscapes ================================================ FILE: scripts/FCN/train_fcn8s_cyclesgta5.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 data=cyclegta5 model=fcn8s export LC_ALL=C.UTF-8 export LANG=C.UTF-8 datadir=/root/MADAN/data batch=8 iterations=30000 snapshot=2000 num_cls=19 data_flag=V4_SEM_Final_Scale cd /root/MADAN outdir=/root/MADAN/results/${data}/${data}_${data_flag}/${model} mkdir -p results/${data}/${data}_${data_flag}/${model} echo $outdir python3 scripts/train_fcn.py ${outdir} --model ${model} \ --num_cls ${num_cls} --gpu ${gpu} \ -b ${batch} --adam \ --iterations ${iterations} \ --datadir ${datadir} \ --snapshot ${snapshot} \ --dataset ${data} --data_flag ${data_flag} ================================================ FILE: scripts/FCN/train_fcn8s_cyclesynthia.sh ================================================ #!/usr/bin/env bash gpu=0,1,2,3 data=cyclesynthia model=fcn8s export LC_ALL=C.UTF-8 export LANG=C.UTF-8 datadir=/root/MADAN/data batch=28 iterations=20000 snapshot=1000 num_cls=19 data_flag=V4_SEM_Final_Scale cd /root/MADAN outdir=/root/MADAN/cycada/results/${data}/${data}_${data_flag}/${model} mkdir -p results/${data}/${data}_${data_flag}/${model} echo $outdir python3 scripts/train_fcn.py ${outdir} --model ${model} \ --num_cls ${num_cls} --gpu ${gpu} \ -b ${batch} --adam \ --iterations ${iterations} \ --datadir ${datadir} \ --snapshot ${snapshot} --small 1 \ --dataset ${data} --data_flag ${data_flag} ================================================ FILE: scripts/eval_fcn.py ================================================ import os import sys from torchvision.transforms import transforms sys.path.append('/nfs/project/libo_iMADAN') import json import click import numpy as np import torch from torch.autograd import Variable from tqdm import * from cycada.data.data_loader import dataset_obj, get_fcn_dataset from cycada.models.models import get_model, models from cycada.util import to_tensor_raw import torchvision from PIL import Image loader = transforms.Compose([ transforms.ToTensor()]) unloader = transforms.ToPILImage() def fmt_array(arr, fmt=','): strs = ['{:.3f}'.format(x) for x in arr] return fmt.join(strs) def fast_hist(a, b, n): k = (a >= 0) & (a < n) return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) def result_stats(hist): acc_overall = np.diag(hist).sum() / hist.sum() * 100 acc_percls = np.diag(hist) / (hist.sum(1) + 1e-8) * 100 iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-8) * 100 freq = hist.sum(1) / hist.sum() fwIU = (freq[freq > 0] * iu[freq > 0]).sum() return acc_overall, acc_percls, iu, fwIU @click.command() @click.argument('path', type=click.Path(exists=True)) @click.option('--dataset', default='cityscapes', type=click.Choice(dataset_obj.keys())) @click.option('--datadir', default='', type=click.Path(exists=True)) @click.option('--model', default='fcn8s', type=click.Choice(models.keys())) @click.option('--gpu', default='0') @click.option('--num_cls', default=19) @click.option('--batch_size', default=16) @click.option('--loadSize', default=None) @click.option('--fineSize', default=None) def main(path, dataset, datadir, model, gpu, num_cls, batch_size, loadSize, fineSize): os.environ['CUDA_VISIBLE_DEVICES'] = gpu net = get_model(model, num_cls=num_cls, weights_init=path) str_ids = gpu.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) # set gpu ids if len(gpu_ids) > 0: torch.cuda.set_device(gpu_ids[0]) assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) net.eval() if (loadSize and fineSize) is not None: print("Loading Center Crop DataLoader Transform") data_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.BICUBIC), net.module.transform.transforms[0], net.module.transform.transforms[1]]) target_transform = torchvision.transforms.Compose([transforms.Resize([int(loadSize), int(int(fineSize) * 1.8)], interpolation=Image.NEAREST), transforms.Lambda(lambda img: to_tensor_raw(img))]) else: data_transform = net.module.transform target_transform = torchvision.transforms.Compose([transforms.Lambda(lambda img: to_tensor_raw(img))]) ds = get_fcn_dataset(dataset, datadir, num_cls=num_cls, split='val', transform=data_transform, target_transform=target_transform) classes = ds.classes loader = torch.utils.data.DataLoader(ds, num_workers=16, batch_size=batch_size) errs = [] hist = np.zeros((num_cls, num_cls)) if len(loader) == 0: print('Empty data loader') return iterations = tqdm(enumerate(loader)) for im_i, (im, label) in iterations: if im_i == 0: print(im.size()) print(label.size()) if im_i > 32: break im = Variable(im.cuda()) score = net(im).data _, preds = torch.max(score, 1) hist += fast_hist(label.numpy().flatten(), preds.cpu().numpy().flatten(), num_cls) acc_overall, acc_percls, iu, fwIU = result_stats(hist) iterations.set_postfix({'mIoU': ' {:0.2f} fwIoU: {:0.2f} pixel acc: {:0.2f} per cls acc: {:0.2f}'.format(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls))}) print() synthia_metric_iu = 0 # line = "" for index, item in enumerate(classes): print(classes[index], " {:0.1f}".format(iu[index])) if classes[index] != 'terrain' and classes[index] != 'truck' and classes[index] != 'train': synthia_metric_iu += iu[index] # line += " {:0.1f} &".format(iu[index]) # variable "line" is used for adding format results into latex grids # print(line) print(np.nanmean(iu), fwIU, acc_overall, np.nanmean(acc_percls)) print("16 Class-Wise mIOU is {}".format(synthia_metric_iu / 16)) print('Errors:', errs) cur_path = path.split('/')[-1] parent_path = path.replace(cur_path, '') results_dict_path = os.path.join(parent_path, 'result.json') results_dict = {} results_dict[cur_path] = [np.nanmean(iu), synthia_metric_iu / 16] if os.path.exists(results_dict_path) is False: with open(results_dict_path, 'w') as fp: json.dump(results_dict, fp) else: with open(results_dict_path, 'r') as fp: exist_dict = json.load(fp) with open(results_dict_path, 'w') as fp: exist_dict.update(results_dict) json.dump(exist_dict, fp) if __name__ == '__main__': main() ================================================ FILE: scripts/train_fcn.py ================================================ import logging import os.path import sys from collections import deque import click import numpy as np import torch import torch.nn.functional as F import torchvision from PIL import Image from tensorboardX import SummaryWriter sys.path.append('/nfs/project/libo_iMADAN') from cycada.data.data_loader import get_fcn_dataset as get_dataset from cycada.models import get_model from cycada.models.models import models from cycada.transforms import augment_collate from cycada.util import config_logging from cycada.util import to_tensor_raw, step_lr from cycada.tools.util import make_variable def to_tensor_raw(im): return torch.from_numpy(np.array(im, np.int64, copy=False)) def roundrobin_infinite(*loaders): if not loaders: return iters = [iter(loader) for loader in loaders] while True: for i in range(len(iters)): it = iters[i] try: yield next(it) except StopIteration: iters[i] = iter(loaders[i]) yield next(iters[i]) def supervised_loss(score, label, weights=None): loss_fn_ = torch.nn.NLLLoss2d(weight=weights, size_average=True, ignore_index=255) loss = loss_fn_(F.log_softmax(score), label) return loss @click.command() @click.argument('output') @click.option('--dataset', required=True, multiple=True) @click.option('--datadir', default="", type=click.Path(exists=True)) @click.option('--batch_size', '-b', default=1) @click.option('--lr', '-l', default=0.001) @click.option('--step', type=int) @click.option('--iterations', '-i', default=100000) @click.option('--momentum', '-m', default=0.9) @click.option('--snapshot', '-s', default=5000) @click.option('--downscale', type=int) @click.option('--resize_to', type=int, default=720) @click.option('--augmentation/--no-augmentation', default=False) @click.option('--adam/--sgd', default=False) @click.option('--small', type=int, default=2) @click.option('--preprocessing', default=False) @click.option('--force_split', default=False) @click.option('--fyu/--torch', default=False) @click.option('--crop_size', default=720) @click.option('--weights', type=click.Path(exists=True)) @click.option('--model_weights', type=click.Path(exists=True)) @click.option('--model', default='fcn8s', type=click.Choice(models.keys())) @click.option('--num_cls', default=19, type=int) @click.option('--nthreads', default=8, type=int) @click.option('--gpu', default='0') @click.option('--start_step', default=0) @click.option('--data_flag', default='', type=str) @click.option('--rundir_flag', default='', type=str) @click.option('--serial_batches', type=bool, default=False, help='if true, takes images in order to make batches, otherwise takes them randomly') def main(output, dataset, datadir, batch_size, lr, step, iterations, momentum, snapshot, downscale, augmentation, fyu, crop_size, weights, model, gpu, num_cls, nthreads, model_weights, data_flag, serial_batches, resize_to, start_step, preprocessing, small, rundir_flag, force_split, adam): if weights is not None: raise RuntimeError("weights don't work because eric is bad at coding") os.environ['CUDA_VISIBLE_DEVICES'] = gpu config_logging() logdir_flag = data_flag if rundir_flag != "": logdir_flag += "_{}".format(rundir_flag) logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset), logdir_flag) writer = SummaryWriter(log_dir=logdir) if model == 'fcn8s': net = get_model(model, num_cls=num_cls, weights_init=model_weights) else: net = get_model(model, num_cls=num_cls, finetune=True, weights_init=model_weights) net.cuda() str_ids = gpu.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) # set gpu ids if len(gpu_ids) > 0: torch.cuda.set_device(gpu_ids[0]) assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) transform = [] target_transform = [] if preprocessing: transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)])]) target_transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.NEAREST)]) transform.extend([net.module.transform]) target_transform.extend([to_tensor_raw]) transform = torchvision.transforms.Compose(transform) target_transform = torchvision.transforms.Compose(target_transform) if force_split: datasets = [] datasets.append( get_dataset(dataset[0], os.path.join(datadir, dataset[0]), num_cls=num_cls, transform=transform, target_transform=target_transform, data_flag=data_flag)) datasets.append( get_dataset(dataset[1], os.path.join(datadir, dataset[1]), num_cls=num_cls, transform=transform, target_transform=target_transform)) else: datasets = [get_dataset(name, os.path.join(datadir, name), num_cls=num_cls, transform=transform, target_transform=target_transform, data_flag=data_flag) for name in dataset] if weights is not None: weights = np.loadtxt(weights) if adam: print("Using Adam") opt = torch.optim.Adam(net.module.parameters(), lr=1e-4) else: print("Using SGD") opt = torch.optim.SGD(net.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) if augmentation: collate_fn = lambda batch: augment_collate(batch, crop=crop_size, flip=True) else: collate_fn = torch.utils.data.dataloader.default_collate loaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads, collate_fn=collate_fn, pin_memory=True) for dataset in datasets] iteration = start_step losses = deque(maxlen=10) for loader in loaders: loader.dataset.__getitem__(0, debug=True) for im, label in roundrobin_infinite(*loaders): # Clear out gradients opt.zero_grad() # load data/label im = make_variable(im, requires_grad=False) label = make_variable(label, requires_grad=False) if iteration == 0: print("im size: {}".format(im.size())) print("label size: {}".format(label.size())) # forward pass and compute loss preds = net(im) loss = supervised_loss(preds, label) # backward pass loss.backward() losses.append(loss.item()) # step gradients opt.step() # log results if iteration % 10 == 0: logging.info('Iteration {}:\t{}'.format(iteration, np.mean(losses))) writer.add_scalar('loss', np.mean(losses), iteration) iteration += 1 if step is not None and iteration % step == 0: logging.info('Decreasing learning rate by 0.1.') step_lr(opt, 0.1) if iteration % snapshot == 0: torch.save(net.module.state_dict(), '{}/iter_{}.pth'.format(output, iteration)) if iteration >= iterations: logging.info('Optimization complete.') if __name__ == '__main__': main() ================================================ FILE: scripts/train_fcn_adda.py ================================================ import logging import os import os.path import sys from collections import deque from datetime import datetime import click import numpy as np import torch import torch.nn.functional as F from tensorboardX import SummaryWriter from torch.autograd import Variable sys.path.append('/nfs/project/libo_iMADAN') from cycada.data.adda_datasets import AddaDataLoader from cycada.models import get_model from cycada.models.models import models from cycada.models import Discriminator from cycada.util import config_logging from cycada.tools.util import make_variable, mmd_loss def check_label(label, num_cls): "Check that no labels are out of range" label_classes = np.unique(label.numpy().flatten()) label_classes = label_classes[label_classes < 255] if len(label_classes) == 0: print('All ignore labels') return False class_too_large = label_classes.max() > num_cls if class_too_large or label_classes.min() < 0: print('Labels out of bound') print(label_classes) return False return True def forward_pass(net, discriminator, im, requires_grad=False, discrim_feat=False): if discrim_feat: score, feat = net(im) dis_score = discriminator(feat) else: score = net(im) dis_score = discriminator(score) if not requires_grad: score = Variable(score.data, requires_grad=False) return score, dis_score def supervised_loss(score, label, weights=None): loss_fn_ = torch.nn.NLLLoss(weight=weights, reduction='mean', ignore_index=255) loss = loss_fn_(F.log_softmax(score, dim=1), label) return loss def discriminator_loss(score, target_val, lsgan=False): if lsgan: loss = 0.5 * torch.mean((score - target_val) ** 2) else: _, _, h, w = score.size() target_val_vec = Variable(target_val * torch.ones(1, h, w), requires_grad=False).long().cuda() loss = supervised_loss(score, target_val_vec) return loss def fast_hist(a, b, n): k = (a >= 0) & (a < n) return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) def seg_accuracy(score, label, num_cls): _, preds = torch.max(score.data, 1) hist = fast_hist(label.cpu().numpy().flatten(), preds.cpu().numpy().flatten(), num_cls) intersections = np.diag(hist) unions = (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-8) * 100 acc = np.diag(hist).sum() / hist.sum() return intersections, unions, acc @click.command() @click.argument('output') @click.option('--dataset', required=True, multiple=True) @click.option('--datadir', default="", type=click.Path(exists=True)) @click.option('--lr', '-l', default=0.0001) @click.option('--momentum', '-m', default=0.9) @click.option('--batch', default=1) @click.option('--snapshot', '-s', default=5000) @click.option('--downscale', type=int) @click.option('--resize', default=None, type=int) @click.option('--crop_size', default=None, type=int) @click.option('--half_crop', default=None) @click.option('--cls_weights', type=click.Path(exists=True)) @click.option('--weights_discrim', type=click.Path(exists=True)) @click.option('--weights_init', type=click.Path(exists=True)) @click.option('--model', default='fcn8s', type=click.Choice(models.keys())) @click.option('--lsgan/--no_lsgan', default=False) @click.option('--num_cls', type=int, default=19) @click.option('--gpu', default='0') @click.option('--max_iter', default=10000) @click.option('--lambda_d', default=1.0) @click.option('--lambda_g', default=1.0) @click.option('--train_discrim_only', default=False) @click.option('--with_mmd_loss/--no_mmd_loss', default=False) @click.option('--discrim_feat/--discrim_score', default=False) @click.option('--weights_shared/--weights_unshared', default=False) @click.option('--data_flag', type=str, default=None) @click.option('--small', type=int, default=2) def main(output, dataset, datadir, lr, momentum, snapshot, downscale, cls_weights, gpu, weights_init, num_cls, lsgan, max_iter, lambda_d, lambda_g, train_discrim_only, weights_discrim, crop_size, weights_shared, discrim_feat, half_crop, batch, model, data_flag, resize, with_mmd_loss, small): # So data is sampled in consistent way np.random.seed(1336) torch.manual_seed(1336) logdir = 'runs/{:s}/{:s}_to_{:s}/lr{:.1g}_ld{:.2g}_lg{:.2g}'.format(model, dataset[0], dataset[1], lr, lambda_d, lambda_g) if weights_shared: logdir += '_weights_shared' else: logdir += '_weights_unshared' if discrim_feat: logdir += '_discrim_feat' else: logdir += '_discrim_score' logdir += '/' + datetime.now().strftime('%Y_%b_%d-%H:%M') writer = SummaryWriter(log_dir=logdir) os.environ['CUDA_VISIBLE_DEVICES'] = gpu config_logging() print('Train Discrim Only', train_discrim_only) if model == 'fcn8s': net = get_model(model, num_cls=num_cls, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) else: net = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) net.cuda() str_ids = gpu.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) # set gpu ids if len(gpu_ids) > 0: torch.cuda.set_device(gpu_ids[0]) assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) if weights_shared: net_src = net # shared weights else: net_src = get_model(model, num_cls=num_cls, finetune=True, pretrained=True, weights_init=weights_init, output_last_ft=discrim_feat) net_src.eval() # initialize Discrminator odim = 1 if lsgan else 2 idim = num_cls if not discrim_feat else 4096 print('Discrim_feat', discrim_feat, idim) print('Discriminator init weights: ', weights_discrim) discriminator = Discriminator(input_dim=idim, output_dim=odim, pretrained=not (weights_discrim == None), weights_init=weights_discrim).cuda() discriminator.to(gpu_ids[0]) discriminator = torch.nn.DataParallel(discriminator, gpu_ids) loader = AddaDataLoader(net.module.transform, dataset, datadir, downscale, resize=resize, crop_size=crop_size, half_crop=half_crop, batch_size=batch, shuffle=True, num_workers=16, src_data_flag=data_flag, small=small) print('dataset', dataset) # Class weighted loss? if cls_weights is not None: weights = np.loadtxt(cls_weights) else: weights = None # setup optimizers opt_dis = torch.optim.SGD(discriminator.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) opt_rep = torch.optim.SGD(net.module.parameters(), lr=lr, momentum=momentum, weight_decay=0.0005) iteration = 0 num_update_g = 0 last_update_g = -1 losses_super_s = deque(maxlen=100) losses_super_t = deque(maxlen=100) losses_dis = deque(maxlen=100) losses_rep = deque(maxlen=100) accuracies_dom = deque(maxlen=100) intersections = np.zeros([100, num_cls]) iu_deque = deque(maxlen=100) unions = np.zeros([100, num_cls]) accuracy = deque(maxlen=100) print('Max Iter:', max_iter) net.train() discriminator.train() loader.loader_src.dataset.__getitem__(0, debug=True) loader.loader_tgt.dataset.__getitem__(0, debug=True) while iteration < max_iter: for im_s, im_t, label_s, label_t in loader: if iteration == 0: print("IM S: {}".format(im_s.size())) print("Label S: {}".format(label_s.size())) print("IM T: {}".format(im_t.size())) print("Label T: {}".format(label_t.size())) if iteration > max_iter: break info_str = 'Iteration {}: '.format(iteration) if not check_label(label_s, num_cls): continue ########################### # 1. Setup Data Variables # ########################### im_s = make_variable(im_s, requires_grad=False) label_s = make_variable(label_s, requires_grad=False) im_t = make_variable(im_t, requires_grad=False) label_t = make_variable(label_t, requires_grad=False) ############################# # 2. Optimize Discriminator # ############################# # zero gradients for optimizer opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, feat_s = net_src(im_s) score_s = Variable(score_s.data, requires_grad=False) f_s = Variable(feat_s.data, requires_grad=False) else: score_s = Variable(net_src(im_s).data, requires_grad=False) f_s = score_s dis_score_s = discriminator(f_s) if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = Variable(feat_t.data, requires_grad=False) else: score_t = Variable(net(im_t).data, requires_grad=False) f_t = score_t dis_score_t = discriminator(f_t) dis_pred_concat = torch.cat((dis_score_s, dis_score_t)) # prepare real and fake labels batch_t, _, h, w = dis_score_t.size() batch_s, _, _, _ = dis_score_s.size() dis_label_concat = make_variable( torch.cat( [torch.ones(batch_s, h, w).long(), torch.zeros(batch_t, h, w).long()] ), requires_grad=False) # compute loss for discriminator loss_dis = supervised_loss(dis_pred_concat, dis_label_concat) (lambda_d * loss_dis).backward() losses_dis.append(loss_dis.item()) # optimize discriminator opt_dis.step() # compute discriminator acc pred_dis = torch.squeeze(dis_pred_concat.max(1)[1]) dom_acc = (pred_dis == dis_label_concat).float().mean().item() accuracies_dom.append(dom_acc * 100.) # add discriminator info to log info_str += " domacc:{:0.1f} D:{:.3f}".format(np.mean(accuracies_dom), np.mean(losses_dis)) writer.add_scalar('loss/discriminator', np.mean(losses_dis), iteration) writer.add_scalar('acc/discriminator', np.mean(accuracies_dom), iteration) ########################### # Optimize Target Network # ########################### np.mean(accuracies_dom) > dom_acc_thresh dom_acc_thresh = 60 if train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh: os.makedirs(output, exist_ok=True) torch.save(discriminator.module.state_dict(), '{}/discriminator_abv60.pth'.format(output, iteration)) break if not train_discrim_only and np.mean(accuracies_dom) > dom_acc_thresh: last_update_g = iteration num_update_g += 1 if num_update_g % 1 == 0: print('Updating G with adversarial loss ({:d} times)'.format(num_update_g)) # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_t, feat_t = net(im_t) score_t = Variable(score_t.data, requires_grad=False) f_t = feat_t else: score_t = net(im_t) f_t = score_t # score_t = net(im_t) dis_score_t = discriminator(f_t) # create fake label batch, _, h, w = dis_score_t.size() target_dom_fake_t = make_variable(torch.ones(batch, h, w).long(), requires_grad=False) # compute loss for target net loss_gan_t = supervised_loss(dis_score_t, target_dom_fake_t) (lambda_g * loss_gan_t).backward() losses_rep.append(loss_gan_t.item()) writer.add_scalar('loss/generator', np.mean(losses_rep), iteration) # optimize target net opt_rep.step() # log net update info info_str += ' G:{:.3f}'.format(np.mean(losses_rep)) if (not train_discrim_only) and weights_shared and np.mean(accuracies_dom) > dom_acc_thresh: print('Updating G using source supervised loss.') # zero out optimizer gradients opt_dis.zero_grad() opt_rep.zero_grad() # extract features if discrim_feat: score_s, feat_s = net(im_s) else: score_s = net(im_s) loss_supervised_s = supervised_loss(score_s, label_s, weights=weights) if with_mmd_loss: print("Updating G using discrepancy loss") lambda_discrepancy = 0.1 loss_mmd = mmd_loss(feat_s, feat_t) * 0.5 + mmd_loss(score_s, score_t) * 0.5 loss_supervised_s += lambda_discrepancy * loss_mmd loss_supervised_s.backward() losses_super_s.append(loss_supervised_s.item()) info_str += ' clsS:{:.2f}'.format(np.mean(losses_super_s)) writer.add_scalar('loss/supervised/source', np.mean(losses_super_s), iteration) # optimize target net opt_rep.step() # compute supervised losses for target -- monitoring only!!!no backward() loss_supervised_t = supervised_loss(score_t, label_t, weights=weights) losses_super_t.append(loss_supervised_t.item()) info_str += ' clsT:{:.2f}'.format(np.mean(losses_super_t)) writer.add_scalar('loss/supervised/target', np.mean(losses_super_t), iteration) ########################### # Log and compute metrics # ########################### if iteration % 10 == 0 and iteration > 0: # compute metrics intersection, union, acc = seg_accuracy(score_t, label_t.data, num_cls) intersections = np.vstack([intersections[1:, :], intersection[np.newaxis, :]]) unions = np.vstack([unions[1:, :], union[np.newaxis, :]]) accuracy.append(acc.item() * 100) acc = np.mean(accuracy) mIoU = np.mean(np.maximum(intersections, 1) / np.maximum(unions, 1)) * 100 iu = (intersection / union) * 10000 iu_deque.append(np.nanmean(iu)) info_str += ' acc:{:0.2f} mIoU:{:0.2f}'.format(acc, np.mean(iu_deque)) writer.add_scalar('metrics/acc', np.mean(accuracy), iteration) writer.add_scalar('metrics/mIoU', np.mean(mIoU), iteration) logging.info(info_str) iteration += 1 ################ # Save outputs # ################ # every 500 iters save current model if iteration % 500 == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.module.state_dict(), '{}/net-itercurr.pth'.format(output)) torch.save(discriminator.module.state_dict(), '{}/discriminator-itercurr.pth'.format(output)) # save labeled snapshots if iteration % snapshot == 0: os.makedirs(output, exist_ok=True) if not train_discrim_only: torch.save(net.module.state_dict(), '{}/net-iter{}.pth'.format(output, iteration)) torch.save(discriminator.module.state_dict(), '{}/discriminator-iter{}.pth'.format(output, iteration)) if iteration - last_update_g >= 3 * len(loader): print('No suitable discriminator found -- returning.') torch.save(net.module.state_dict(), '{}/net-iter{}.pth'.format(output, iteration)) iteration = max_iter # make sure outside loop breaks break writer.close() if __name__ == '__main__': main() ================================================ FILE: scripts/train_fcn_mdan.py ================================================ import itertools import json import logging import os.path import subprocess import sys from collections import deque import click import numpy as np import torch import torch.nn.functional as F import torchvision from PIL import Image from tensorboardX import SummaryWriter sys.path.append('/nfs/project/libo_iMADAN') from cycada.data.data_loader import get_fcn_dataset as get_dataset from cycada.models import get_model from cycada.models.models import models from cycada.models.MDAN import MDANet from cycada.transforms import augment_collate from cycada.util import config_logging from cycada.util import to_tensor_raw, step_lr from cycada.tools.util import make_variable def to_tensor_raw(im): return torch.from_numpy(np.array(im, np.int64, copy=False)) def roundrobin_infinite(*loaders): if not loaders: return iters = [iter(loader) for loader in loaders] while True: for i in range(len(iters)): it = iters[i] try: yield next(it) except StopIteration: iters[i] = iter(loaders[i]) yield next(iters[i]) def multi_source_infinite(loaders, target_loader): if not loaders: return iters_syn = iter(loaders[0]) iters_gta = iter(loaders[1]) iters_cs = iter(target_loader) while True: try: yield next(iters_syn), next(iters_gta), next(iters_cs) except StopIteration: iters_syn = iter(loaders[0]) iters_gta = iter(loaders[1]) iters_cs = iter(target_loader) yield next(iters_syn), next(iters_gta), next(iters_cs) def supervised_loss(score, label, weights=None): loss_fn_ = torch.nn.NLLLoss2d(weight=weights, size_average=True, ignore_index=255) loss = loss_fn_(F.log_softmax(score), label) return loss @click.command() @click.argument('output') @click.option('--dataset', required=True, multiple=True) @click.option('--target_name', required=True) @click.option('--datadir', default="", type=click.Path(exists=True)) @click.option('--batch_size', '-b', default=1) @click.option('--lr', '-l', default=0.001) @click.option('--iterations', '-i', default=100000) @click.option('--momentum', '-m', default=0.9) @click.option('--snapshot', '-s', default=5000) @click.option('--downscale', type=int) @click.option('--resize_to', type=int, default=720) @click.option('--augmentation/--no-augmentation', default=False) @click.option('--small', type=int, default=2) @click.option('--preprocessing', default=False) @click.option('--fyu/--torch', default=False) @click.option('--crop_size', default=720) @click.option('--weights', type=click.Path(exists=True)) @click.option('--model_weights', type=click.Path(exists=True)) @click.option('--model', default='fcn8s', type=click.Choice(models.keys())) @click.option('--num_cls', default=19, type=int) @click.option('--nthreads', default=16, type=int) @click.option('--gpu', default='0') @click.option('--start_step', default=0) @click.option('--data_flag', default='', type=str) @click.option('--rundir_flag', default='', type=str) @click.option('--serial_batches', type=bool, default=False, help='if true, takes images in order to make batches, otherwise takes them randomly') def main(output, dataset, target_name, datadir, batch_size, lr, iterations, momentum, snapshot, downscale, augmentation, fyu, crop_size, weights, model, gpu, num_cls, nthreads, model_weights, data_flag, serial_batches, resize_to, start_step, preprocessing, small, rundir_flag): if weights is not None: raise RuntimeError("weights don't work because eric is bad at coding") os.environ['CUDA_VISIBLE_DEVICES'] = gpu config_logging() logdir_flag = data_flag if rundir_flag != "": logdir_flag += "_{}".format(rundir_flag) logdir = 'runs/{:s}/{:s}/{:s}'.format(model, '-'.join(dataset), logdir_flag) writer = SummaryWriter(log_dir=logdir) if model == 'fcn8s': net = get_model(model, num_cls=num_cls, weights_init=model_weights, output_last_ft=True) else: net = get_model(model, num_cls=num_cls, finetune=True, weights_init=model_weights) net.cuda() str_ids = gpu.split(',') gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: gpu_ids.append(id) # set gpu ids if len(gpu_ids) > 0: torch.cuda.set_device(gpu_ids[0]) assert (torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) transform = [] target_transform = [] if preprocessing: transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.BICUBIC)]) target_transform.extend([torchvision.transforms.Resize([int(resize_to), int(int(resize_to) * 1.8)], interpolation=Image.NEAREST)]) transform.extend([net.module.transform]) target_transform.extend([to_tensor_raw]) transform = torchvision.transforms.Compose(transform) target_transform = torchvision.transforms.Compose(target_transform) datasets = [get_dataset(name, os.path.join(datadir, name), num_cls=num_cls, transform=transform, target_transform=target_transform, data_flag=data_flag, small=small) for name in dataset] target_dataset = get_dataset(target_name, os.path.join(datadir, target_name), num_cls=num_cls, transform=transform, target_transform=target_transform, data_flag=data_flag, small=small) if weights is not None: weights = np.loadtxt(weights) if augmentation: collate_fn = lambda batch: augment_collate(batch, crop=crop_size, flip=True) else: collate_fn = torch.utils.data.dataloader.default_collate loaders = [torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads, collate_fn=collate_fn, pin_memory=True, drop_last=True) for dataset in datasets] target_loader = torch.utils.data.DataLoader(target_dataset, batch_size=batch_size, shuffle=not serial_batches, num_workers=nthreads, collate_fn=collate_fn, pin_memory=True, drop_last=True) iteration = start_step losses = deque(maxlen=10) losses_domain_syn = deque(maxlen=10) losses_domain_gta = deque(maxlen=10) losses_task = deque(maxlen=10) for loader in loaders: loader.dataset.__getitem__(0, debug=True) input_dim = 2048 configs = {"input_dim": input_dim, "hidden_layers": [1000, 500, 100], "num_classes": 2, 'num_domains': 2, 'mode': 'dynamic', 'mu': 1e-2, 'gamma': 10.0} mdan = MDANet(configs).to(gpu_ids[0]) mdan = torch.nn.DataParallel(mdan, gpu_ids) mdan.train() opt = torch.optim.Adam(itertools.chain(mdan.module.parameters(), net.module.parameters()), lr=1e-4) # cnt = 0 for (im_syn, label_syn), (im_gta, label_gta), (im_cs, label_cs) in multi_source_infinite(loaders, target_loader): # cnt += 1 # print(cnt) # Clear out gradients opt.zero_grad() # load data/label im_syn = make_variable(im_syn, requires_grad=False) label_syn = make_variable(label_syn, requires_grad=False) im_gta = make_variable(im_gta, requires_grad=False) label_gta = make_variable(label_gta, requires_grad=False) im_cs = make_variable(im_cs, requires_grad=False) label_cs = make_variable(label_cs, requires_grad=False) if iteration == 0: print("im_syn size: {}".format(im_syn.size())) print("label_syn size: {}".format(label_syn.size())) print("im_gta size: {}".format(im_gta.size())) print("label_gta size: {}".format(label_gta.size())) print("im_cs size: {}".format(im_cs.size())) print("label_cs size: {}".format(label_cs.size())) if not (im_syn.size() == im_gta.size() == im_cs.size()): print(im_syn.size()) print(im_gta.size()) print(im_cs.size()) # forward pass and compute loss preds_syn, ft_syn = net(im_syn) # pooled_ft_syn = avg_pool(ft_syn) preds_gta, ft_gta = net(im_gta) # pooled_ft_gta = avg_pool(ft_gta) preds_cs, ft_cs = net(im_cs) # pooled_ft_cs = avg_pool(ft_cs) loss_synthia = supervised_loss(preds_syn, label_syn) loss_gta = supervised_loss(preds_gta, label_gta) loss = loss_synthia + loss_gta losses_task.append(loss.item()) logprobs, sdomains, tdomains = mdan(ft_syn, ft_gta, ft_cs) slabels = torch.ones(batch_size, requires_grad=False).type(torch.LongTensor).to(gpu_ids[0]) tlabels = torch.zeros(batch_size, requires_grad=False).type(torch.LongTensor).to(gpu_ids[0]) # TODO: increase task loss # Compute prediction accuracy on multiple training sources. domain_losses = torch.stack([F.nll_loss(sdomains[j], slabels) + F.nll_loss(tdomains[j], tlabels) for j in range(configs['num_domains'])]) losses_domain_syn.append(domain_losses[0].item()) losses_domain_gta.append(domain_losses[1].item()) # Different final loss function depending on different training modes. if configs['mode'] == "maxmin": loss = torch.max(loss) + configs['mu'] * torch.min(domain_losses) elif configs['mode'] == "dynamic": loss = torch.log(torch.sum(torch.exp(configs['gamma'] * (loss + configs['mu'] * domain_losses)))) / configs['gamma'] # backward pass loss.backward() losses.append(loss.item()) torch.nn.utils.clip_grad_norm_(net.module.parameters(), 10) torch.nn.utils.clip_grad_norm_(mdan.module.parameters(), 10) # step gradients opt.step() # log results if iteration % 10 == 0: logging.info( 'Iteration {}:\t{:.3f} Domain SYN: {:.3f} Domain GTA: {:.3f} Task: {:.3f}'.format(iteration, np.mean(losses), np.mean(losses_domain_syn), np.mean(losses_domain_gta), np.mean(losses_task))) writer.add_scalar('loss', np.mean(losses), iteration) writer.add_scalar('domain_syn', np.mean(losses_domain_syn), iteration) writer.add_scalar('domain_gta', np.mean(losses_domain_gta), iteration) writer.add_scalar('task', np.mean(losses_task), iteration) iteration += 1 if iteration % 500 == 0: os.makedirs(output, exist_ok=True) torch.save(net.module.state_dict(), '{}/net-itercurr.pth'.format(output)) if iteration % snapshot == 0: torch.save(net.module.state_dict(), '{}/iter_{}.pth'.format(output, iteration)) if iteration >= iterations: logging.info('Optimization complete.') if __name__ == '__main__': main() ================================================ FILE: tools/__init__.py ================================================ ================================================ FILE: tools/eval_templates.sh ================================================ #!/usr/bin/env bash export LC_ALL=C.UTF-8 export LANG=C.UTF-8 cd /nfs/project/libo_i/MADAN ckpt_path=$1 datadir=/nfs/project/libo_i/MADAN/data/cityscapes model=fcn8s num_cls=19 gpu=0 sudo python3 scripts/eval_fcn.py ${ckpt_path} \ --dataset cityscapes \ --datadir ${datadir} \ --model ${model} --num_cls ${num_cls} \ --gpu ${gpu}