Full Code of junyanz/CycleGAN for AI

master 40b4498526de cached
33 files
150.2 KB
45.9k tokens
1 requests
Download .txt
Repository: junyanz/CycleGAN
Branch: master
Commit: 40b4498526de
Files: 33
Total size: 150.2 KB

Directory structure:
gitextract_7vccy3wl/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── aligned_data_loader.lua
│   ├── base_data_loader.lua
│   ├── data.lua
│   ├── data_util.lua
│   ├── dataset.lua
│   ├── donkey_folder.lua
│   └── unaligned_data_loader.lua
├── examples/
│   ├── test_vangogh_style_on_ae_photos.sh
│   └── train_maps.sh
├── models/
│   ├── architectures.lua
│   ├── base_model.lua
│   ├── bigan_model.lua
│   ├── content_gan_model.lua
│   ├── cycle_gan_model.lua
│   ├── one_direction_test_model.lua
│   └── pix2pix_model.lua
├── options.lua
├── pretrained_models/
│   ├── download_model.sh
│   ├── download_vgg.sh
│   └── places_vgg.prototxt
├── test.lua
├── train.lua
└── util/
    ├── InstanceNormalization.lua
    ├── VGG_preprocess.lua
    ├── content_loss.lua
    ├── cudnn_convert_custom.lua
    ├── image_pool.lua
    ├── plot_util.lua
    ├── util.lua
    └── visualizer.lua

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
datasets/
checkpoints/
results/
build/
dist/
*.png
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~


================================================
FILE: LICENSE
================================================
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


--------------------------- LICENSE FOR pix2pix --------------------------------
BSD License

For pix2pix software
Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

----------------------------- LICENSE FOR DCGAN --------------------------------
BSD License

For dcgan.torch software

Copyright (c) 2015, Facebook, Inc. All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.

Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
<img src='imgs/horse2zebra.gif' align="right" width=384>

<br><br><br>

# CycleGAN
### [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | [project page](https://junyanz.github.io/CycleGAN/) |   [paper](https://arxiv.org/pdf/1703.10593.pdf)

Torch implementation for learning an image-to-image translation (i.e. [pix2pix](https://github.com/phillipi/pix2pix)) **without** input-output pairs, for example:

**New**:  Please check out [contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT), our new unpaired image-to-image translation model that enables fast and memory-efficient training.

<img src="https://junyanz.github.io/CycleGAN/images/teaser_high_res.jpg" width="1000px"/>

[Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://junyanz.github.io/CycleGAN/)  
 [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*,  [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/)  
 Berkeley AI Research Lab, UC Berkeley  
 In ICCV 2017. (* equal contributions)  

This package includes CycleGAN, [pix2pix](https://github.com/phillipi/pix2pix), as well as other methods like [BiGAN](https://arxiv.org/abs/1605.09782)/[ALI](https://ishmaelbelghazi.github.io/ALI/) and Apple's paper [S+U learning](https://arxiv.org/pdf/1612.07828.pdf).  
The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung).  
**Update**: Please check out [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for CycleGAN and pix2pix.
The PyTorch version is under active development and can produce results comparable or better than this Torch version.

## Other implementations:
<p><a href="https://github.com/leehomyc/cyclegan-1"> [Tensorflow]</a> (by Harry Yang),
<a href="https://github.com/architrathore/CycleGAN/">[Tensorflow]</a> (by Archit Rathore),
<a href="https://github.com/vanhuyz/CycleGAN-TensorFlow">[Tensorflow]</a> (by Van Huy),
<a href="https://github.com/XHUJOY/CycleGAN-tensorflow">[Tensorflow]</a> (by Xiaowei Hu),
<a href="https://github.com/LynnHo/CycleGAN-Tensorflow-Simple"> [Tensorflow-simple]</a> (by Zhenliang He),
<a href="https://github.com/luoxier/CycleGAN_Tensorlayer"> [TensorLayer]</a> (by luoxier),
<a href="https://github.com/Aixile/chainer-cyclegan">[Chainer]</a> (by Yanghua Jin),
<a href="https://github.com/yunjey/mnist-svhn-transfer">[Minimal PyTorch]</a> (by yunjey),
<a href="https://github.com/Ldpe2G/DeepLearningForFun/tree/master/Mxnet-Scala/CycleGAN">[Mxnet]</a> (by Ldpe2G),
<a href="https://github.com/tjwei/GANotebooks">[lasagne/Keras]</a> (by tjwei),
<a href="https://github.com/simontomaskarlsson/CycleGAN-Keras">[Keras]</a> (by Simon Karlsson)</p>
</ul>

## Applications
### Monet Paintings to Photos
<img src="https://junyanz.github.io/CycleGAN/images/painting2photo.jpg" width="1000px"/>

### Collection Style Transfer
<img src="https://junyanz.github.io/CycleGAN/images/photo2painting.jpg" width="1000px"/>

### Object Transfiguration
<img src="https://junyanz.github.io/CycleGAN/images/objects.jpg" width="1000px"/>

### Season Transfer
<img src="https://junyanz.github.io/CycleGAN/images/season.jpg" width="1000px"/>

### Photo Enhancement: Narrow depth of field
<img src="https://junyanz.github.io/CycleGAN/images/photo_enhancement.jpg" width="1000px"/>



## Prerequisites
- Linux or OSX
- NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested)
- For MAC users, you need the Linux/GNU commands `gfind` and `gwc`, which can be installed with `brew install findutils coreutils`.

## Getting Started
### Installation
- Install torch and dependencies from https://github.com/torch/distro
- Install torch packages `nngraph`, `class`, `display`
```bash
luarocks install nngraph
luarocks install class
luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec
```
- Clone this repo:
```bash
git clone https://github.com/junyanz/CycleGAN
cd CycleGAN
```

### Apply a Pre-trained Model
- Download the test photos (taken by [Alexei Efros](https://www.flickr.com/photos/aaefros)):
```
bash ./datasets/download_dataset.sh ae_photos
```
- Download the pre-trained model `style_cezanne` (For CPU model, use `style_cezanne_cpu`):
```
bash ./pretrained_models/download_model.sh style_cezanne
```
- Now, let's generate Paul Cézanne style images:
```
DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua
```
The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`.  
Please refer to [Model Zoo](#model-zoo) for more pre-trained models.
`./examples/test_vangogh_style_on_ae_photos.sh` is an example script that downloads the pretrained Van Gogh style network and runs it on Efros's photos.

### Train
- Download a dataset (e.g. zebra and horse images from ImageNet):
```bash
bash ./datasets/download_dataset.sh horse2zebra
```
- Train a model:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua
```
- (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua
```
- (Optionally) start the display server to view results as the model trains. (See [Display UI](#display-ui) for more details):
```bash
th -ldisplay.start 8000 0.0.0.0
```

### Test
- Finally, test the model:
```bash
DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua
```
The test results will be saved to an HTML file here: `./results/horse2zebra_model/latest_test/index.html`.


## Model Zoo
Download the pre-trained models with the following script. The model will be saved to `./checkpoints/model_name/latest_net_G.t7`.
```bash
bash ./pretrained_models/download_model.sh model_name
```
- `orange2apple` (orange -> apple) and `apple2orange`: trained on ImageNet categories `apple` and `orange`.
- `horse2zebra` (horse -> zebra) and `zebra2horse` (zebra -> horse): trained on ImageNet categories `horse` and `zebra`.
- `style_monet` (landscape photo -> Monet painting style),  `style_vangogh` (landscape photo  -> Van Gogh painting style), `style_ukiyoe` (landscape photo  -> Ukiyo-e painting style), `style_cezanne` (landscape photo  -> Cezanne painting style): trained on paintings and Flickr landscape photos.
- `monet2photo` (Monet paintings -> real landscape): trained on paintings and Flickr landscape photographs.
- `cityscapes_photo2label` (street scene -> label) and `cityscapes_label2photo` (label -> street scene): trained on the Cityscapes dataset.
- `map2sat` (map -> aerial photo) and `sat2map` (aerial photo -> map): trained on Google maps.
- `iphone2dslr_flower` (iPhone photos of flowers -> DSLR photos of flowers): trained on Flickr photos.

CPU models can be downloaded using:
```bash
bash pretrained_models/download_model.sh <name>_cpu
```
, where `<name>` can be `horse2zebra`, `style_monet`, etc. You just need to append `_cpu` to the target model.

## Training and Test Details
To train a model,  
```bash
DATA_ROOT=/path/to/data/ name=expt_name th train.lua
```
Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua).  
See `opt_train` in `options.lua` for additional training options.

To test the model,
```bash
DATA_ROOT=/path/to/data/ name=expt_name phase=test th test.lua
```
This will run the model named `expt_name` in both directions on all images in `/path/to/data/testA` and `/path/to/data/testB`.  
A webpage with result images will be saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua).  
See `opt_test` in `options.lua` for additional test options. Please use `model=one_direction_test` if you only would like to generate outputs of the trained network in only one direction, and specify `which_direction=AtoB` or `which_direction=BtoA` to set the direction.

There are other options that can be used. For example, you can specify `resize_or_crop=crop` option to avoid resizing the image to squares. This is indeed how we trained GTA2Cityscapes model in the projet [webpage](https://junyanz.github.io/CycleGAN/) and [Cycada](https://arxiv.org/pdf/1711.03213.pdf) model. We prepared the images at 1024px resolution, and used `resize_or_crop=crop fineSize=360` to work with the cropped images of size 360x360. We also used `lambda_identity=1.0`.

## Datasets
Download the datasets using the following script. Many of the datasets were collected by other researchers. Please cite their papers if you use the data.
```bash
bash ./datasets/download_dataset.sh dataset_name
```
- `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)]
- `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we do not host the dataset on our repo. Please download the dataset directly from the Cityscapes webpage. Please refer to `./datasets/prepare_cityscapes_dataset.py` for more detail.
- `maps`: 1096 training images scraped from Google Maps.
- `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `wild horse` and `zebra`
- `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`.
- `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
- `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
- `iphone2dslr_flower`: both classes of images were downloaded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.


## Display UI
Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display).

- Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec`
- Then start the server with: `th -ldisplay.start`
- Open this URL in your browser: [http://localhost:8000](http://localhost:8000)

By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface:
```bash
th -ldisplay.start 8000 0.0.0.0
```
Then open `http://(hostname):(port)/` in your browser to load the remote desktop.

## Setup Training and Test data
To train CycleGAN model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in  `test.lua`. You can also create subdirectories `testA` and `testB` if you have test data.

You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails.  See the following section for more discussion.


## Failure cases
<img align="left" style="padding:10px" src="https://junyanz.github.io/CycleGAN/images/failure_putin.jpg" width=320>

Our model does not work well when the test image is rather different from the images on which the model is trained, as is the case in the figure to the left (we trained on horses and zebras without riders, but test here one a horse with a rider).  See additional typical failure cases [here](https://junyanz.github.io/CycleGAN/images/failures.jpg). On translation tasks that involve color and texture changes, like many of those reported above, the method often succeeds. We have also explored tasks that require geometric changes, with little success. For example, on the task of `dog<->cat` transfiguration, the learned translation degenerates into making minimal changes to the input. We also observe a lingering gap between the results achievable with paired training data and those achieved by our unpaired method. In some cases, this gap may be very hard -- or even impossible,-- to close: for example, our method sometimes permutes the labels for tree and building in the output of the cityscapes photos->labels task.



## Citation
If you use this code for your research, please cite our [paper](https://junyanz.github.io/CycleGAN/):

```
@inproceedings{CycleGAN2017,
  title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
  author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
  booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
  year={2017}
}

```


## Related Projects:
**[contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT)**<br>
**[pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) |
[BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**<br>
**[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)**

## Cat Paper Collection
If you love cats, and love reading cool graphics, vision, and ML papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers).  


## Acknowledgments
Code borrows from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and  [Context-Encoder](https://github.com/pathak22/context-encoder). The generative network is adopted from [neural-style](https://github.com/jcjohnson/neural-style) with [Instance Normalization](https://github.com/DmitryUlyanov/texture_nets/blob/master/InstanceNormalization.lua).


================================================
FILE: data/aligned_data_loader.lua
================================================
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are aligned
-- The datasets are of the same size
--------------------------------------------------------------------------------
require 'data.base_data_loader'

local class = require 'class'
data_util = paths.dofile('data_util.lua')

AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader')

function AlignedDataLoader:__init(conf)
  BaseDataLoader.__init(self, conf)
  conf = conf or {}
end

function AlignedDataLoader:name()
  return 'AlignedDataLoader'
end

function AlignedDataLoader:Initialize(opt)
  opt.align_data = 1
  self.idx_A = {1, opt.input_nc}
  self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc}
  local nc = 3--opt.input_nc + opt.output_nc
  self.data = data_util.load_dataset('', opt, nc)
end

-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function AlignedDataLoader:LoadBatchForAllDatasets()
  local batch_data, path = self.data:getBatch()
  local batchA = batch_data[{ {}, self.idx_A, {}, {} }]
  local batchB = batch_data[{ {}, self.idx_B, {}, {} }]

  return batchA, batchB, path, path
end

-- returns the size of each dataset
function AlignedDataLoader:size(dataset)
  return self.data:size()
end


================================================
FILE: data/base_data_loader.lua
================================================
--------------------------------------------------------------------------------
-- Base Class for Providing Data
--------------------------------------------------------------------------------

local class = require 'class'
require 'torch'

BaseDataLoader = class('BaseDataLoader')

function BaseDataLoader:__init(conf)
   conf = conf or {}   
   self.data_tm = torch.Timer()
end

function BaseDataLoader:name()
  return 'BaseDataLoader'
end

function BaseDataLoader:Initialize(opt)  
end

-- actually fetches the data
-- |return|: a table of two tables, each corresponding to 
-- the batch for dataset A and dataset B
function BaseDataLoader:LoadBatchForAllDatasets()
  return {},{},{},{}
end

-- returns the next batch
-- a wrapper of getBatch(), which is meant to be overriden by subclasses
-- |return|: a table of two tables, each corresponding to 
-- the batch for dataset A and dataset B
function BaseDataLoader:GetNextBatch()
  self.data_tm:reset()
  self.data_tm:resume()
  local dataA, dataB, pathA, pathB = self:LoadBatchForAllDatasets()
  self.data_tm:stop()
  return dataA, dataB, pathA, pathB
end

function BaseDataLoader:time_elapsed_to_fetch_data()
  return self.data_tm:time().real
end

-- returns the size of each dataset
function BaseDataLoader:size(dataset)
  return 0
end







================================================
FILE: data/data.lua
================================================
--[[
    This data loader is a modified version of the one from dcgan.torch
    (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua).

    Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
]]--

local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')

local data = {}

local result = {}
local unpack = unpack and unpack or table.unpack

function data.new(n, opt_)
   opt_ = opt_ or {}
   local self = {}
   for k,v in pairs(data) do
      self[k] = v
   end

   local donkey_file = 'donkey_folder.lua'
--   print('n..' .. n)
   if n > 0 then
      local options = opt_
      self.threads = Threads(n,
                             function() require 'torch' end,
                             function(idx)
                                opt = options
                                tid = idx
                                local seed = (opt.manualSeed and opt.manualSeed or 0) + idx
                                torch.manualSeed(seed)
                                torch.setnumthreads(1)
                                print(string.format('Starting donkey with id: %d seed: %d', tid, seed))
                                assert(options, 'options not found')
                                assert(opt, 'opt not given')
                                print(opt)
                                paths.dofile(donkey_file)
                             end

      )
   else
      if donkey_file then paths.dofile(donkey_file) end
--      print('empty threads')
      self.threads = {}
      function self.threads:addjob(f1, f2) f2(f1()) end
      function self.threads:dojob() end
      function self.threads:synchronize() end
   end

   local nSamples = 0
   self.threads:addjob(function() return trainLoader:size() end,
         function(c) nSamples = c end)
   self.threads:synchronize()
   self._size = nSamples

   for i = 1, n do
      self.threads:addjob(self._getFromThreads,
                          self._pushResult)
   end
--   print(self.threads)
   return self
end

function data._getFromThreads()
   assert(opt.batchSize, 'opt.batchSize not found')
   return trainLoader:sample(opt.batchSize)
end

function data._pushResult(...)
   local res = {...}
   if res == nil then
      self.threads:synchronize()
   end
   result[1] = res
end



function data:getBatch()
   -- queue another job
   self.threads:addjob(self._getFromThreads, self._pushResult)
   self.threads:dojob()
   local res = result[1]

   img_data = res[1]
   img_paths =  res[3]

   result[1] = nil
   if torch.type(img_data) == 'table' then
      img_data = unpack(img_data)
   end


   return img_data, img_paths
end

function data:size()
   return self._size
end

return data


================================================
FILE: data/data_util.lua
================================================
local data_util = {}

require 'torch'
-- options =  require '../options.lua'
-- load dataset from the file system
-- |name|: name of the dataset. It's currently either 'A' or 'B'
function data_util.load_dataset(name, opt, nc)
  local tensortype = torch.getdefaulttensortype()
  torch.setdefaulttensortype('torch.FloatTensor')

  local new_opt = options.clone(opt)
  new_opt.manualSeed = torch.random(1, 10000) -- fix seed
  new_opt.nc = nc
  torch.manualSeed(new_opt.manualSeed)
  local data_loader = paths.dofile('../data/data.lua')
  new_opt.phase = new_opt.phase .. name
  local data = data_loader.new(new_opt.nThreads, new_opt)
  print("Dataset Size " .. name .. ": ", data:size())

  torch.setdefaulttensortype(tensortype)
  return data
end

return data_util


================================================
FILE: data/dataset.lua
================================================
--[[
    Copyright (c) 2015-present, Facebook, Inc.
    All rights reserved.

    This source code is licensed under the BSD-style license found in the
    LICENSE file in the root directory of this source tree. An additional grant
    of patent rights can be found in the PATENTS file in the same directory.
]]--

require 'torch'
torch.setdefaulttensortype('torch.FloatTensor')
local ffi = require 'ffi'
local class = require('pl.class')
local dir = require 'pl.dir'
local tablex = require 'pl.tablex'
local argcheck = require 'argcheck'
require 'sys'
require 'xlua'
require 'image'

local dataset = torch.class('dataLoader')

local initcheck = argcheck{
   pack=true,
   help=[[
     A dataset class for images in a flat folder structure (folder-name is class-name).
     Optimized for extremely large datasets (upwards of 14 million images).
     Tested only on Linux (as it uses command-line linux utilities to scale up)
]],
   {check=function(paths)
       local out = true;
       for k,v in ipairs(paths) do
          if type(v) ~= 'string' then
             print('paths can only be of string input');
             out = false
          end
       end
       return out
   end,
    name="paths",
    type="table",
    help="Multiple paths of directories with images"},

   {name="sampleSize",
    type="table",
    help="a consistent sample size to resize the images"},

   {name="split",
    type="number",
    help="Percentage of split to go to Training"
   },
   {name="serial_batches",
    type="number",
    help="if randomly sample training images"},

   {name="samplingMode",
    type="string",
    help="Sampling mode: random | balanced ",
    default = "balanced"},

   {name="verbose",
    type="boolean",
    help="Verbose mode during initialization",
    default = false},

   {name="loadSize",
    type="table",
    help="a size to load the images to, initially",
    opt = true},

   {name="forceClasses",
    type="table",
    help="If you want this loader to map certain classes to certain indices, "
       .. "pass a classes table that has {classname : classindex} pairs."
       .. " For example: {3 : 'dog', 5 : 'cat'}"
       .. "This function is very useful when you want two loaders to have the same "
    .. "class indices (trainLoader/testLoader for example)",
    opt = true},

   {name="sampleHookTrain",
    type="function",
    help="applied to sample during training(ex: for lighting jitter). "
       .. "It takes the image path as input",
    opt = true},

   {name="sampleHookTest",
    type="function",
    help="applied to sample during testing",
    opt = true},
}

function dataset:__init(...)

   -- argcheck
   local args =  initcheck(...)
   print(args)
   for k,v in pairs(args) do self[k] = v end

   if not self.loadSize then self.loadSize = self.sampleSize; end

   if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end
   if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end
   self.image_count = 1
--   print('image_count_init', self.image_count)
   -- find class names
   self.classes = {}
   local classPaths = {}
   if self.forceClasses then
      for k,v in pairs(self.forceClasses) do
         self.classes[k] = v
         classPaths[k] = {}
      end
   end
   local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
   -- loop over each paths folder, get list of unique class names,
   -- also store the directory paths per class
   -- for each class,
   for k,path in ipairs(self.paths) do
--      print('path', path)
      local dirs = {} -- hack
      dirs[1] = path
--      local dirs = dir.getdirectories(path);
      for k,dirpath in ipairs(dirs) do
         local class = paths.basename(dirpath)
         local idx = tableFind(self.classes, class)
--         print(class)
--         print(idx)
         if not idx then
            table.insert(self.classes, class)
            idx = #self.classes
            classPaths[idx] = {}
         end
         if not tableFind(classPaths[idx], dirpath) then
            table.insert(classPaths[idx], dirpath);
         end
      end
   end

   self.classIndices = {}
   for k,v in ipairs(self.classes) do
      self.classIndices[v] = k
   end

   -- define command-line tools, try your best to maintain OSX compatibility
   local wc = 'wc'
   local cut = 'cut'
   local find = 'find -H'  -- if folder name is symlink, do find inside it after dereferencing

  if ffi.os == 'OSX' then
      wc = 'gwc'
      cut = 'gcut'
      find = 'gfind'
   end
   ----------------------------------------------------------------------
   -- Options for the GNU find command
   local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
   local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
   for i=2,#extensionList do
      findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
   end

   -- find the image path names
   self.imagePath = torch.CharTensor()  -- path to each image in dataset
   self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
   self.classList = {}                  -- index of imageList to each image of a particular class
   self.classListSample = self.classList -- the main list used when sampling data

   print('running "find" on each class directory, and concatenate all'
         .. ' those filenames into a single file containing all image paths for a given class')
   -- so, generates one file per class
   local classFindFiles = {}
   for i=1,#self.classes do
      classFindFiles[i] = os.tmpname()
   end
   local combinedFindList = os.tmpname();

   local tmpfile = os.tmpname()
   local tmphandle = assert(io.open(tmpfile, 'w'))
   -- iterate over classes
   for i, class in ipairs(self.classes) do
      -- iterate over classPaths
      for j,path in ipairs(classPaths[i]) do
         local command = find .. ' "' .. path .. '" ' .. findOptions
            .. ' >>"' .. classFindFiles[i] .. '" \n'
         tmphandle:write(command)
      end
   end
   io.close(tmphandle)
   os.execute('bash ' .. tmpfile)
   os.execute('rm -f ' .. tmpfile)

   print('now combine all the files to a single large file')
   local tmpfile = os.tmpname()
   local tmphandle = assert(io.open(tmpfile, 'w'))
   -- concat all finds to a single large file in the order of self.classes
   for i=1,#self.classes do
      local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
      tmphandle:write(command)
   end
   io.close(tmphandle)
   os.execute('bash ' .. tmpfile)
   os.execute('rm -f ' .. tmpfile)

   --==========================================================================
   print('load the large concatenated list of sample paths to self.imagePath')
   local cmd = wc .. " -L '"
                                                  .. combinedFindList .. "' |"
                                                  .. cut .. " -f1 -d' '"
   print('cmd..' .. cmd)
   local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
                                                  .. combinedFindList .. "' |"
                                                  .. cut .. " -f1 -d' '")) + 1
   local length = tonumber(sys.fexecute(wc .. " -l '"
                                           .. combinedFindList .. "' |"
                                           .. cut .. " -f1 -d' '"))
   assert(length > 0, "Could not find any image file in the given input paths")
   assert(maxPathLength > 0, "paths of files are length 0?")
   self.imagePath:resize(length, maxPathLength):fill(0)
   local s_data = self.imagePath:data()
   local count = 0
   for line in io.lines(combinedFindList) do
      ffi.copy(s_data, line)
      s_data = s_data + maxPathLength
      if self.verbose and count % 10000 == 0 then
         xlua.progress(count, length)
      end;
      count = count + 1
   end

   self.numSamples = self.imagePath:size(1)
   if self.verbose then print(self.numSamples ..  ' samples found.') end
   --==========================================================================
   print('Updating classList and imageClass appropriately')
   self.imageClass:resize(self.numSamples)
   local runningIndex = 0
   for i=1,#self.classes do
      if self.verbose then xlua.progress(i, #(self.classes)) end
      local length = tonumber(sys.fexecute(wc .. " -l '"
                                              .. classFindFiles[i] .. "' |"
                                              .. cut .. " -f1 -d' '"))
      if length == 0 then
         error('Class has zero samples')
      else
         self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long()
         self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i)
      end
      runningIndex = runningIndex + length
   end

   --==========================================================================
   -- clean up temporary files
   print('Cleaning up temporary files')
   local tmpfilelistall = ''
   for i=1,#(classFindFiles) do
      tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
      if i % 1000 == 0 then
         os.execute('rm -f ' .. tmpfilelistall)
         tmpfilelistall = ''
      end
   end
   os.execute('rm -f '  .. tmpfilelistall)
   os.execute('rm -f "' .. combinedFindList .. '"')
   --==========================================================================

   if self.split == 100 then
      self.testIndicesSize = 0
   else
      print('Splitting training and test sets to a ratio of '
               .. self.split .. '/' .. (100-self.split))
      self.classListTrain = {}
      self.classListTest  = {}
      self.classListSample = self.classListTrain
      local totalTestSamples = 0
      -- split the classList into classListTrain and classListTest
      for i=1,#self.classes do
         local list = self.classList[i]
         local count = self.classList[i]:size(1)
         local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round
         local perm = torch.randperm(count)
         self.classListTrain[i] = torch.LongTensor(splitidx)
         for j=1,splitidx do
            self.classListTrain[i][j] = list[perm[j]]
         end
         if splitidx == count then -- all samples were allocated to train set
            self.classListTest[i]  = torch.LongTensor()
         else
            self.classListTest[i]  = torch.LongTensor(count-splitidx)
            totalTestSamples = totalTestSamples + self.classListTest[i]:size(1)
            local idx = 1
            for j=splitidx+1,count do
               self.classListTest[i][idx] = list[perm[j]]
               idx = idx + 1
            end
         end
      end
      -- Now combine classListTest into a single tensor
      self.testIndices = torch.LongTensor(totalTestSamples)
      self.testIndicesSize = totalTestSamples
      local tdata = self.testIndices:data()
      local tidx = 0
      for i=1,#self.classes do
         local list = self.classListTest[i]
         if list:dim() ~= 0 then
            local ldata = list:data()
            for j=0,list:size(1)-1 do
               tdata[tidx] = ldata[j]
               tidx = tidx + 1
            end
         end
      end
   end
end

-- size(), size(class)
function dataset:size(class, list)
   list = list or self.classList
   if not class then
      return self.numSamples
   elseif type(class) == 'string' then
      return list[self.classIndices[class]]:size(1)
   elseif type(class) == 'number' then
      return list[class]:size(1)
   end
end

-- getByClass
function dataset:getByClass(class)
   local index = 0
   if self.serial_batches == 1 then
     index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1
     self.image_count = self.image_count +1
   else
    index = math.ceil(torch.uniform() * self.classListSample[class]:nElement())
   end

   local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]]))
   return self:sampleHookTrain(imgpath),  imgpath
end

-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, dataTable, scalarTable)
   local data, scalarLabels, labels
  if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
    assert(#scalarTable == 1)
    data = torch.Tensor(1,
                dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3))
    data[1]:copy(dataTable[1])
    scalarLabels = torch.LongTensor(#scalarTable):fill(-1111)
  else
    local quantity = #scalarTable
    data = torch.Tensor(quantity,
                self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])
    scalarLabels = torch.LongTensor(quantity):fill(-1111)
    for i=1,#dataTable do
      data[i]:copy(dataTable[i])
      scalarLabels[i] = scalarTable[i]
    end
  end
   return data, scalarLabels
end

-- sampler, samples from the training set.
function dataset:sample(quantity)
   assert(quantity)
   local dataTable = {}
   local scalarTable = {}
   local samplePaths = {}
   for i=1,quantity do
      local class = torch.random(1, #self.classes)
      local out, imgpath = self:getByClass(class)
      table.insert(dataTable, out)
      table.insert(scalarTable, class)
      samplePaths[i] = imgpath
   end

   local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
   return data, scalarLabels, samplePaths-- filePaths
end

function dataset:get(i1, i2)
   local indices = torch.range(i1, i2);
   local quantity = i2 - i1 + 1;
   assert(quantity > 0)
   -- now that indices has been initialized, get the samples
   local dataTable = {}
   local scalarTable = {}
   for i=1,quantity do
      -- load the sample
      local imgpath = ffi.string(torch.data(self.imagePath[indices[i]]))
      local out = self:sampleHookTest(imgpath)
      table.insert(dataTable, out)
      table.insert(scalarTable, self.imageClass[indices[i]])
   end
   local data, scalarLabels = tableToOutput(self, dataTable, scalarTable)
   return data, scalarLabels
end

return dataset


================================================
FILE: data/donkey_folder.lua
================================================

--[[
    This data loader is a modified version of the one from dcgan.torch
    (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua).
    Copyright (c) 2016, Deepak Pathak [See LICENSE file for details]
    Copyright (c) 2015-present, Facebook, Inc.
    All rights reserved.
    This source code is licensed under the BSD-style license found in the
    LICENSE file in the root directory of this source tree. An additional grant
    of patent rights can be found in the PATENTS file in the same directory.
]]--

require 'image'
paths.dofile('dataset.lua')
-- This file contains the data-loading logic and details.
-- It is run by each data-loader thread.
------------------------------------------
-------- COMMON CACHES and PATHS
-- Check for existence of opt.data
if opt.DATA_ROOT then
  opt.data = paths.concat(opt.DATA_ROOT, opt.phase)
else
  print(os.getenv('DATA_ROOT'))
  opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase)
end

if not paths.dirp(opt.data) then
    error('Did not find directory: ' .. opt.data)
end

-- a cache file of the training metadata (if doesnt exist, will be created)
local cache_prefix = opt.data:gsub('/', '_')
os.execute(('mkdir -p %s'):format(opt.cache_dir))
local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7')

--------------------------------------------------------------------------------------------
local input_nc = opt.nc -- input channels
local loadSize   = {input_nc, opt.loadSize}
local sampleSize = {input_nc, opt.fineSize}

local function loadImage(path)
  local input = image.load(path, 3, 'float')
  local h = input:size(2)
  local w = input:size(3)

  local imA = image.crop(input, 0, 0, w/2, h)
  imA = image.scale(imA, loadSize[2], loadSize[2])
  local imB = image.crop(input, w/2, 0, w, h)
  imB = image.scale(imB, loadSize[2], loadSize[2])

  local perm = torch.LongTensor{3, 2, 1}
  imA = imA:index(1, perm)
  imA = imA:mul(2):add(-1)
  imB = imB:index(1, perm)
  imB = imB:mul(2):add(-1)

  assert(imA:max()<=1,"A: badly scaled inputs")
  assert(imA:min()>=-1,"A: badly scaled inputs")
  assert(imB:max()<=1,"B: badly scaled inputs")
  assert(imB:min()>=-1,"B: badly scaled inputs")


  local oW = sampleSize[2]
  local oH = sampleSize[2]
  local iH = imA:size(2)
  local iW = imA:size(3)

  if iH~=oH then
    h1 = math.ceil(torch.uniform(1e-2, iH-oH))
  end

  if iW~=oW then
    w1 = math.ceil(torch.uniform(1e-2, iW-oW))
  end
  if iH ~= oH or iW ~= oW then
    imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH)
    imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH)
  end

  if opt.flip == 1 and torch.uniform() > 0.5 then
    imA = image.hflip(imA)
    imB = image.hflip(imB)
  end

  local concatenated = torch.cat(imA,imB,1)

  return concatenated
end


local function loadSingleImage(path)
    local im = image.load(path, input_nc, 'float')
    if opt.resize_or_crop == 'resize_and_crop' then
      im = image.scale(im, loadSize[2], loadSize[2])
    end
    if input_nc == 3 then
      local perm = torch.LongTensor{3, 2, 1}
      im = im:index(1, perm)--:mul(256.0): brg, rgb
      im = im:mul(2):add(-1)
    end
    assert(im:max()<=1,"A: badly scaled inputs")
    assert(im:min()>=-1,"A: badly scaled inputs")

    local oW = sampleSize[2]
    local oH = sampleSize[2]
    local iH = im:size(2)
    local iW = im:size(3)
    if (opt.resize_or_crop == 'resize_and_crop' ) then
      local h1, w1 = 0, 0
      if iH~=oH then
        h1 = math.ceil(torch.uniform(1e-2, iH-oH))
      end
      if iW~=oW then
        w1 = math.ceil(torch.uniform(1e-2, iW-oW))
      end
      if iH ~= oH or iW ~= oW then
        im = image.crop(im, w1, h1, w1 + oW, h1 + oH)
      end
    elseif (opt.resize_or_crop == 'combined') then
      local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2)
      local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2)
      local h1 = math.ceil(torch.uniform(1e-2, iH-sH))
      local w1 = math.ceil(torch.uniform(1e-2, iW-sW))
      im = image.crop(im, w1, h1, w1 + sW, h1 + sH)
      im = image.scale(im, oW, oH)
    elseif (opt.resize_or_crop == 'crop') then
      local w = math.min(math.min(oH, iH),iW)
      w = math.floor(w/4)*4
      local x = math.floor(torch.uniform(0, iW - w))
      local y = math.floor(torch.uniform(0, iH - w))
      im = image.crop(im, x, y, x+w, y+w)
    elseif (opt.resize_or_crop == 'scale_width') then
      w = oW
      h = torch.floor(iH * oW/iW)
      im = image.scale(im, w, h)
    elseif (opt.resize_or_crop == 'scale_height') then
      h = oH
      w = torch.floor(iW * oH / iH)
      im = image.scale(im, w, h)
    end

    if opt.flip == 1 and torch.uniform() > 0.5 then
        im = image.hflip(im)
    end

  return im

end

-- channel-wise mean and std. Calculate or load them from disk later in the script.
local mean,std
--------------------------------------------------------------------------------
-- Hooks that are used for each image that is loaded

-- function to load the image, jitter it appropriately (random crops etc.)
local trainHook_singleimage = function(self, path)
   collectgarbage()
  --  print('load single image')
   local im = loadSingleImage(path)
   return im
end

-- function that loads images that have juxtaposition
-- of two images from two domains
local trainHook_doubleimage = function(self, path)
  -- print('load double image')
  collectgarbage()

  local im = loadImage(path)
  return im
end


if opt.align_data > 0 then
  sample_nc = input_nc*2
  trainHook = trainHook_doubleimage
else
  sample_nc = input_nc
  trainHook = trainHook_singleimage
end

trainLoader = dataLoader{
    paths = {opt.data},
    loadSize = {input_nc, loadSize[2], loadSize[2]},
    sampleSize = {sample_nc, sampleSize[2], sampleSize[2]},
    split = 100,
    serial_batches = opt.serial_batches,
    verbose = true
 }

trainLoader.sampleHookTrain = trainHook
collectgarbage()

-- do some sanity checks on trainLoader
do
   local class = trainLoader.imageClass
   local nClasses = #trainLoader.classes
   assert(class:max() <= nClasses, "class logic has error")
   assert(class:min() >= 1, "class logic has error")
end


================================================
FILE: data/unaligned_data_loader.lua
================================================
--------------------------------------------------------------------------------
-- Subclass of BaseDataLoader that provides data from two datasets.
-- The samples from the datasets are not aligned.
-- The datasets can have different sizes
--------------------------------------------------------------------------------
require 'data.base_data_loader'

local class = require 'class'
data_util = paths.dofile('data_util.lua')

UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader')

function UnalignedDataLoader:__init(conf)
  BaseDataLoader.__init(self, conf)
  conf = conf or {}
end

function UnalignedDataLoader:name()
  return 'UnalignedDataLoader'
end

function UnalignedDataLoader:Initialize(opt)
  opt.align_data = 0
  self.dataA = data_util.load_dataset('A', opt, opt.input_nc)
  self.dataB = data_util.load_dataset('B', opt, opt.output_nc)
end

-- actually fetches the data
-- |return|: a table of two tables, each corresponding to
-- the batch for dataset A and dataset B
function UnalignedDataLoader:LoadBatchForAllDatasets()
  local batchA, pathA = self.dataA:getBatch()
  local batchB, pathB = self.dataB:getBatch()
  return batchA, batchB, pathA, pathB
end

-- returns the size of each dataset
function UnalignedDataLoader:size(dataset)
  if dataset == 'A' then
    return self.dataA:size()
  end

  if dataset == 'B' then
    return self.dataB:size()
  end

  return math.max(self.dataA:size(), self.dataB:size())
  -- return the size of the largest dataset by default
end


================================================
FILE: examples/test_vangogh_style_on_ae_photos.sh
================================================
#!/bin/sh

## This script download the dataset and pre-trained network,
## and generates style transferred images.

# Download the dataset. The downloaded dataset is stored in ./datasets/${DATASET_NAME}
DATASET_NAME='ae_photos'
bash ./datasets/download_dataset.sh $DATASET_NAME

# Download the pre-trained model. The downloaded model is stored in ./models/${MODEL_NAME}_pretrained/latest_net_G.t7
MODEL_NAME='style_vangogh'
bash ./pretrained_models/download_model.sh $MODEL_NAME

# Run style transfer using the downloaded dataset and model
DATA_ROOT=./datasets/$DATASET_NAME name=${MODEL_NAME}_pretrained model=one_direction_test phase=test how_many='all' loadSize=256 fineSize=256 resize_or_crop='scale_width' th test.lua

if [ $? == 0 ]; then
    echo "The result can be viewed at ./results/${MODEL_NAME}_pretrained/latest_test/index.html"
fi


================================================
FILE: examples/train_maps.sh
================================================
DB_NAME='maps'
GPU_ID=1
DISPLAY_ID=1
NET_G=resnet_6blocks
NET_D=basic
MODEL=cycle_gan
SAVE_EPOCH=5
ALIGN_DATA=0
LAMBDA=10
NF=64


EXPR_NAME=${DB_NAME}_${MODEL}_${LAMBDA}

CHECKPOINT_DIR=./checkpoints/
LOG_FILE=${CHECKPOINT_DIR}${EXPR_NAME}/log.txt
mkdir -p ${CHECKPOINT_DIR}${EXPR_NAME}

DATA_ROOT=./datasets/$DB_NAME align_data=$ALIGN_DATA use_lsgan=1 \
which_direction='AtoB' display_plot=$PLOT pool_size=50 niter=100 niter_decay=100 \
which_model_netG=$NET_G which_model_netD=$NET_D model=$MODEL lr=0.0002 print_freq=200 lambda_A=$LAMBDA lambda_B=$LAMBDA \
loadSize=143 fineSize=128 gpu=$GPU_ID display_winsize=128 \
name=$EXPR_NAME flip=1 save_epoch_freq=$SAVE_EPOCH \
continue_train=0 display_id=$DISPLAY_ID \
checkpoints_dir=$CHECKPOINT_DIR\
 th train.lua | tee -a $LOG_FILE


================================================
FILE: models/architectures.lua
================================================
require 'nngraph'


----------------------------------------------------------------------------
local function weights_init(m)
  local name = torch.type(m)
  if name:find('Convolution') then
    m.weight:normal(0.0, 0.02)
    m.bias:fill(0)
  elseif name:find('Normalization') then
    if m.weight then m.weight:normal(1.0, 0.02) end
    if m.bias then m.bias:fill(0) end
  end
end


normalization = nil

function set_normalization(norm)
if norm == 'instance' then
  require 'util.InstanceNormalization'
  print('use InstanceNormalization')
  normalization = nn.InstanceNormalization
elseif norm == 'batch' then
  print('use SpatialBatchNormalization')
  normalization = nn.SpatialBatchNormalization
end
end

function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch)
  local netG = nil
  if     which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
  elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf)
  elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf)
  elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf)
  elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf)
  else error("unsupported netG model")
    end
  netG:apply(weights_init)

  return netG
end

function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid)
  local netD = nil
  if     which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid)
  elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid)
  elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid)
  else error("unsupported netD model")
  end
  netD:apply(weights_init)

  return netD
end

function defineG_encoder_decoder(input_nc, output_nc, ngf)
    -- input is (nc) x 256 x 256
    local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
    -- input is (ngf) x 128 x 128
    local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 8 x 8
    local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 4 x 4
    local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 2 x 2
    local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
    -- input is (ngf * 8) x 1 x 1

    local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 2 x 2
    local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 4 x 4
    local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 8 x 8
    local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
    -- input is (ngf) x128 x 128
    local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1)
    -- input is (nc) x 256 x 256
    local o1 = d8 - nn.Tanh()

    local netG = nn.gModule({e1},{o1})
    return netG
end


function defineG_unet128(input_nc, output_nc, ngf)
    local netG = nil
    -- input is (nc) x 128 x 128
    local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
    -- input is (ngf) x 64 x 64
    local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 2) x 32 x 32
    local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 4) x 16 x 16
    local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 8 x 8
    local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 4 x 4
    local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 2 x 2
    local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8)
    -- input is (ngf * 8) x 1 x 1

    local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 2 x 2
    local d1 = {d1_,e6} - nn.JoinTable(2)
    local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 4 x 4
    local d2 = {d2_,e5} - nn.JoinTable(2)
    local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 8 x 8
    local d3 = {d3_,e4} - nn.JoinTable(2)
    local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 8) x 16 x 16
    local d4 = {d4_,e3} - nn.JoinTable(2)
    local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 4) x 32 x 32
    local d5 = {d5_,e2} - nn.JoinTable(2)
    local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
    -- input is (ngf * 2) x 64 x 64
    local d6 = {d6_,e1} - nn.JoinTable(2)

    local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
    -- input is (nc) x 128 x 128

    local o1 = d7 - nn.Tanh()
    local netG = nn.gModule({e1},{o1})
    return netG
end


function defineG_unet256(input_nc, output_nc, ngf)
    local netG = nil
    -- input is (nc) x 256 x 256
    local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1)
    -- input is (ngf) x 128 x 128
    local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 8 x 8
    local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 4 x 4
    local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 2 x 2
    local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8)
    -- input is (ngf * 8) x 1 x 1

    local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 2 x 2
    local d1 = {d1_,e7} - nn.JoinTable(2)
    local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 4 x 4
    local d2 = {d2_,e6} - nn.JoinTable(2)
    local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5)
    -- input is (ngf * 8) x 8 x 8
    local d3 = {d3_,e5} - nn.JoinTable(2)
    local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8)
    -- input is (ngf * 8) x 16 x 16
    local d4 = {d4_,e4} - nn.JoinTable(2)
    local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4)
    -- input is (ngf * 4) x 32 x 32
    local d5 = {d5_,e3} - nn.JoinTable(2)
    local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2)
    -- input is (ngf * 2) x 64 x 64
    local d6 = {d6_,e2} - nn.JoinTable(2)
    local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf)
    -- input is (ngf) x128 x 128
    local d7 = {d7_,e1} - nn.JoinTable(2)
    local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1)
    -- input is (nc) x 256 x 256

    local o1 = d8 - nn.Tanh()
    local netG = nn.gModule({e1},{o1})
    return netG
end

--------------------------------------------------------------------------------
-- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
--------------------------------------------------------------------------------

local function build_conv_block(dim, padding_type)
  local conv_block = nn.Sequential()
  local p = 0
  if padding_type == 'reflect' then
    conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
  elseif padding_type == 'replicate' then
    conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
  elseif padding_type == 'zero' then
    p = 1
  end
  conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
  conv_block:add(normalization(dim))
  conv_block:add(nn.ReLU(true))
  if padding_type == 'reflect' then
    conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1))
  elseif padding_type == 'replicate' then
    conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1))
  end
  conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p))
  conv_block:add(normalization(dim))
  return conv_block
end


local function build_res_block(dim, padding_type)
  local conv_block = build_conv_block(dim, padding_type)
  local res_block = nn.Sequential()
  local concat = nn.ConcatTable()
  concat:add(conv_block)
  concat:add(nn.Identity())
  
  res_block:add(concat):add(nn.CAddTable())
  return res_block
end

function defineG_resnet_6blocks(input_nc, output_nc, ngf)
  padding_type = 'reflect'
  local ks = 3
  local netG = nil
  local f = 7
  local p = (f - 1) / 2
  local data = -nn.Identity()
  local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
  local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
  local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
  local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
  - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
  local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
  local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
  local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
  netG = nn.gModule({data},{d4})
  return netG
end

function defineG_resnet_9blocks(input_nc, output_nc, ngf)
  padding_type = 'reflect'
  local ks = 3
  local netG = nil
  local f = 7
  local p = (f - 1) / 2
  local data = -nn.Identity()
  local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true)
  local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true)
  local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true)
  local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
  - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type)
  local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true)
  local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true)
  local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh()
  netG = nn.gModule({data},{d4})
  return netG
end

function defineD_imageGAN(input_nc, ndf, use_sigmoid)
    local netD = nn.Sequential()

    -- input is (nc) x 256 x 256
    netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1))
    netD:add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf) x 128 x 128
    netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
    netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*2) x 64 x 64
    netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1))
    netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*4) x 32 x 32
    netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
    netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*8) x 16 x 16
    netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
    netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*8) x 8 x 8
    netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1))
    netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*8) x 4 x 4
    netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1))
    -- state size: 1 x 1 x 1
    if use_sigmoid then
      netD:add(nn.Sigmoid())
    end

	return netD
end



function defineD_basic(input_nc, ndf, use_sigmoid)
    n_layers = 3
    return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid)
end

-- rf=1
function defineD_pixelGAN(input_nc, ndf, use_sigmoid)

    local netD = nn.Sequential()

    -- input is (nc) x 256 x 256
    netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0))
    netD:add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf) x 256 x 256
    netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0))
    netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf*2) x 256 x 256
    netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0))
    -- state size: 1 x 256 x 256
    if use_sigmoid then
      netD:add(nn.Sigmoid())
    -- state size: 1 x 30 x 30
    end

    return netD
end

-- if n=0, then use pixelGAN (rf=1)
-- else rf is 16 if n=1
--            34 if n=2
--            70 if n=3
--            142 if n=4
--            286 if n=5
--            574 if n=6
function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio)

  if dropout_ratio == nil then
    dropout_ratio = 0.0
  end

  if kw == nil then
	kw = 4
  end
  padw = math.ceil((kw-1)/2)

    if n_layers==0 then
        return defineD_pixelGAN(input_nc, ndf, use_sigmoid)
    else

        local netD = nn.Sequential()

        -- input is (nc) x 256 x 256
        -- print('input_nc', input_nc)
        netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw))
        netD:add(nn.LeakyReLU(0.2, true))

        local nf_mult = 1
        local nf_mult_prev = 1
        for n = 1, n_layers-1 do
            nf_mult_prev = nf_mult
            nf_mult = math.min(2^n,8)
            netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw))
            netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio))
            netD:add(nn.LeakyReLU(0.2, true))
        end

        -- state size: (ndf*M) x N x N
        nf_mult_prev = nf_mult
        nf_mult = math.min(2^n_layers,8)
        netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw))
        netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true))
        -- state size: (ndf*M*2) x (N-1) x (N-1)
        netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw))
        -- state size: 1 x (N-2) x (N-2)
        if use_sigmoid then
          netD:add(nn.Sigmoid())
        end
        -- state size: 1 x (N-2) x (N-2)
        return netD
    end
end


================================================
FILE: models/base_model.lua
================================================
--------------------------------------------------------------------------------
-- Base Class for Providing Models
--------------------------------------------------------------------------------

local class = require 'class'

BaseModel = class('BaseModel')

function BaseModel:__init(conf)
   conf = conf or {}
end

-- Returns the name of the model
function BaseModel:model_name()
   return 'DoesNothingModel'
end

-- Defines models and networks
function BaseModel:Initialize(opt)
  models = {}
  return models
end

-- Runs the forward pass of the network
function BaseModel:Forward(input, opt)
  output = {}
  return output
end

-- Runs the backprop gradient descent
-- Corresponds to a single batch of data
function BaseModel:OptimizeParameters(opt)
end

-- This function can be used to reset momentum after each epoch
function BaseModel:RefreshParameters(opt)
end

-- This function can be used to reset momentum after each epoch
function BaseModel:UpdateLearningRate(opt)
end
-- Save the current model to the file system
function BaseModel:Save(prefix, opt)
end

-- returns a string that describes the current errors
function BaseModel:GetCurrentErrorDescription()
  return "No Error exists in BaseModel"
end

-- returns current errors
function BaseModel:GetCurrentErrors(opt)
  return {}
end

-- returns a table of image/label pairs that describe
-- the current results.
-- |return|: a table of table. List of image/label pairs
function BaseModel:GetCurrentVisuals(opt, size)
  return {}
end

-- returns a string that describes the display plot configuration 
function BaseModel:DisplayPlot(opt)
  return {}
end


================================================
FILE: models/bigan_model.lua
================================================
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')

BiGANModel = class('BiGANModel', 'BaseModel')

function BiGANModel:__init(conf)
  BaseModel.__init(self, conf)
  conf = conf or {}
end

function BiGANModel:model_name()
  return 'BiGANModel'
end

function BiGANModel:InitializeStates(use_wgan)
  optimState = {learningRate=opt.lr, beta1=opt.beta1,}
  return optimState
end
-- Defines models and networks
function BiGANModel:Initialize(opt)
  if opt.test == 0 then
    self.realABPool = ImagePool(opt.pool_size)
    self.fakeABPool = ImagePool(opt.pool_size)
  end
  -- define tensors
  local d_input_nc = opt.input_nc + opt.output_nc
  self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
  self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
  -- load/define models
  self.criterionGAN = nn.MSECriterion()

  local netG,  netE, netD = nil, nil, nil
  if opt.continue_train == 1 then
    if opt.test == 1 then -- which_epoch option exists in test mode
      netG = util.load_test_model('G', opt)
      netE = util.load_test_model('E', opt)
      netD = util.load_test_model('D', opt)
    else
      netG = util.load_model('G', opt)
      netE = util.load_model('E', opt)
      netD = util.load_model('D', opt)
    end
  else
    -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
    -- os.exit()
    netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)  -- no sigmoid layer
    print('netD...', netD)
    netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
    print('netG...', netG)
    netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
    print('netE...', netE)

  end

  self.netD = netD
  self.netG = netG
  self.netE = netE

  -- define real/fake labels
  netD_output_size = self.netD:forward(self.real_AB):size()
  self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
  self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing

  self.optimStateD = self:InitializeStates()
  self.optimStateG = self:InitializeStates()
  self.optimStateE = self:InitializeStates()
  self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
  self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
  self:RefreshParameters()

  print('---------- # Learnable Parameters --------------')
  print(('G = %d'):format(self.parametersG:size(1)))
  print(('E = %d'):format(self.parametersE:size(1)))
  print(('D = %d'):format(self.parametersD:size(1)))
  print('------------------------------------------------')
  -- os.exit()
end

-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function BiGANModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
   	local temp = input.real_A
   	input.real_A = input.real_B
   	input.real_B = temp
   end
   self.real_AB[self.A_idx]:copy(input.real_A)
   self.fake_AB[self.B_idx]:copy(input.real_B)
   self.real_A = self.real_AB[self.A_idx]
   self.real_B = self.fake_AB[self.B_idx]
   self.fake_B = self.netG:forward(self.real_A):clone()
   self.fake_A = self.netE:forward(self.real_B):clone()
   self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B  -> real_label
   self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B  -> fake_label
  --  if opt.test == 0 then
  --    self.real_AB = self.realABPool:Query(self.real_AB)   -- batch history
  --    self.fake_AB = self.fakeABPool:Query(self.fake_AB)   -- batch history
  --  end
end

-- create closure to evaluate f(X) and df/dX of discriminator
function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt)
  util.BiasZero(netD)
  gradParams:zero()
  -- Real  log(D_A(B))
  local output = netD:forward(real_AB):clone()
  local errD_real = self.criterionGAN:forward(output, self.real_label)
  local df_do = self.criterionGAN:backward(output, self.real_label)
  netD:backward(real_AB, df_do)
  -- Fake  + log(1 - D_A(G(A)))
  output = netD:forward(fake_AB):clone()
  local errD_fake = self.criterionGAN:forward(output, self.fake_label)
  local df_do2 = self.criterionGAN:backward(output, self.fake_label)
  netD:backward(fake_AB, df_do2)
  -- Compute loss
  local errD = (errD_real + errD_fake) / 2.0
  return errD, gradParams
end


function BiGANModel:fDx(x, opt)
  -- use image pool that stores the old fake images
  real_AB = self.realABPool:Query(self.real_AB)
  fake_AB = self.fakeABPool:Query(self.fake_AB)
  self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt)
  return self.errD, gradParams
end



function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt)
  util.BiasZero(netG)
  util.BiasZero(netD)
  gradParametersG:zero()

  -- First. G(A) should fake the discriminator
  local output = netD:forward(self.real_AB):clone()
  local errG = self.criterionGAN:forward(output, self.fake_label)
  local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label)
  local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd)
  netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B  -> real_label
  return gradParametersG, errG
end


function BiGANModel:fGx(x, opt)
  self.gradParametersG, self.errG =  self:fGx_basic(x, self.netG, self.netD,
             self.gradParametersG, opt)
  return self.errG, self.gradParametersG
end


function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt)
  util.BiasZero(netE)
  util.BiasZero(netD)
  gradParametersE:zero()

  -- First. G(A) should fake the discriminator
  local output = netD:forward(self.fake_AB):clone()
  local errE= self.criterionGAN:forward(output, self.real_label)
  local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
  local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd)
  netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B  -> fake_label
  return gradParametersE, errE
end


function BiGANModel:fEx(x, opt)
  self.gradParametersE, self.errE =  self:fEx_basic(x, self.netE, self.netD,
             self.gradParametersE, opt)
  return self.errE, self.gradParametersE
end


function BiGANModel:OptimizeParameters(opt)
  local fG = function(x) return self:fGx(x, opt) end
  local fE = function(x) return self:fEx(x, opt) end
  local fD = function(x) return self:fDx(x, opt) end
  optim.adam(fD, self.parametersD, self.optimStateD)
  optim.adam(fG, self.parametersG, self.optimStateG)
  optim.adam(fE, self.parametersE, self.optimStateE)
end

function BiGANModel:RefreshParameters()
  self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
  self.parametersG, self.gradParametersG = nil, nil
  self.parametersE, self.gradParametersE = nil, nil
  -- define parameters of optimization
  self.parametersD, self.gradParametersD = self.netD:getParameters()
  self.parametersG, self.gradParametersG = self.netG:getParameters()
  self.parametersE, self.gradParametersE = self.netE:getParameters()
end

function BiGANModel:Save(prefix, opt)
  util.save_model(self.netG, prefix .. '_net_G.t7', 1)
  util.save_model(self.netE, prefix .. '_net_E.t7', 1)
  util.save_model(self.netD, prefix .. '_net_D.t7', 1)
end

function BiGANModel:GetCurrentErrorDescription()
  description = ('D: %.4f  G: %.4f  E: %.4f'):format(
                         self.errD and self.errD or -1,
                         self.errG and self.errG or -1,
                         self.errE and self.errE or -1)
  return description
end

function BiGANModel:GetCurrentErrors()
  local errors = {errD=self.errD, errG=self.errG, errE=self.errE}
  return errors
end

-- returns a string that describes the display plot configuration
function BiGANModel:DisplayPlot(opt)
  return 'errD,errG,errE'
end
function BiGANModel:UpdateLearningRate(opt)
  local lrd = opt.lr / opt.niter_decay
  local old_lr = self.optimStateD['learningRate']
  local lr =  old_lr - lrd
  self.optimStateD['learningRate'] = lr
  self.optimStateG['learningRate'] = lr
  self.optimStateE['learningRate'] = lr
  print(('update learning rate: %f -> %f'):format(old_lr, lr))
end

local function MakeIm3(im)
  -- print('before im_size', im:size())
  local im3 = nil
  if im:size(2) == 1 then
    im3 = torch.repeatTensor(im, 1,3,1,1)
  else
    im3 = im
  end
  -- print('after im_size', im:size())
  -- print('after im3_size', im3:size())
  return im3
end
function BiGANModel:GetCurrentVisuals(opt, size)
  if not size then
    size = opt.display_winsize
  end

  local visuals = {}
  table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
  table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
  table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
  table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
  return visuals
end


================================================
FILE: models/content_gan_model.lua
================================================
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
content = paths.dofile('../util/content_loss.lua')

ContentGANModel = class('ContentGANModel', 'BaseModel')

function ContentGANModel:__init(conf)
  BaseModel.__init(self, conf)
  conf = conf or {}
end

function ContentGANModel:model_name()
  return 'ContentGANModel'
end

function ContentGANModel:InitializeStates()
  local optimState = {learningRate=opt.lr, beta1=opt.beta1,}
  return optimState
end
-- Defines models and networks
function ContentGANModel:Initialize(opt)
  if opt.test == 0 then
    self.fakePool = ImagePool(opt.pool_size)
  end
  -- define tensors
  self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
  self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
  self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)

  -- load/define models
  self.criterionGAN = nn.MSECriterion()
  self.criterionContent = nn.AbsCriterion()
  self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name)
  self.netG, self.netD = nil, nil
  if opt.continue_train == 1 then
    if opt.which_epoch then -- which_epoch option exists in test mode
      self.netG = util.load_test_model('G_A', opt)
      self.netD = util.load_test_model('D_A', opt)
    else
      self.netG = util.load_model('G_A', opt)
      self.netD = util.load_model('D_A', opt)
    end
  else
    self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
    print('netG...', self.netG)
    self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)
    print('netD...', self.netD)
  end
  -- define real/fake labels
  netD_output_size = self.netD:forward(self.real_A):size()
  self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
  self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing
  self.optimStateD = self:InitializeStates()
  self.optimStateG = self:InitializeStates()
  self:RefreshParameters()
  print('---------- # Learnable Parameters --------------')
  print(('G = %d'):format(self.parametersG:size(1)))
  print(('D = %d'):format(self.parametersD:size(1)))
  print('------------------------------------------------')
  -- os.exit()
end

-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function ContentGANModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
    local temp = input.real_A
    input.real_A = input.real_B
    input.real_B = temp
  end

  self.real_A:copy(input.real_A)
  self.real_B:copy(input.real_B)
  self.fake_B = self.netG:forward(self.real_A):clone()
  -- output = {self.fake_B}
  output =  {}
  -- if opt.test == 1 then

  -- end
  return output
end

-- create closure to evaluate f(X) and df/dX of discriminator
function ContentGANModel:fDx_basic(x, gradParams, netD, netG,
                                   real_target, fake_target, opt)
  util.BiasZero(netD)
  util.BiasZero(netG)
  gradParams:zero()

  local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0
  -- Real  log(D_A(B))
  local output = netD:forward(real_target)
  errD_real = self.criterionGAN:forward(output, self.real_label)
  df_do = self.criterionGAN:backward(output, self.real_label)
  netD:backward(real_target, df_do)

  -- Fake  + log(1 - D_A(G_A(A)))
  output = netD:forward(fake_target)
  errD_fake = self.criterionGAN:forward(output, self.fake_label)
  df_do = self.criterionGAN:backward(output, self.fake_label)
  netD:backward(fake_target, df_do)
  errD = (errD_real + errD_fake) / 2.0
  -- print('errD', errD
  return errD, gradParams
end


function ContentGANModel:fDx(x, opt)
  fake_B = self.fakePool:Query(self.fake_B)
  self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG,
                                     self.real_B, fake_B, opt)
  return self.errD, gradParams
end

function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target,
                                   gradParametersG_source, opt)
  util.BiasZero(netD_source)
  util.BiasZero(netG_source)
  gradParametersG_source:zero()
  -- GAN loss
  -- local df_d_GAN = torch.zeros(fake_target:size())
  -- local errGAN = 0
  -- local errRec = 0
    --- Domain GAN loss: D_A(G_A(A))
  local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) ---
  local errGAN = self.criterionGAN:forward(output, self.real_label)
  local df_do = self.criterionGAN:backward(output, self.real_label)
  local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc)

  -- content loss
  -- print('content_loss', opt.content_loss)
  -- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
  local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A)
  netG_source:forward(real_source)
  netG_source:backward(real_source, df_d_GAN  + df_d_content)
  -- print('errD', errGAN)
  return gradParametersG_source, errGAN, errContent
end

function ContentGANModel:fGx(x, opt)
  self.gradparametersG, self.errG, self.errCont =
  self:fGx_basic(x, self.netG, self.netD,
             self.real_A, self.real_B, self.fake_B,
             self.gradparametersG, opt)
  return self.errG, self.gradparametersG
end

function ContentGANModel:OptimizeParameters(opt)
  local fDx = function(x) return self:fDx(x, opt) end
  local fGx = function(x) return self:fGx(x, opt) end
  optim.adam(fDx, self.parametersD, self.optimStateD)
  optim.adam(fGx, self.parametersG, self.optimStateG)
end

function ContentGANModel:RefreshParameters()
  self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory
  self.parametersG, self.gradparametersG = nil, nil
  -- define parameters of optimization
  self.parametersG, self.gradparametersG = self.netG:getParameters()
  self.parametersD, self.gradparametersD = self.netD:getParameters()
end

function ContentGANModel:Save(prefix, opt)
  util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0)
  util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0)
end

function ContentGANModel:GetCurrentErrorDescription()
  description = ('G: %.4f  D: %.4f  Content: %.4f'):format(self.errG and self.errG or -1,
                         self.errD and self.errD or -1,
                         self.errCont and self.errCont or -1)
  return description
end


function ContentGANModel:GetCurrentErrors()
  local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1,
  errCont=self.errCont and self.errCont or -1}
  return errors
end

-- returns a string that describes the display plot configuration
function ContentGANModel:DisplayPlot(opt)
  return 'errG,errD,errCont'
end


function ContentGANModel:GetCurrentVisuals(opt, size)
  if not size then
    size = opt.display_winsize
  end

  local visuals = {}
  table.insert(visuals, {img=self.real_A, label='real_A'})
  table.insert(visuals, {img=self.fake_B, label='fake_B'})
  table.insert(visuals, {img=self.real_B, label='real_B'})
  return visuals
end


================================================
FILE: models/cycle_gan_model.lua
================================================
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'

util = paths.dofile('../util/util.lua')
CycleGANModel = class('CycleGANModel', 'BaseModel')

function CycleGANModel:__init(conf)
  BaseModel.__init(self, conf)
  conf = conf or {}
end

function CycleGANModel:model_name()
  return 'CycleGANModel'
end

function CycleGANModel:InitializeStates(use_wgan)
  optimState = {learningRate=opt.lr, beta1=opt.beta1,}
  return optimState
end
-- Defines models and networks
function CycleGANModel:Initialize(opt)
  if opt.test == 0 then
    self.fakeAPool = ImagePool(opt.pool_size)
    self.fakeBPool = ImagePool(opt.pool_size)
  end
  -- define tensors
  if opt.test == 0 then  -- allocate tensors for training
    self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
    self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
    self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
    self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
    self.rec_A  = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
    self.rec_B  = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)
  end
  -- load/define models
  local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1))
  if not use_lsgan then
    self.criterionGAN = nn.BCECriterion()
  else
    self.criterionGAN = nn.MSECriterion()
  end
  self.criterionRec = nn.AbsCriterion()

  local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil
  if opt.continue_train == 1 then
    if opt.test == 1 then -- test mode
      netG_A = util.load_test_model('G_A', opt)
      netG_B = util.load_test_model('G_B', opt)

      --setup optnet to save a little bit of memory
      if opt.use_optnet == 1 then
        local sample_input = torch.randn(1, opt.input_nc, 2, 2)
        local optnet = require 'optnet'
        optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true})
        optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true})
      end
    else
      netG_A = util.load_model('G_A', opt)
      netG_B = util.load_model('G_B', opt)
      netD_A = util.load_model('D_A', opt)
      netD_B = util.load_model('D_B', opt)
    end
  else
    local use_sigmoid = (not use_lsgan)
    -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch)
    -- os.exit()
    netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch)
    print('netG_A...', netG_A)
    netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid)  -- no sigmoid layer
    print('netD_A...', netD_A)
    netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch)
    print('netG_B...', netG_B)
    netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid)  -- no sigmoid layer
    print('netD_B', netD_B)
  end

  self.netD_A = netD_A
  self.netG_A = netG_A
  self.netG_B = netG_B
  self.netD_B = netD_B

  -- define real/fake labels
  if opt.test == 0 then
    local D_A_size = self.netD_A:forward(self.real_B):size()  -- hack: assume D_size_A = D_size_B
    self.fake_label_A = torch.Tensor(D_A_size):fill(0.0)
    self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing
    local D_B_size = self.netD_B:forward(self.real_A):size()  -- hack: assume D_size_A = D_size_B
    self.fake_label_B = torch.Tensor(D_B_size):fill(0.0)
    self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing
    self.optimStateD_A = self:InitializeStates()
    self.optimStateG_A = self:InitializeStates()
    self.optimStateD_B = self:InitializeStates()
    self.optimStateG_B = self:InitializeStates()
    self:RefreshParameters()
    print('---------- # Learnable Parameters --------------')
    print(('G_A = %d'):format(self.parametersG_A:size(1)))
    print(('D_A = %d'):format(self.parametersD_A:size(1)))
    print(('G_B = %d'):format(self.parametersG_B:size(1)))
    print(('D_B = %d'):format(self.parametersD_B:size(1)))
    print('------------------------------------------------')
  end
end

-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function CycleGANModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
  	local temp = input.real_A:clone()
  	input.real_A = input.real_B:clone()
  	input.real_B = temp
  end

  if opt.test == 0 then
    self.real_A:copy(input.real_A)
    self.real_B:copy(input.real_B)
  end

  if opt.test == 1 then  -- forward for test
    if opt.gpu > 0 then
      self.real_A = input.real_A:cuda()
      self.real_B = input.real_B:cuda()
    else
      self.real_A = input.real_A:clone()
      self.real_B = input.real_B:clone()
    end
    self.fake_B = self.netG_A:forward(self.real_A):clone()
    self.fake_A = self.netG_B:forward(self.real_B):clone()
    self.rec_A  = self.netG_B:forward(self.fake_B):clone()
    self.rec_B  = self.netG_A:forward(self.fake_A):clone()
  end
end

-- create closure to evaluate f(X) and df/dX of discriminator
function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt)
  util.BiasZero(netD)
  util.BiasZero(netG)
  gradParams:zero()
  -- Real  log(D_A(B))
  local output = netD:forward(real)
  local errD_real = self.criterionGAN:forward(output, real_label)
  local df_do = self.criterionGAN:backward(output, real_label)
  netD:backward(real, df_do)
  -- Fake  + log(1 - D_A(G_A(A)))
  output = netD:forward(fake)
  local errD_fake = self.criterionGAN:forward(output, fake_label)
  local df_do2 = self.criterionGAN:backward(output, fake_label)
  netD:backward(fake, df_do2)
  -- Compute loss
  local errD = (errD_real + errD_fake) / 2.0
  return errD, gradParams
end


function CycleGANModel:fDAx(x, opt)
  -- use image pool that stores the old fake images
  fake_B = self.fakeBPool:Query(self.fake_B)
  self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A,
                            self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt)
  return self.errD_A, gradParams
end


function CycleGANModel:fDBx(x, opt)
  -- use image pool that stores the old fake images
  fake_A = self.fakeAPool:Query(self.fake_A)
  self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B,
                            self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt)
  return self.errD_B, gradParams
end


function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt)
  util.BiasZero(netD)
  util.BiasZero(netG)
  util.BiasZero(netE)  -- inverse mapping
  gradParams:zero()

  -- G should be identity if real2 is fed.
  local errI = nil
  local identity = nil
  if opt.lambda_identity > 0 then
    identity = netG:forward(real2):clone()
    errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity
    local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity)
    netG:backward(real2, didentity_loss_do)
  end

  --- GAN loss: D_A(G_A(A))
  local fake = netG:forward(real):clone()
  local output = netD:forward(fake)
  local errG = self.criterionGAN:forward(output, real_label)
  local df_do1 = self.criterionGAN:backward(output, real_label)
  local df_d_GAN = netD:updateGradInput(fake, df_do1) --

  -- forward cycle loss
  local rec = netE:forward(fake):clone()
  local errRec = self.criterionRec:forward(rec, real) * lambda1
  local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1)
  local df_do_rec = netE:updateGradInput(fake, df_do2)

  netG:backward(real, df_d_GAN + df_do_rec)

  -- backward cycle loss
  local fake2 = netE:forward(real2)--:clone()
  local rec2 = netG:forward(fake2)--:clone()
  local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2
  local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2)
  netG:backward(fake2, df_do_coadapt)

  return gradParams, errG, errRec, errI, fake, rec, identity
end

function CycleGANModel:fGAx(x, opt)
  self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B =
  self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B,
  self.real_label_A, opt.lambda_A, opt.lambda_B, opt)
  return self.errG_A, self.gradparametersG_A
end

function CycleGANModel:fGBx(x, opt)
  self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A =
  self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A,
  self.real_label_B, opt.lambda_B, opt.lambda_A, opt)
  return self.errG_B, self.gradparametersG_B
end


function CycleGANModel:OptimizeParameters(opt)
  local fDA = function(x) return self:fDAx(x, opt) end
  local fGA = function(x) return self:fGAx(x, opt) end
  local fDB = function(x) return self:fDBx(x, opt) end
  local fGB = function(x) return self:fGBx(x, opt) end

  optim.adam(fGA, self.parametersG_A, self.optimStateG_A)
  optim.adam(fDA, self.parametersD_A, self.optimStateD_A)
  optim.adam(fGB, self.parametersG_B, self.optimStateG_B)
  optim.adam(fDB, self.parametersD_B, self.optimStateD_B)
end

function CycleGANModel:RefreshParameters()
  self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory
  self.parametersG_A, self.gradparametersG_A = nil, nil
  self.parametersG_B, self.gradparametersG_B = nil, nil
  self.parametersD_B, self.gradparametersD_B = nil, nil
  -- define parameters of optimization
  self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
  self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters()
  self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters()
  self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters()
end

function CycleGANModel:Save(prefix, opt)
  util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1)
  util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1)
  util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1)
  util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1)
end

function CycleGANModel:GetCurrentErrorDescription()
  description = ('[A] G: %.4f  D: %.4f  Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format(
                         self.errG_A and self.errG_A or -1,
                         self.errD_A and self.errD_A or -1,
                         self.errRec_A and self.errRec_A or -1,
                         self.errI_A and self.errI_A or -1,
                         self.errG_B and self.errG_B or -1,
                         self.errD_B and self.errD_B or -1,
                         self.errRec_B and self.errRec_B or -1,
                         self.errI_B and self.errI_B or -1)
  return description
end

function CycleGANModel:GetCurrentErrors()
  local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A,
                  errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B}
  return errors
end

-- returns a string that describes the display plot configuration
function CycleGANModel:DisplayPlot(opt)
  if opt.lambda_identity > 0 then
    return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B'
  else
    return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B'
  end
end

function CycleGANModel:UpdateLearningRate(opt)
  local lrd = opt.lr / opt.niter_decay
  local old_lr = self.optimStateD_A['learningRate']
  local lr =  old_lr - lrd
  self.optimStateD_A['learningRate'] = lr
  self.optimStateD_B['learningRate'] = lr
  self.optimStateG_A['learningRate'] = lr
  self.optimStateG_B['learningRate'] = lr
  print(('update learning rate: %f -> %f'):format(old_lr, lr))
end

local function MakeIm3(im)
  if im:size(2) == 1 then
    local im3 = torch.repeatTensor(im, 1,3,1,1)
    return im3
  else
    return im
  end
end

function CycleGANModel:GetCurrentVisuals(opt, size)
  local visuals = {}
  table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
  table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
  table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'})
  if opt.test == 0 and opt.lambda_identity > 0 then
    table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'})
  end
  table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'})
  table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'})
  table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'})
  if opt.test == 0 and opt.lambda_identity > 0 then
    table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'})
  end
  return visuals
end


================================================
FILE: models/one_direction_test_model.lua
================================================
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'

util = paths.dofile('../util/util.lua')
OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel')

function OneDirectionTestModel:__init(conf)
  BaseModel.__init(self, conf)
  conf = conf or {}
end

function OneDirectionTestModel:model_name()
  return 'OneDirectionTestModel'
end

-- Defines models and networks
function OneDirectionTestModel:Initialize(opt)
  -- define tensors
  self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)

  -- load/define models
  self.netG_A = util.load_test_model('G', opt)

  -- setup optnet to save a bit of memory
  if opt.use_optnet == 1 then
    local optnet = require 'optnet'
    local sample_input = torch.randn(1, opt.input_nc, 2, 2)
    optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true})
  end

  self:RefreshParameters()

  print('---------- # Learnable Parameters --------------')
  print(('G_A = %d'):format(self.parametersG_A:size(1)))
  print('------------------------------------------------')
end

-- Runs the forward pass of the network and
-- saves the result to member variables of the class
function OneDirectionTestModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
  	input.real_A = input.real_B:clone()
  end

  self.real_A = input.real_A:clone()
  if opt.gpu > 0 then
    self.real_A = self.real_A:cuda()
  end

  self.fake_B = self.netG_A:forward(self.real_A):clone()
end

function OneDirectionTestModel:RefreshParameters()
  self.parametersG_A, self.gradparametersG_A = nil, nil
  self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters()
end


local function MakeIm3(im)
  if im:size(2) == 1 then
    local im3 = torch.repeatTensor(im, 1,3,1,1)
    return im3
  else
    return im
  end
end

function OneDirectionTestModel:GetCurrentVisuals(opt, size)
  if not size then
    size = opt.display_winsize
  end

  local visuals = {}
  table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'})
  table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'})
  return visuals
end


================================================
FILE: models/pix2pix_model.lua
================================================
local class = require 'class'
require 'models.base_model'
require 'models.architectures'
require 'util.image_pool'
util = paths.dofile('../util/util.lua')
Pix2PixModel = class('Pix2PixModel', 'BaseModel')

function Pix2PixModel:__init(conf)
   conf = conf or {}
end

-- Returns the name of the model
function Pix2PixModel:model_name()
   return 'Pix2PixModel'
end

function Pix2PixModel:InitializeStates()
  return {learningRate=opt.lr, beta1=opt.beta1,}
end

-- Defines models and networks
function Pix2PixModel:Initialize(opt)  -- use lsgan
  -- define tensors
  local d_input_nc = opt.input_nc + opt.output_nc
  self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
  self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize)
  if opt.test == 0 then
    self.fakeABPool = ImagePool(opt.pool_size)
  end
  -- load/define models
  self.criterionGAN = nn.MSECriterion()
  self.criterionL1 = nn.AbsCriterion()

  local netG, netD = nil, nil
  if opt.continue_train == 1 then
    if opt.test == 1 then -- only load model G for test
      netG = util.load_test_model('G', opt)
    else
      netG = util.load_model('G', opt)
      netD = util.load_model('D', opt)
    end
  else
    netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG)
    netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false)  -- with sigmoid
  end

  self.netD = netD
  self.netG = netG

  -- define real/fake labels
  if opt.test == 0 then
    netD_output_size = self.netD:forward(self.real_AB):size()
    self.fake_label = torch.Tensor(netD_output_size):fill(0.0)
    self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing

    self.optimStateD = self:InitializeStates()
    self.optimStateG = self:InitializeStates()

    self:RefreshParameters()

    print('---------- # Learnable Parameters --------------')
    print(('G = %d'):format(self.parametersG:size(1)))
    print(('D = %d'):format(self.parametersD:size(1)))
    print('------------------------------------------------')
  end

  self.A_idx = {{}, {1, opt.input_nc}, {}, {}}
  self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}}
end

-- Runs the forward pass of the network
function Pix2PixModel:Forward(input, opt)
  if opt.which_direction == 'BtoA' then
  	local temp = input.real_A
  	input.real_A = input.real_B
  	input.real_B = temp
  end

  if opt.test == 0 then
    self.real_AB[self.A_idx]:copy(input.real_A)
    self.real_AB[self.B_idx]:copy(input.real_B)
    self.real_A = self.real_AB[self.A_idx]
    self.real_B = self.real_AB[self.B_idx]

    self.fake_AB[self.A_idx]:copy(self.real_A)
    self.fake_B = self.netG:forward(self.real_A):clone()
    self.fake_AB[self.B_idx]:copy(self.fake_B)
  else
    if opt.gpu > 0 then
      self.real_A = input.real_A:cuda()
      self.real_B = input.real_B:cuda()
    else
      self.real_A = input.real_A:clone()
      self.real_B = input.real_B:clone()
    end
    self.fake_B = self.netG:forward(self.real_A):clone()
  end
end

-- create closure to evaluate f(X) and df/dX of discriminator
function Pix2PixModel:fDx_basic(x, gradParams, netD, netG, real, fake, opt)
  util.BiasZero(netD)
  util.BiasZero(netG)
  gradParams:zero()

  -- Real  log(D(B))
  local output = netD:forward(real)
  local errD_real = self.criterionGAN:forward(output, self.real_label)
  local df_do = self.criterionGAN:backward(output, self.real_label)
  netD:backward(real, df_do)
  -- Fake  + log(1 - D(G(A)))
  output = netD:forward(fake)
  local errD_fake = self.criterionGAN:forward(output, self.fake_label)
  local df_do2 = self.criterionGAN:backward(output, self.fake_label)
  netD:backward(fake, df_do2)
  -- calculate loss
  local errD = (errD_real + errD_fake) / 2.0
  return errD, gradParams
end


function Pix2PixModel:fDx(x, opt)
  fake_AB = self.fakeABPool:Query(self.fake_AB)
  self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, self.netG,
                                     self.real_AB, fake_AB, opt)
  return self.errD, gradParams
end

function Pix2PixModel:fGx_basic(x, netG, netD, real, fake, gradParametersG, opt)
  util.BiasZero(netG)
  util.BiasZero(netD)
  gradParametersG:zero()

  -- First. G(A) should fake the discriminator
  local output = netD:forward(fake)
  local errG = self.criterionGAN:forward(output, self.real_label)
  local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label)
  local dgan_loss_do = netD:updateGradInput(fake, dgan_loss_dd)

  -- Second. G(A) should be close to the real
  real_B = real[self.B_idx]
  real_A = real[self.A_idx]
  fake_B = fake[self.B_idx]
  local errL1 = self.criterionL1:forward(fake_B, real_B) * opt.lambda_A
  local dl1_loss_do = self.criterionL1:backward(fake_B, real_B) * opt.lambda_A
  netG:backward(real_A, dgan_loss_do[self.B_idx] + dl1_loss_do)

  return gradParametersG, errG, errL1
end

function Pix2PixModel:fGx(x, opt)
  self.gradParametersG, self.errG, self.errL1 =  self:fGx_basic(x, self.netG, self.netD,
             self.real_AB, self.fake_AB, self.gradParametersG, opt)
  return self.errG, self.gradParametersG
end

-- Runs the backprop gradient descent
-- Corresponds to a single batch of data
function Pix2PixModel:OptimizeParameters(opt)
  local fD = function(x) return self:fDx(x, opt) end
  local fG = function(x) return self:fGx(x, opt) end
  optim.adam(fD, self.parametersD, self.optimStateD)
  optim.adam(fG, self.parametersG, self.optimStateG)
end

-- This function can be used to reset momentum after each epoch
function Pix2PixModel:RefreshParameters()
  self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory
  self.parametersG, self.gradParametersG = nil, nil

  -- define parameters of optimization
  self.parametersG, self.gradParametersG = self.netG:getParameters()
  self.parametersD, self.gradParametersD = self.netD:getParameters()
end

-- This function updates the learning rate; lr for the first opt.niter iterations; graduatlly decreases the lr to 0 for the next opt.niter_decay iterations
function Pix2PixModel:UpdateLearningRate(opt)
  local lrd = opt.lr / opt.niter_decay
  local old_lr = self.optimStateD['learningRate']
  local lr =  old_lr - lrd
  self.optimStateD['learningRate'] = lr
  self.optimStateG['learningRate'] = lr
  print(('update learning rate: %f -> %f'):format(old_lr, lr))
end


-- Save the current model to the file system
function Pix2PixModel:Save(prefix, opt)
  util.save_model(self.netG, prefix .. '_net_G.t7', 1.0)
  util.save_model(self.netD, prefix .. '_net_D.t7', 1.0)
end

-- returns a string that describes the current errors
function Pix2PixModel:GetCurrentErrorDescription()
  description = ('G: %.4f  D: %.4f L1: %.4f'):format(
      self.errG and self.errG or -1, self.errD and self.errD or -1, self.errL1 and self.errL1 or -1)
  return description

end


-- returns a string that describes the display plot configuration
function Pix2PixModel:DisplayPlot(opt)
  return 'errG,errD,errL1'
end


-- returns current errors
function Pix2PixModel:GetCurrentErrors()
  local errors = {errG=self.errG, errD=self.errD, errL1=self.errL1}
  return errors
end

-- returns a table of image/label pairs that describe
-- the current results.
-- |return|: a table of table. List of image/label pairs
function Pix2PixModel:GetCurrentVisuals(opt, size)
  if not size then
    size = opt.display_winsize
  end

  local visuals = {}
  table.insert(visuals, {img=self.real_A, label='real_A'})
  table.insert(visuals, {img=self.fake_B, label='fake_B'})
  table.insert(visuals, {img=self.real_B, label='real_B'})

  return visuals
end


================================================
FILE: options.lua
================================================
--------------------------------------------------------------------------------
-- Configure options
--------------------------------------------------------------------------------

local options = {}
-- options for train
local opt_train = {
   DATA_ROOT = '',         -- path to images (should have subfolders 'train', 'val', etc)
   batchSize = 1,          -- # images in batch
   loadSize = 143,         -- scale images to this size
   fineSize = 128,         --  then crop to this size
   ngf = 64,               -- #  of gen filters in first conv layer
   ndf = 64,               -- #  of discrim filters in first conv layer
   input_nc = 3,           -- #  of input image channels
   output_nc = 3,          -- #  of output image channels
   niter = 100,            -- #  of iter at starting learning rate
   niter_decay = 100,      --  # of iter to linearly decay learning rate to zero
   lr = 0.0002,            -- initial learning rate for adam
   beta1 = 0.5,            -- momentum term of adam
   ntrain = math.huge,     -- #  of examples per epoch. math.huge for full dataset
   flip = 1,               -- if flip the images for data argumentation
   display_id = 10,        -- display window id.
   display_winsize = 128,  -- display window size
   display_freq = 25,      -- display the current results every display_freq iterations
   gpu = 1,                -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
   name = '',              -- name of the experiment, should generally be passed on the command line
   which_direction = 'AtoB',    -- AtoB or BtoA
   phase = 'train',             -- train, val, test, etc
   nThreads = 2,                -- # threads for loading data
   save_epoch_freq = 1,         -- save a model every save_epoch_freq epochs (does not overwrite previously saved models)
   save_latest_freq = 5000,     -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model)
   print_freq = 50,             -- print the debug information every print_freq iterations
   save_display_freq = 2500,    -- save the current display of results every save_display_freq_iterations
   continue_train = 0,          -- if continue training, load the latest model: 1: true, 0: false
   serial_batches = 0,          -- if 1, takes images in order to make batches, otherwise takes them randomly
   checkpoints_dir = './checkpoints', -- models are saved here
   cache_dir = './cache',             -- cache files are saved here
   cudnn = 1,                         -- set to 0 to not use cudnn
   which_model_netD = 'basic',        -- selects model to use for netD
   which_model_netG = 'resnet_6blocks',   -- selects model to use for netG
   norm = 'instance',             -- batch or instance normalization
   n_layers_D = 3,                -- only used if which_model_netD=='n_layers'
   content_loss = 'pixel',        -- content loss type: pixel, vgg
   layer_name = 'pixel',          -- layer used in content loss (e.g. relu4_2)
   lambda_A = 10.0,               -- weight for cycle loss (A -> B -> A)
   lambda_B = 10.0,               -- weight for cycle loss (B -> A -> B)
   model = 'cycle_gan',           -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'
   use_lsgan = 1,                 -- if 1, use least square GAN, if 0, use vanilla GAN
   align_data = 0,                -- if > 0, use the dataloader for where the images are aligned
   pool_size = 50,                -- the size of image buffer that stores previously generated images
   resize_or_crop = 'resize_and_crop',  -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
   lambda_identity = 0.5,                 -- use identity mapping. Setting opt.lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.lambda_identity = 0.1
   use_optnet = 0,                -- use optnet to save GPU memory during test
}

-- options for test
local opt_test = {
  DATA_ROOT = '',           -- path to images (should have subfolders 'train', 'val', etc)
  loadSize = 128,           -- scale images to this size
  fineSize = 128,           --  then crop to this size
  flip = 0,                  -- horizontal mirroring data augmentation
  display = 1,              -- display samples while training. 0 = false
  display_id = 200,         -- display window id.
  gpu = 1,                  -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
  how_many = 'all',         -- how many test images to run (set to all to run on every image found in the data/phase folder)
  phase = 'test',            -- train, val, test, etc
  aspect_ratio = 1.0,       -- aspect ratio of result images
  norm = 'instance',        -- batchnorm or isntance norm
  name = '',                -- name of experiment, selects which model to run, should generally should be passed on command line
  input_nc = 3,              -- #  of input image channels
  output_nc = 3,             -- #  of output image channels
  serial_batches = 1,        -- if 1, takes images in order to make batches, otherwise takes them randomly
  cudnn = 1,                 -- set to 0 to not use cudnn (untested)
  checkpoints_dir = './checkpoints', -- loads models from here
  cache_dir = './cache',             -- cache files are saved here
  results_dir='./results/',          -- saves results here
  which_epoch = 'latest',            -- which epoch to test? set to 'latest' to use latest cached model
  model = 'cycle_gan',               -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'; to use pretrained model, select `one_direction_test`
  align_data = 0,                    -- if > 0, use the dataloader for pix2pix
  which_direction = 'AtoB',          -- AtoB or BtoA
  resize_or_crop = 'resize_and_crop',  -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height
}

--------------------------------------------------------------------------------
-- util functions
--------------------------------------------------------------------------------
function options.clone(opt)
  local copy = {}
  for orig_key, orig_value in pairs(opt) do
    copy[orig_key] = orig_value
  end
  return copy
end

function options.parse_options(mode)
  if mode == 'train' then
    opt = opt_train
    opt.test = 0
  elseif mode == 'test' then
    opt = opt_test
    opt.test = 1
  else
    print("Invalid option [" .. mode .. "]")
    return nil
  end

  -- one-line argument parser. parses enviroment variables to override the defaults
  for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
  if mode == 'test' then
    opt.nThreads = 1
    opt.continue_train = 1
    opt.batchSize = 1  -- test code only supports batchSize=1
  end

  -- print by keys
  keyset = {}
  for k,v in pairs(opt) do
    table.insert(keyset, k)
  end
  table.sort(keyset)
  print("------------------- Options -------------------")
  for i,k in ipairs(keyset) do
    print(('%+25s: %s'):format(k, opt[k]))
  end
  print("-----------------------------------------------")

  -- save opt to checkpoints
  paths.mkdir(opt.checkpoints_dir)
  paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name))
  opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals')
  paths.mkdir(opt.visual_dir)
  -- save opt to the disk
  fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w')
  for i,k in ipairs(keyset) do
    fd:write(("%+25s: %s\n"):format(k, opt[k]))
  end
  fd:close()

  return opt
end


return options


================================================
FILE: pretrained_models/download_model.sh
================================================
FILE=$1

echo "Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style_ukiyoe,  summer2winter_yosemite, zebra2horse, facades_label2photo, horse2zebra,monet2photo, sat2map, style_monet,style_vangogh, winter2summer_yosemite, iphone2dslr_flower"

echo "Specified [$FILE]"

mkdir -p ./checkpoints/${FILE}_pretrained
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/$FILE.t7
MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.t7
wget -N $URL -O $MODEL_FILE


================================================
FILE: pretrained_models/download_vgg.sh
================================================
URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel
MODEL_FILE1=./models/places_vgg.caffemodel
URL2=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.prototxt
MODEL_FILE2=./models/places_vgg.prototxt
wget -N $URL1 -O $MODEL_FILE1
wget -N $URL2 -O $MODEL_FILE2


================================================
FILE: pretrained_models/places_vgg.prototxt
================================================
name: "VGG-Places365"
input: "data"
input_dim: 1
input_dim: 3
input_dim: 224
input_dim: 224
layer {
  name: "conv1_1"
  type: "Convolution"
  bottom: "data"
  top: "conv1_1"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 64
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu1_1"
  type: "ReLU"
  bottom: "conv1_1"
  top: "conv1_1"
}
layer {
  name: "conv1_2"
  type: "Convolution"
  bottom: "conv1_1"
  top: "conv1_2"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 64
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu1_2"
  type: "ReLU"
  bottom: "conv1_2"
  top: "conv1_2"
}
layer {
  name: "pool1"
  type: "Pooling"
  bottom: "conv1_2"
  top: "pool1"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {
  name: "conv2_1"
  type: "Convolution"
  bottom: "pool1"
  top: "conv2_1"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 128
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu2_1"
  type: "ReLU"
  bottom: "conv2_1"
  top: "conv2_1"
}
layer {
  name: "conv2_2"
  type: "Convolution"
  bottom: "conv2_1"
  top: "conv2_2"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 128
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu2_2"
  type: "ReLU"
  bottom: "conv2_2"
  top: "conv2_2"
}
layer {
  name: "pool2"
  type: "Pooling"
  bottom: "conv2_2"
  top: "pool2"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {
  name: "conv3_1"
  type: "Convolution"
  bottom: "pool2"
  top: "conv3_1"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 256
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu3_1"
  type: "ReLU"
  bottom: "conv3_1"
  top: "conv3_1"
}
layer {
  name: "conv3_2"
  type: "Convolution"
  bottom: "conv3_1"
  top: "conv3_2"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 256
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu3_2"
  type: "ReLU"
  bottom: "conv3_2"
  top: "conv3_2"
}
layer {
  name: "conv3_3"
  type: "Convolution"
  bottom: "conv3_2"
  top: "conv3_3"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 256
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu3_3"
  type: "ReLU"
  bottom: "conv3_3"
  top: "conv3_3"
}
layer {
  name: "pool3"
  type: "Pooling"
  bottom: "conv3_3"
  top: "pool3"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {
  name: "conv4_1"
  type: "Convolution"
  bottom: "pool3"
  top: "conv4_1"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu4_1"
  type: "ReLU"
  bottom: "conv4_1"
  top: "conv4_1"
}
layer {
  name: "conv4_2"
  type: "Convolution"
  bottom: "conv4_1"
  top: "conv4_2"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu4_2"
  type: "ReLU"
  bottom: "conv4_2"
  top: "conv4_2"
}
layer {
  name: "conv4_3"
  type: "Convolution"
  bottom: "conv4_2"
  top: "conv4_3"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu4_3"
  type: "ReLU"
  bottom: "conv4_3"
  top: "conv4_3"
}
layer {
  name: "pool4"
  type: "Pooling"
  bottom: "conv4_3"
  top: "pool4"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {
  name: "conv5_1"
  type: "Convolution"
  bottom: "pool4"
  top: "conv5_1"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu5_1"
  type: "ReLU"
  bottom: "conv5_1"
  top: "conv5_1"
}
layer {
  name: "conv5_2"
  type: "Convolution"
  bottom: "conv5_1"
  top: "conv5_2"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu5_2"
  type: "ReLU"
  bottom: "conv5_2"
  top: "conv5_2"
}
layer {
  name: "conv5_3"
  type: "Convolution"
  bottom: "conv5_2"
  top: "conv5_3"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  convolution_param {
    num_output: 512
    pad: 1
    kernel_size: 3
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu5_3"
  type: "ReLU"
  bottom: "conv5_3"
  top: "conv5_3"
}
layer {
  name: "pool5"
  type: "Pooling"
  bottom: "conv5_3"
  top: "pool5"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
layer {
  name: "fc6"
  type: "InnerProduct"
  bottom: "pool5"
  top: "fc6"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu6"
  type: "ReLU"
  bottom: "fc6"
  top: "fc6"
}
layer {
  name: "drop6"
  type: "Dropout"
  bottom: "fc6"
  top: "fc6"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc7"
  type: "InnerProduct"
  bottom: "fc6"
  top: "fc7"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  inner_product_param {
    num_output: 4096
    weight_filler {
      type: "gaussian"
      std: 0.01
    }
    bias_filler {
      type: "constant"
      value: 0.0
    }
  }
}
layer {
  name: "relu7"
  type: "ReLU"
  bottom: "fc7"
  top: "fc7"
}
layer {
  name: "drop7"
  type: "Dropout"
  bottom: "fc7"
  top: "fc7"
  dropout_param {
    dropout_ratio: 0.5
  }
}
layer {
  name: "fc8a"
  type: "InnerProduct"
  bottom: "fc7"
  top: "fc8a"
  param {
    lr_mult: 1.0
    decay_mult: 1.0
  }
  param {
    lr_mult: 2.0
    decay_mult: 0.0
  }
  inner_product_param {
    num_output: 365
  }
}
layer {
  name: "prob"
  type: "Softmax"
  bottom: "fc8a"
  top: "prob"
}


================================================
FILE: test.lua
================================================
-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua
--
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix
require 'image'
require 'nn'
require 'nngraph'
require 'models.architectures'


util = paths.dofile('util/util.lua')
options = require 'options'
opt = options.parse_options('test')

-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
  require 'cutorch'
  require 'cunn'
  cutorch.setDevice(opt.gpu)
  print ("GPU Mode")
  torch.setdefaulttensortype('torch.CudaTensor')
else
  torch.setdefaulttensortype('torch.FloatTensor')
  print ("CPU Mode")
end

-- setup visualization
visualizer = require 'util/visualizer'

function TableConcat(t1,t2)
  for i=1,#t2 do
    t1[#t1+1] = t2[i]
  end
  return t1
end


-- load data
local data_loader = nil
if opt.align_data > 0 then
  require 'data.aligned_data_loader'
  data_loader = AlignedDataLoader()
else
  require 'data.unaligned_data_loader'
  data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)

if opt.how_many == 'all' then
  opt.how_many = data_loader:size()
end

opt.how_many = math.min(opt.how_many, data_loader:size())

-- set batch/instance normalization
set_normalization(opt.norm)

-- load model
opt.continue_train = 1
-- define model
if opt.model == 'cycle_gan' then
  require 'models.cycle_gan_model'
  model  = CycleGANModel()
elseif opt.model == 'one_direction_test' then
  require 'models.one_direction_test_model'
  model = OneDirectionTestModel()
elseif opt.model == 'pix2pix' then
  require 'models.pix2pix_model'
  model = Pix2PixModel()
elseif opt.model == 'bigan' then
  require 'models.bigan_model'
  model  = BiGANModel()
elseif opt.model == 'content_gan' then
  require 'models.content_gan_model'
  model = ContentGANModel()
else
  error('Please specify a correct model')
end
model:Initialize(opt)

local pathsA = {} -- paths to images A tested on
local pathsB = {} -- paths to images B tested on
local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase)
paths.mkdir(web_dir)
local image_dir = paths.concat(web_dir, 'images')
paths.mkdir(image_dir)
s1 = opt.fineSize
s2 = opt.fineSize / opt.aspect_ratio

visuals = {}

for n = 1, math.floor(opt.how_many) do
  print('processing batch ' .. n)
  local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch()

  cur_pathsA = util.basename_batch(cur_pathsA)
  cur_pathsB = util.basename_batch(cur_pathsB)
  print('pathsA', cur_pathsA)
  print('pathsB', cur_PathsB)
  model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt)

  visuals = model:GetCurrentVisuals(opt, opt.fineSize)

  for i,visual in ipairs(visuals) do
    if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
      s1 = nil
      s2 = nil
    end
    visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2)
  end


  print('Saved images to: ', image_dir)
  pathsA = TableConcat(pathsA, cur_pathsA)
  pathsB = TableConcat(pathsB, cur_pathsB)
end

labels = {}
for i,visual in ipairs(visuals) do
  table.insert(labels, visual.label)
end

-- make webpage
io.output(paths.concat(web_dir, 'index.html'))
io.write('<table style="text-align:center;">')
io.write('<tr><td> Image </td>')
for i = 1, #labels do
  io.write('<td>' .. labels[i] .. '</td>')
end
io.write('</tr>')

for n = 1,math.floor(opt.how_many) do
  io.write('<tr>')
  io.write('<td>' .. tostring(n) .. '</td>')
  for j = 1, #labels do
    label = labels[j]
    io.write('<td><img src="./images/' .. label .. '/' .. string.gsub(pathsA[n],'.jpg','.png') .. '"/></td>')
  end
  io.write('</tr>')
end

io.write('</table>')


================================================
FILE: train.lua
================================================
-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua
-- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix

require 'torch'
require 'nn'
require 'optim'
util = paths.dofile('util/util.lua')
content = paths.dofile('util/content_loss.lua')
require 'image'
require 'models.architectures'

-- load configuration file
options = require 'options'
opt = options.parse_options('train')

-- setup visualization
visualizer = require 'util/visualizer'

-- initialize torch GPU/CPU mode
if opt.gpu > 0 then
  require 'cutorch'
  require 'cunn'
  cutorch.setDevice(opt.gpu)
  print ("GPU Mode")
  torch.setdefaulttensortype('torch.CudaTensor')
else
  torch.setdefaulttensortype('torch.FloatTensor')
  print ("CPU Mode")
end

-- load data
local data_loader = nil
if opt.align_data > 0 then
  require 'data.aligned_data_loader'
  data_loader = AlignedDataLoader()
else
  require 'data.unaligned_data_loader'
  data_loader = UnalignedDataLoader()
end
print( "DataLoader " .. data_loader:name() .. " was created.")
data_loader:Initialize(opt)

-- set batch/instance normalization
set_normalization(opt.norm)

--- timer
local epoch_tm = torch.Timer()
local tm = torch.Timer()

-- define model
local model = nil
local display_plot = nil
if opt.model == 'cycle_gan' then
  assert(data_loader:name() == 'UnalignedDataLoader')
  require 'models.cycle_gan_model'
  model = CycleGANModel()
elseif opt.model == 'pix2pix' then
  require 'models.pix2pix_model'
  assert(data_loader:name() == 'AlignedDataLoader')
  model = Pix2PixModel()
elseif opt.model == 'bigan' then
  assert(data_loader:name() == 'UnalignedDataLoader')
  require 'models.bigan_model'
  model = BiGANModel()
elseif opt.model == 'content_gan' then
  require 'models.content_gan_model'
  assert(data_loader:name() == 'UnalignedDataLoader')
  model = ContentGANModel()
else
  error('Please specify a correct model')
end

-- print the model name
print('Model ' .. model:model_name() .. ' was specified.')
model:Initialize(opt)

-- set up the loss plot
require 'util/plot_util'
plotUtil = PlotUtil()
display_plot = model:DisplayPlot(opt)
plotUtil:Initialize(display_plot, opt.display_id, opt.name)

--------------------------------------------------------------------------------
-- Helper Functions
--------------------------------------------------------------------------------
function visualize_current_results()
  local visuals = model:GetCurrentVisuals(opt)
  for i,visual in ipairs(visuals) do
    visualizer.disp_image(visual.img, opt.display_winsize,
                          opt.display_id+i, opt.name .. ' ' .. visual.label)
  end
end

function save_current_results(epoch, counter)
  local visuals = model:GetCurrentVisuals(opt)
  for i,visual in ipairs(visuals) do
    output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg')
    visualizer.save_results(visual.img, output_path)
  end
end

function print_current_errors(epoch, counter_in_epoch)
  print(('Epoch: [%d][%8d / %8d]\t Time: %.3f  DataTime: %.3f  '
           .. '%s'):
      format(epoch, ((counter_in_epoch-1) / opt.batchSize),
      math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize),
      tm:time().real / opt.batchSize,
      data_loader:time_elapsed_to_fetch_data() / opt.batchSize,
      model:GetCurrentErrorDescription()
  ))
end

function plot_current_errors(epoch, counter_ratio, opt)
  local errs = model:GetCurrentErrors(opt)
  local plot_vals = { epoch + counter_ratio}
  plotUtil:Display(plot_vals, errs)
end

--------------------------------------------------------------------------------
-- Main Training Loop
--------------------------------------------------------------------------------
local counter = 0
local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize)
print('#training iterations: ' .. opt.niter+opt.niter_decay )

for epoch = 1, opt.niter+opt.niter_decay do
    epoch_tm:reset()
    for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do
        tm:reset()
        -- load a batch and run G on that batch
        local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch()

        model:Forward({real_A=real_dataA, real_B=real_dataB}, opt)
        -- run forward pass
        opt.counter = counter
        -- run backward pass
        model:OptimizeParameters(opt)
        -- display on the web server
        if counter % opt.display_freq == 0 and opt.display_id > 0 then
          visualize_current_results()
        end

        -- logging
        if counter % opt.print_freq == 0 then
          print_current_errors(epoch, counter_in_epoch)
          plot_current_errors(epoch, counter_in_epoch/num_batches, opt)
        end

        -- save latest model
        if counter % opt.save_latest_freq == 0 and counter > 0 then
          print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter))
          model:Save('latest', opt)
        end

        -- save latest results
        if counter % opt.save_display_freq == 0 then
          save_current_results(epoch, counter)
        end
        counter = counter + 1
    end

    -- save model at the end of epoch
    if epoch % opt.save_epoch_freq == 0 then
        print(('saving the model (epoch %d, iters %d)'):format(epoch, counter))
        model:Save('latest', opt)
        model:Save(epoch, opt)
   end
    -- print the timing information after each epoch
    print(('End of epoch %d / %d \t Time Taken: %.3f'):
        format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real))

    -- update learning rate
    if epoch > opt.niter then
      model:UpdateLearningRate(opt)
    end
    -- refresh parameters
    model:RefreshParameters(opt)
end


================================================
FILE: util/InstanceNormalization.lua
================================================
require 'nn'

--[[
  Implements instance normalization as described in the paper

  Instance Normalization: The Missing Ingredient for Fast Stylization
  Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky
  https://arxiv.org/abs/1607.08022
  This implementation is based on
  https://github.com/DmitryUlyanov/texture_nets
]]

local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module')

function InstanceNormalization:__init(nOutput, eps, momentum, affine)
   parent.__init(self)
   self.running_mean = torch.zeros(nOutput)
   self.running_var = torch.ones(nOutput)

   self.eps = eps or 1e-5
   self.momentum = momentum or 0.0
   if affine ~= nil then
      assert(type(affine) == 'boolean', 'affine has to be true/false')
      self.affine = affine
   else
      self.affine = true
   end

   self.nOutput = nOutput
   self.prev_batch_size = -1

   if self.affine then
      self.weight = torch.Tensor(nOutput):uniform()
      self.bias = torch.Tensor(nOutput):zero()
      self.gradWeight = torch.Tensor(nOutput)
      self.gradBias = torch.Tensor(nOutput)
   end
end

function InstanceNormalization:updateOutput(input)
   self.output = self.output or input.new()
   assert(input:size(2) == self.nOutput)

   local batch_size = input:size(1)

   if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type())  then
      self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine)
      self.bn:type(self:type())
      self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size))
      self.bn.running_var:copy(self.running_var:repeatTensor(batch_size))

      self.prev_batch_size = input:size(1)
   end

   -- Get statistics
   self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1))
   self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1))

   -- Set params for BN
   if self.affine then
      self.bn.weight:copy(self.weight:repeatTensor(batch_size))
      self.bn.bias:copy(self.bias:repeatTensor(batch_size))
   end

   local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
   self.output = self.bn:forward(input_1obj):viewAs(input)

   return self.output
end

function InstanceNormalization:updateGradInput(input, gradOutput)
   self.gradInput = self.gradInput or gradOutput.new()

   assert(self.bn)

   local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))
   local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4))

   if self.affine then
      self.bn.gradWeight:zero()
      self.bn.gradBias:zero()
   end

   self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input)

   if self.affine then
      self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1))
      self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1))
   end
   return self.gradInput
end

function InstanceNormalization:clearState()
   self.output = self.output.new()
   self.gradInput = self.gradInput.new()

   if self.bn then
     self.bn:clearState()
   end
end

function InstanceNormalization:evaluate()
end

function InstanceNormalization:training()
end


================================================
FILE: util/VGG_preprocess.lua
================================================
-- define nn module for VGG postprocessing
local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module')

function VGG_postprocess:__init()
	parent.__init(self)
end

function VGG_postprocess:updateOutput(input)
  self.output = input:add(1):mul(127.5)
	-- print(self.output:max(), self.output:min())
	if self.output:max() > 255 or self.output:min() < 0 then
		print(self.output:min(), self.output:max())
	end
	-- assert(self.output:min()>=0,"badly scaled inputs")
  -- assert(self.output:max()<=255,"badly scaled inputs")

	local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68})
	mean_pixel = mean_pixel:reshape(1,3,1,1)
	mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda()
	self.output:add(-1, mean_pixel)
	return self.output
end

function VGG_postprocess:updateGradInput(input, gradOutput)
	self.gradInput = gradOutput:div(127.5)
	return self.gradInput
end


================================================
FILE: util/content_loss.lua
================================================
require 'torch'
require 'nn'
local content = {}

function content.defineVGG(content_layer)
  local contentFunc = nn.Sequential()
  require 'loadcaffe'
  require 'util/VGG_preprocess'
  cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn')
  contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
  contentFunc:add(nn.VGG_postprocess())
  for i = 1, #cnn do
    local layer = cnn:get(i):clone()
    local name = layer.name
    local layer_type = torch.type(layer)
    contentFunc:add(layer)
    if name == content_layer then
      print("Setting up content layer: ", layer.name)
      break
    end
  end
  cnn = nil
  collectgarbage()
  print(contentFunc)
  return contentFunc
end

function content.defineAlexNet(content_layer)
  local contentFunc = nn.Sequential()
  require 'loadcaffe'
  require 'util/VGG_preprocess'
  cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn')
  contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224}))
  contentFunc:add(nn.VGG_postprocess())
  for i = 1, #cnn do
    local layer = cnn:get(i):clone()
    local name = layer.name
    local layer_type = torch.type(layer)
    contentFunc:add(layer)
    if name == content_layer then
      print("Setting up content layer: ", layer.name)
      break
    end
  end
  cnn = nil
  collectgarbage()
  print(contentFunc)
  return contentFunc
end



function content.defineContent(content_loss, layer_name)
  -- print('content_loss_define', content_loss)
  if content_loss == 'pixel' or content_loss == 'none' then
    return nil
  elseif content_loss == 'vgg' then
    return content.defineVGG(layer_name)
  else
    print("unsupported content loss")
    return nil
  end
end


function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight)
  if loss_type == 'none' then
    local errCont = 0.0
    local df_d_content = torch.zeros(fake_target:size())
    return errCont, df_d_content
  elseif loss_type == 'pixel' then
    local errCont = criterionContent:forward(fake_target, real_source) * weight
    local df_do_content = criterionContent:backward(fake_target, real_source)*weight
    return errCont, df_do_content
  elseif loss_type == 'vgg' then
    local f_fake = contentFunc:forward(fake_target):clone()
	  local f_real = contentFunc:forward(real_source):clone()
    local errCont = criterionContent:forward(f_fake, f_real) * weight
    local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight
    local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight)
    return errCont, df_do_content
  else error("unsupported content loss")
  end
end


return content


================================================
FILE: util/cudnn_convert_custom.lua
================================================
-- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua
-- removed error on nngraph

-- modules that can be converted to nn seamlessly
local layer_list = {
  'BatchNormalization',
  'SpatialBatchNormalization',
  'SpatialConvolution',
  'SpatialCrossMapLRN',
  'SpatialFullConvolution',
  'SpatialMaxPooling',
  'SpatialAveragePooling',
  'ReLU',
  'Tanh',
  'Sigmoid',
  'SoftMax',
  'LogSoftMax',
  'VolumetricBatchNormalization',
  'VolumetricConvolution',
  'VolumetricFullConvolution',
  'VolumetricMaxPooling',
  'VolumetricAveragePooling',
}

-- goes over a given net and converts all layers to dst backend
-- for example: net = cudnn_convert_custom(net, cudnn)
-- same as cudnn.convert with gModule check commented out
function cudnn_convert_custom(net, dst, exclusion_fn)
  return net:replace(function(x)
    --if torch.type(x) == 'nn.gModule' then
    --  io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule')
    --  return x
    --end
    local y = 0
    local src = dst == nn and cudnn or nn
    local src_prefix = src == nn and 'nn.' or 'cudnn.'
    local dst_prefix = dst == nn and 'nn.' or 'cudnn.'

    local function convert(v)
      local y = {}
      torch.setmetatable(y, dst_prefix..v)
      if v == 'ReLU' then y = dst.ReLU() end -- because parameters
      for k,u in pairs(x) do y[k] = u end
      if src == cudnn and x.clearDesc then x.clearDesc(y) end
      if src == cudnn and v == 'SpatialAveragePooling' then
        y.divide = true
        y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING'
      end
      if src == nn and string.find(v, 'Convolution') then
         y.groups = 1
      end
      return y
    end

    if exclusion_fn and exclusion_fn(x) then
      return x
    end
    local t = torch.typename(x)
    if t == 'nn.SpatialConvolutionMM' then
      y = convert('SpatialConvolution')
    elseif t == 'inn.SpatialCrossResponseNormalization' then
      y = convert('SpatialCrossMapLRN')
    else
      for i,v in ipairs(layer_list) do
        if torch.typename(x) == src_prefix..v then
          y = convert(v)
        end
      end
    end
    return y == 0 and x or y
  end)
end


================================================
FILE: util/image_pool.lua
================================================
local class = require 'class'
ImagePool= class('ImagePool')

require 'torch'
require 'image'

function ImagePool:__init(pool_size)
  self.pool_size = pool_size
  if pool_size > 0 then
    self.num_imgs = 0
    self.images = {}
  end
end

function ImagePool:model_name()
  return 'ImagePool'
end
-- 
-- function ImagePool:Initialize(pool_size)
--   -- torch.manualSeed(0)
--   -- assert(pool_size > 0)
--   self.pool_size = pool_size
--   if pool_size > 0 then
--     self.num_imgs = 0
--     self.images = {}
--   end
-- end

function ImagePool:Query(image)
  -- print('query image')
  if self.pool_size == 0 then
    -- print('get identical image')
    return image
  end
  if self.num_imgs < self.pool_size then
    -- self.images.insert(image:clone())
    self.num_imgs = self.num_imgs + 1
    self.images[self.num_imgs] = image
    return image
  else
    local p = math.random()
    -- print('p' ,p)
    -- os.exit()
    if p > 0.5 then
      -- print('use old image')
      -- random_id = torch.Tensor(1)
      -- random_id:random(1, self.pool_size)
      local random_id = math.random(self.pool_size)
      -- print('random_id', random_id)
      local tmp = self.images[random_id]:clone()
      self.images[random_id] = image:clone()
      return tmp
    else
      return image
    end

  end

end


================================================
FILE: util/plot_util.lua
================================================
local class = require 'class'
PlotUtil = class('PlotUtil')


require 'torch'
disp = require 'display'
util = require 'util/util'
require 'image'

local unpack = unpack or table.unpack

function PlotUtil:__init(conf)
  conf = conf or {}
end

function PlotUtil:model_name()
  return 'PlotUtil'
end

function PlotUtil:Initialize(display_plot, display_id, name)
  self.display_plot = string.split(string.gsub(display_plot, "%s+", ""), ",")

  self.plot_config = {
    title = name .. ' loss over time',
    labels = {'epoch', unpack(self.display_plot)},
    ylabel = 'loss',
    win  = display_id,
  }

  self.plot_data = {}
  print('display_opt', self.display_plot)
end


function PlotUtil:Display(plot_vals, loss)
  for k, v in ipairs(self.display_plot) do
    if loss[v] ~= nil then
      plot_vals[#plot_vals + 1] = loss[v]
    end
  end

  table.insert(self.plot_data, plot_vals)
  disp.plot(self.plot_data, self.plot_config)
end


================================================
FILE: util/util.lua
================================================
--
-- code derived from https://github.com/soumith/dcgan.torch
--

local util = {}

require 'torch'


function util.BiasZero(net)
  net:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
end


function util.checkEqual(A, B, name)
  local dif = (A:float()-B:float()):abs():mean()
  print(name, dif)
end

function util.containsValue(table, value)
  for k, v in pairs(table) do
    if v == value then return true end
  end
  return false
end


function util.CheckTensor(A, name)
  print(name, A:min(), A:max(), A:mean())
end


function util.normalize(img)
  -- rescale image to 0 .. 1
  local min = img:min()
  local max = img:max()

  img = torch.FloatTensor(img:size()):copy(img)
  img:add(-min):mul(1/(max-min))
  return img
end

function util.normalizeBatch(batch)
	for i = 1, batch:size(1) do
		batch[i] = util.normalize(batch[i]:squeeze())
	end
	return batch
end

function util.basename_batch(batch)
	for i = 1, #batch do
		batch[i] = paths.basename(batch[i])
	end
	return batch
end



-- default preprocessing
--
-- Preprocesses an image before passing it to a net
-- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
function util.preprocess(img)
    -- RGB to BGR
    if img:size(1) == 3 then
      local perm = torch.LongTensor{3, 2, 1}
      img = img:index(1, perm)
    end
    -- [0,1] to [-1,1]
    img = img:mul(2):add(-1)

    -- check that input is in expected range
    assert(img:max()<=1,"badly scaled inputs")
    assert(img:min()>=-1,"badly scaled inputs")

    return img
end

-- Undo the above preprocessing.
function util.deprocess(img)
    -- BGR to RGB
    if img:size(1) == 3 then
      local perm = torch.LongTensor{3, 2, 1}
      img = img:index(1, perm)
    end

    -- [-1,1] to [0,1]
    img = img:add(1):div(2)

    return img
end

function util.preprocess_batch(batch)
	for i = 1, batch:size(1) do
		batch[i] = util.preprocess(batch[i]:squeeze())
	end
	return batch
end

function util.print_tensor(name, x)
  print(name, x:size(), x:min(), x:mean(), x:max())
end

function util.deprocess_batch(batch)
	for i = 1, batch:size(1) do
		batch[i] = util.deprocess(batch[i]:squeeze())
	end
	return batch
end


function util.scaleBatch(batch,s1,s2)
  -- print('s1', s1)
  -- print('s2', s2)
	local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
	for i = 1, batch:size(1) do
		scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
	end
	return scaled_batch
end



function util.toTrivialBatch(input)
  return input:reshape(1,input:size(1),input:size(2),input:size(3))
end
function util.fromTrivialBatch(input)
    return input[1]
end

-- input is between -1 and 1
function util.jitter(input)
  local noise = torch.rand(input:size())/256.0
  input:add(1.0):mul(0.5*255.0/256.0):add(noise):add(-0.5):mul(2.0)
  --local scaled = (input+1.0)*0.5
  --local jittered = scaled*255.0/256.0 + torch.rand(input:size())/256.0
  --local scaled_back = (jittered-0.5)*2.0
  --return scaled_back
end

function util.scaleImage(input, loadSize)

    -- replicate bw images to 3 channels
    if input:size(1)==1 then
    	input = torch.repeatTensor(input,3,1,1)
    end

    input = image.scale(input, loadSize, loadSize)

    return input
end

function util.getAspectRatio(path)
	local input = image.load(path, 3, 'float')
	local ar = input:size(3)/input:size(2)
	return ar
end

function util.loadImage(path, loadSize, nc)
  local input = image.load(path, 3, 'float')
  input= util.preprocess(util.scaleImage(input, loadSize))

  if nc == 1 then
    input = input[{{1}, {}, {}}]
  end

  return input
end

function file_exists(filename)
   local f = io.open(filename,"r")
   if f ~= nil then io.close(f) return true else return false end
end

-- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
function load_helper(filename, opt)
  fileExists = file_exists(filename)
  if not fileExists then
    print('model not found!    ' .. filename)
    return nil
  end
  print(('loading previously trained model (%s)'):format(filename))
	if opt.norm == 'instance' then
	  print('use InstanceNormalization')
	  require 'util.InstanceNormalization'
	end

	if opt.cudnn>0 then
		require 'cudnn'
	end

	local net = torch.load(filename)
	if opt.gpu > 0 then
		require 'cunn'
		net:cuda()

		-- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
		if net.forwardnodes then
			for i=1,#net.forwardnodes do
				if net.forwardnodes[i].data.module then
					net.forwardnodes[i].data.module:cuda()
				end
			end
		end

	else
		net:float()
	end
	net:apply(function(m) if m.weight then
	    m.gradWeight = m.weight:clone():zero();
	    m.gradBias = m.bias:clone():zero(); end end)
	return net
end

function util.load_model(name, opt)
  -- if opt['lambda_'.. name] > 0.0 then
  -- print('not loading model '.. opt.checkpoints_dir .. opt.name ..
  --         'latest_net_' .. name .. '.t7' .. ' because opt.lambda is not greater than zero')
  return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
                                'latest_net_' .. name .. '.t7'), opt)
  -- end
end

function util.load_test_model(name, opt)
  return load_helper(paths.concat(opt.checkpoints_dir, opt.name,
                                    opt.which_epoch .. '_net_' .. name .. '.t7'), opt)
end


-- load dataset from the file system
-- |name|: name of the dataset. It's currently either 'A' or 'B'
-- function util.load_dataset(name, nc, opt, nc)
--   local tensortype = torch.getdefaulttensortype()
--   torch.setdefaulttensortype('torch.FloatTensor')
--
--   local new_opt = options.clone(opt)
--   new_opt.manualSeed = torch.random(1, 10000) -- fix seed
--   new_opt.nc = nc
--   torch.manualSeed(new_opt.manualSeed)
--   local data_loader = paths.dofile('../data/data.lua')
--   new_opt.phase = new_opt.phase .. name
--   local data = data_loader.new(new_opt.nThreads, new_opt)
--   print("Dataset Size " .. name .. ": ", data:size())
--
--   torch.setdefaulttensortype(tensortype)
--   return data
-- end



function util.cudnn(net)
	require 'cudnn'
	require 'util/cudnn_convert_custom'
	return cudnn_convert_custom(net, cudnn)
end

function util.save_model(net, net_name, weight)
  if weight > 0.0 then
	   torch.save(paths.concat(opt.checkpoints_dir, opt.name, net_name), net:clearState())
  end
end




return util


================================================
FILE: util/visualizer.lua
================================================
-------------------------------------------------------------
-- Various utilities for visualization through the web server
-------------------------------------------------------------

local visualizer = {}

require 'torch'
disp = nil
print(opt)
if opt.display_id > 0 then -- [hack]: assume that opt already existed
  disp = require 'display'
end
util = require 'util/util'
require 'image'

-- function visualizer
function visualizer.disp_image(img_data, win_size, display_id, title)
  images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size))
  disp.image(images, {win=display_id, title=title})
end

function visualizer.save_results(img_data, output_path)
  local tensortype = torch.getdefaulttensortype()
  torch.setdefaulttensortype('torch.FloatTensor')
  local image_out = nil
  local win_size = opt.display_winsize
  images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size)))

  if images:dim() == 3 then
    image_out = images
  else
    for i = 1,images:size(1) do
      img = images[i]
      if image_out == nil then
        image_out = img
      else
        image_out = torch.cat(image_out, img)
      end
    end
  end
  image.save(output_path, image_out)
  torch.setdefaulttensortype(tensortype)
end

function visualizer.save_images(imgs, save_dir, impaths, s1, s2)
  local tensortype = torch.getdefaulttensortype()
  torch.setdefaulttensortype('torch.FloatTensor')
  batchSize = imgs:size(1)
  imgs_f = util.deprocess_batch(imgs):float()
  paths.mkdir(save_dir)
  for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio
    if s1 ~= nil and s2 ~= nil then
      im_s = image.scale(imgs_f[i], s1, s2):float()
    else
      im_s = imgs_f[i]:float()
    end
    img_to_save = torch.FloatTensor(im_s:size()):copy(im_s)
    image.save(paths.concat(save_dir, impaths[i]), img_to_save)
  end
  torch.setdefaulttensortype(tensortype)
end

return visualizer
Download .txt
gitextract_7vccy3wl/

├── .gitignore
├── LICENSE
├── README.md
├── data/
│   ├── aligned_data_loader.lua
│   ├── base_data_loader.lua
│   ├── data.lua
│   ├── data_util.lua
│   ├── dataset.lua
│   ├── donkey_folder.lua
│   └── unaligned_data_loader.lua
├── examples/
│   ├── test_vangogh_style_on_ae_photos.sh
│   └── train_maps.sh
├── models/
│   ├── architectures.lua
│   ├── base_model.lua
│   ├── bigan_model.lua
│   ├── content_gan_model.lua
│   ├── cycle_gan_model.lua
│   ├── one_direction_test_model.lua
│   └── pix2pix_model.lua
├── options.lua
├── pretrained_models/
│   ├── download_model.sh
│   ├── download_vgg.sh
│   └── places_vgg.prototxt
├── test.lua
├── train.lua
└── util/
    ├── InstanceNormalization.lua
    ├── VGG_preprocess.lua
    ├── content_loss.lua
    ├── cudnn_convert_custom.lua
    ├── image_pool.lua
    ├── plot_util.lua
    ├── util.lua
    └── visualizer.lua
Condensed preview — 33 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (161K chars).
[
  {
    "path": ".gitignore",
    "chars": 723,
    "preview": "datasets/\ncheckpoints/\nresults/\nbuild/\ndist/\n*.png\ntorch.egg-info/\n*/**/__pycache__\ntorch/version.py\ntorch/csrc/generic/"
  },
  {
    "path": "LICENSE",
    "chars": 3565,
    "preview": "Copyright (c) 2017, Jun-Yan Zhu and Taesung Park\nAll rights reserved.\n\nRedistribution and use in source and binary forms"
  },
  {
    "path": "README.md",
    "chars": 14763,
    "preview": "<img src='imgs/horse2zebra.gif' align=\"right\" width=384>\n\n<br><br><br>\n\n# CycleGAN\n### [PyTorch](https://github.com/juny"
  },
  {
    "path": "data/aligned_data_loader.lua",
    "chars": 1400,
    "preview": "--------------------------------------------------------------------------------\n-- Subclass of BaseDataLoader that prov"
  },
  {
    "path": "data/base_data_loader.lua",
    "chars": 1299,
    "preview": "--------------------------------------------------------------------------------\n-- Base Class for Providing Data\n------"
  },
  {
    "path": "data/data.lua",
    "chars": 2725,
    "preview": "--[[\n    This data loader is a modified version of the one from dcgan.torch\n    (see https://github.com/soumith/dcgan.to"
  },
  {
    "path": "data/data_util.lua",
    "chars": 764,
    "preview": "local data_util = {}\n\nrequire 'torch'\n-- options =  require '../options.lua'\n-- load dataset from the file system\n-- |na"
  },
  {
    "path": "data/dataset.lua",
    "chars": 14094,
    "preview": "--[[\n    Copyright (c) 2015-present, Facebook, Inc.\n    All rights reserved.\n\n    This source code is licensed under the"
  },
  {
    "path": "data/donkey_folder.lua",
    "chars": 6171,
    "preview": "\n--[[\n    This data loader is a modified version of the one from dcgan.torch\n    (see https://github.com/soumith/dcgan.t"
  },
  {
    "path": "data/unaligned_data_loader.lua",
    "chars": 1502,
    "preview": "--------------------------------------------------------------------------------\n-- Subclass of BaseDataLoader that prov"
  },
  {
    "path": "examples/test_vangogh_style_on_ae_photos.sh",
    "chars": 845,
    "preview": "#!/bin/sh\n\n## This script download the dataset and pre-trained network,\n## and generates style transferred images.\n\n# Do"
  },
  {
    "path": "examples/train_maps.sh",
    "chars": 781,
    "preview": "DB_NAME='maps'\nGPU_ID=1\nDISPLAY_ID=1\nNET_G=resnet_6blocks\nNET_D=basic\nMODEL=cycle_gan\nSAVE_EPOCH=5\nALIGN_DATA=0\nLAMBDA=1"
  },
  {
    "path": "models/architectures.lua",
    "chars": 18199,
    "preview": "require 'nngraph'\n\n\n----------------------------------------------------------------------------\nlocal function weights_"
  },
  {
    "path": "models/base_model.lua",
    "chars": 1619,
    "preview": "--------------------------------------------------------------------------------\n-- Base Class for Providing Models\n----"
  },
  {
    "path": "models/bigan_model.lua",
    "chars": 9023,
    "preview": "local class = require 'class'\nrequire 'models.base_model'\nrequire 'models.architectures'\nrequire 'util.image_pool'\nutil "
  },
  {
    "path": "models/content_gan_model.lua",
    "chars": 7364,
    "preview": "local class = require 'class'\nrequire 'models.base_model'\nrequire 'models.architectures'\nrequire 'util.image_pool'\nutil "
  },
  {
    "path": "models/cycle_gan_model.lua",
    "chars": 12910,
    "preview": "local class = require 'class'\nrequire 'models.base_model'\nrequire 'models.architectures'\nrequire 'util.image_pool'\n\nutil"
  },
  {
    "path": "models/one_direction_test_model.lua",
    "chars": 2167,
    "preview": "local class = require 'class'\nrequire 'models.base_model'\nrequire 'models.architectures'\nrequire 'util.image_pool'\n\nutil"
  },
  {
    "path": "models/pix2pix_model.lua",
    "chars": 7639,
    "preview": "local class = require 'class'\nrequire 'models.base_model'\nrequire 'models.architectures'\nrequire 'util.image_pool'\nutil "
  },
  {
    "path": "options.lua",
    "chars": 7679,
    "preview": "--------------------------------------------------------------------------------\n-- Configure options\n------------------"
  },
  {
    "path": "pretrained_models/download_model.sh",
    "chars": 526,
    "preview": "FILE=$1\n\necho \"Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style"
  },
  {
    "path": "pretrained_models/download_vgg.sh",
    "chars": 340,
    "preview": "URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel\nMODEL_FILE1=./models/"
  },
  {
    "path": "pretrained_models/places_vgg.prototxt",
    "chars": 8373,
    "preview": "name: \"VGG-Places365\"\ninput: \"data\"\ninput_dim: 1\ninput_dim: 3\ninput_dim: 224\ninput_dim: 224\nlayer {\n  name: \"conv1_1\"\n  "
  },
  {
    "path": "test.lua",
    "chars": 3781,
    "preview": "-- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua\n--\n-- code derived from https://github.co"
  },
  {
    "path": "train.lua",
    "chars": 5808,
    "preview": "-- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua\n-- code derived from https://git"
  },
  {
    "path": "util/InstanceNormalization.lua",
    "chars": 3320,
    "preview": "require 'nn'\n\n--[[\n  Implements instance normalization as described in the paper\n\n  Instance Normalization: The Missing "
  },
  {
    "path": "util/VGG_preprocess.lua",
    "chars": 925,
    "preview": "-- define nn module for VGG postprocessing\nlocal VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module'"
  },
  {
    "path": "util/content_loss.lua",
    "chars": 2725,
    "preview": "require 'torch'\nrequire 'nn'\nlocal content = {}\n\nfunction content.defineVGG(content_layer)\n  local contentFunc = nn.Sequ"
  },
  {
    "path": "util/cudnn_convert_custom.lua",
    "chars": 2214,
    "preview": "-- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua\n-- removed error on nngraph\n\n-- modules t"
  },
  {
    "path": "util/image_pool.lua",
    "chars": 1306,
    "preview": "local class = require 'class'\nImagePool= class('ImagePool')\n\nrequire 'torch'\nrequire 'image'\n\nfunction ImagePool:__init("
  },
  {
    "path": "util/plot_util.lua",
    "chars": 931,
    "preview": "local class = require 'class'\nPlotUtil = class('PlotUtil')\n\n\nrequire 'torch'\ndisp = require 'display'\nutil = require 'ut"
  },
  {
    "path": "util/util.lua",
    "chars": 6366,
    "preview": "--\n-- code derived from https://github.com/soumith/dcgan.torch\n--\n\nlocal util = {}\n\nrequire 'torch'\n\n\nfunction util.Bias"
  },
  {
    "path": "util/visualizer.lua",
    "chars": 1956,
    "preview": "-------------------------------------------------------------\n-- Various utilities for visualization through the web ser"
  }
]

About this extraction

This page contains the full source code of the junyanz/CycleGAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 33 files (150.2 KB), approximately 45.9k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!