Repository: junyanz/CycleGAN Branch: master Commit: 40b4498526de Files: 33 Total size: 150.2 KB Directory structure: gitextract_7vccy3wl/ ├── .gitignore ├── LICENSE ├── README.md ├── data/ │ ├── aligned_data_loader.lua │ ├── base_data_loader.lua │ ├── data.lua │ ├── data_util.lua │ ├── dataset.lua │ ├── donkey_folder.lua │ └── unaligned_data_loader.lua ├── examples/ │ ├── test_vangogh_style_on_ae_photos.sh │ └── train_maps.sh ├── models/ │ ├── architectures.lua │ ├── base_model.lua │ ├── bigan_model.lua │ ├── content_gan_model.lua │ ├── cycle_gan_model.lua │ ├── one_direction_test_model.lua │ └── pix2pix_model.lua ├── options.lua ├── pretrained_models/ │ ├── download_model.sh │ ├── download_vgg.sh │ └── places_vgg.prototxt ├── test.lua ├── train.lua └── util/ ├── InstanceNormalization.lua ├── VGG_preprocess.lua ├── content_loss.lua ├── cudnn_convert_custom.lua ├── image_pool.lua ├── plot_util.lua ├── util.lua └── visualizer.lua ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ datasets/ checkpoints/ results/ build/ dist/ *.png torch.egg-info/ */**/__pycache__ torch/version.py torch/csrc/generic/TensorMethods.cpp torch/lib/*.so* torch/lib/*.dylib* torch/lib/*.h torch/lib/build torch/lib/tmp_install torch/lib/include torch/lib/torch_shm_manager torch/csrc/cudnn/cuDNN.cpp torch/csrc/nn/THNN.cwrap torch/csrc/nn/THNN.cpp torch/csrc/nn/THCUNN.cwrap torch/csrc/nn/THCUNN.cpp torch/csrc/nn/THNN_generic.cwrap torch/csrc/nn/THNN_generic.cpp torch/csrc/nn/THNN_generic.h docs/src/**/* test/data/legacy_modules.t7 test/data/gpu_tensors.pt test/htmlcov test/.coverage */*.pyc */**/*.pyc */**/**/*.pyc */**/**/**/*.pyc */**/**/**/**/*.pyc */*.so* */**/*.so* */**/*.dylib* test/data/legacy_serialized.pt *~ ================================================ FILE: LICENSE ================================================ Copyright (c) 2017, Jun-Yan Zhu and Taesung Park All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --------------------------- LICENSE FOR pix2pix -------------------------------- BSD License For pix2pix software Copyright (c) 2016, Phillip Isola and Jun-Yan Zhu All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. ----------------------------- LICENSE FOR DCGAN -------------------------------- BSD License For dcgan.torch software Copyright (c) 2015, Facebook, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. Neither the name Facebook nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================


# CycleGAN ### [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) | [project page](https://junyanz.github.io/CycleGAN/) | [paper](https://arxiv.org/pdf/1703.10593.pdf) Torch implementation for learning an image-to-image translation (i.e. [pix2pix](https://github.com/phillipi/pix2pix)) **without** input-output pairs, for example: **New**: Please check out [contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT), our new unpaired image-to-image translation model that enables fast and memory-efficient training. [Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks](https://junyanz.github.io/CycleGAN/) [Jun-Yan Zhu](https://people.eecs.berkeley.edu/~junyanz/)\*, [Taesung Park](https://taesung.me/)\*, [Phillip Isola](http://web.mit.edu/phillipi/), [Alexei A. Efros](https://people.eecs.berkeley.edu/~efros/) Berkeley AI Research Lab, UC Berkeley In ICCV 2017. (* equal contributions) This package includes CycleGAN, [pix2pix](https://github.com/phillipi/pix2pix), as well as other methods like [BiGAN](https://arxiv.org/abs/1605.09782)/[ALI](https://ishmaelbelghazi.github.io/ALI/) and Apple's paper [S+U learning](https://arxiv.org/pdf/1612.07828.pdf). The code was written by [Jun-Yan Zhu](https://github.com/junyanz) and [Taesung Park](https://github.com/taesung). **Update**: Please check out [PyTorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation for CycleGAN and pix2pix. The PyTorch version is under active development and can produce results comparable or better than this Torch version. ## Other implementations:

[Tensorflow] (by Harry Yang), [Tensorflow] (by Archit Rathore), [Tensorflow] (by Van Huy), [Tensorflow] (by Xiaowei Hu), [Tensorflow-simple] (by Zhenliang He), [TensorLayer] (by luoxier), [Chainer] (by Yanghua Jin), [Minimal PyTorch] (by yunjey), [Mxnet] (by Ldpe2G), [lasagne/Keras] (by tjwei), [Keras] (by Simon Karlsson)

## Applications ### Monet Paintings to Photos ### Collection Style Transfer ### Object Transfiguration ### Season Transfer ### Photo Enhancement: Narrow depth of field ## Prerequisites - Linux or OSX - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested) - For MAC users, you need the Linux/GNU commands `gfind` and `gwc`, which can be installed with `brew install findutils coreutils`. ## Getting Started ### Installation - Install torch and dependencies from https://github.com/torch/distro - Install torch packages `nngraph`, `class`, `display` ```bash luarocks install nngraph luarocks install class luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec ``` - Clone this repo: ```bash git clone https://github.com/junyanz/CycleGAN cd CycleGAN ``` ### Apply a Pre-trained Model - Download the test photos (taken by [Alexei Efros](https://www.flickr.com/photos/aaefros)): ``` bash ./datasets/download_dataset.sh ae_photos ``` - Download the pre-trained model `style_cezanne` (For CPU model, use `style_cezanne_cpu`): ``` bash ./pretrained_models/download_model.sh style_cezanne ``` - Now, let's generate Paul Cézanne style images: ``` DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop="scale_width" th test.lua ``` The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`. Please refer to [Model Zoo](#model-zoo) for more pre-trained models. `./examples/test_vangogh_style_on_ae_photos.sh` is an example script that downloads the pretrained Van Gogh style network and runs it on Efros's photos. ### Train - Download a dataset (e.g. zebra and horse images from ImageNet): ```bash bash ./datasets/download_dataset.sh horse2zebra ``` - Train a model: ```bash DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model th train.lua ``` - (CPU only) The same training command without using a GPU or CUDNN. Setting the environment variables ```gpu=0 cudnn=0``` forces CPU only ```bash DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model gpu=0 cudnn=0 th train.lua ``` - (Optionally) start the display server to view results as the model trains. (See [Display UI](#display-ui) for more details): ```bash th -ldisplay.start 8000 0.0.0.0 ``` ### Test - Finally, test the model: ```bash DATA_ROOT=./datasets/horse2zebra name=horse2zebra_model phase=test th test.lua ``` The test results will be saved to an HTML file here: `./results/horse2zebra_model/latest_test/index.html`. ## Model Zoo Download the pre-trained models with the following script. The model will be saved to `./checkpoints/model_name/latest_net_G.t7`. ```bash bash ./pretrained_models/download_model.sh model_name ``` - `orange2apple` (orange -> apple) and `apple2orange`: trained on ImageNet categories `apple` and `orange`. - `horse2zebra` (horse -> zebra) and `zebra2horse` (zebra -> horse): trained on ImageNet categories `horse` and `zebra`. - `style_monet` (landscape photo -> Monet painting style), `style_vangogh` (landscape photo -> Van Gogh painting style), `style_ukiyoe` (landscape photo -> Ukiyo-e painting style), `style_cezanne` (landscape photo -> Cezanne painting style): trained on paintings and Flickr landscape photos. - `monet2photo` (Monet paintings -> real landscape): trained on paintings and Flickr landscape photographs. - `cityscapes_photo2label` (street scene -> label) and `cityscapes_label2photo` (label -> street scene): trained on the Cityscapes dataset. - `map2sat` (map -> aerial photo) and `sat2map` (aerial photo -> map): trained on Google maps. - `iphone2dslr_flower` (iPhone photos of flowers -> DSLR photos of flowers): trained on Flickr photos. CPU models can be downloaded using: ```bash bash pretrained_models/download_model.sh _cpu ``` , where `` can be `horse2zebra`, `style_monet`, etc. You just need to append `_cpu` to the target model. ## Training and Test Details To train a model, ```bash DATA_ROOT=/path/to/data/ name=expt_name th train.lua ``` Models are saved to `./checkpoints/expt_name` (can be changed by passing `checkpoint_dir=your_dir` in train.lua). See `opt_train` in `options.lua` for additional training options. To test the model, ```bash DATA_ROOT=/path/to/data/ name=expt_name phase=test th test.lua ``` This will run the model named `expt_name` in both directions on all images in `/path/to/data/testA` and `/path/to/data/testB`. A webpage with result images will be saved to `./results/expt_name` (can be changed by passing `results_dir=your_dir` in test.lua). See `opt_test` in `options.lua` for additional test options. Please use `model=one_direction_test` if you only would like to generate outputs of the trained network in only one direction, and specify `which_direction=AtoB` or `which_direction=BtoA` to set the direction. There are other options that can be used. For example, you can specify `resize_or_crop=crop` option to avoid resizing the image to squares. This is indeed how we trained GTA2Cityscapes model in the projet [webpage](https://junyanz.github.io/CycleGAN/) and [Cycada](https://arxiv.org/pdf/1711.03213.pdf) model. We prepared the images at 1024px resolution, and used `resize_or_crop=crop fineSize=360` to work with the cropped images of size 360x360. We also used `lambda_identity=1.0`. ## Datasets Download the datasets using the following script. Many of the datasets were collected by other researchers. Please cite their papers if you use the data. ```bash bash ./datasets/download_dataset.sh dataset_name ``` - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). [[Citation](datasets/bibtex/facades.tex)] - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). [[Citation](datasets/bibtex/cityscapes.tex)]. Note: Due to license issue, we do not host the dataset on our repo. Please download the dataset directly from the Cityscapes webpage. Please refer to `./datasets/prepare_cityscapes_dataset.py` for more detail. - `maps`: 1096 training images scraped from Google Maps. - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `wild horse` and `zebra` - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org/) using the keywords `apple` and `navel orange`. - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. - `iphone2dslr_flower`: both classes of images were downloaded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. ## Display UI Optionally, for displaying images during training and test, use the [display package](https://github.com/szym/display). - Install it with: `luarocks install https://raw.githubusercontent.com/szym/display/master/display-scm-0.rockspec` - Then start the server with: `th -ldisplay.start` - Open this URL in your browser: [http://localhost:8000](http://localhost:8000) By default, the server listens on localhost. Pass `0.0.0.0` to allow external connections on any interface: ```bash th -ldisplay.start 8000 0.0.0.0 ``` Then open `http://(hostname):(port)/` in your browser to load the remote desktop. ## Setup Training and Test data To train CycleGAN model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting ``phase='train'`` in `test.lua`. You can also create subdirectories `testA` and `testB` if you have test data. You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. See the following section for more discussion. ## Failure cases Our model does not work well when the test image is rather different from the images on which the model is trained, as is the case in the figure to the left (we trained on horses and zebras without riders, but test here one a horse with a rider). See additional typical failure cases [here](https://junyanz.github.io/CycleGAN/images/failures.jpg). On translation tasks that involve color and texture changes, like many of those reported above, the method often succeeds. We have also explored tasks that require geometric changes, with little success. For example, on the task of `dog<->cat` transfiguration, the learned translation degenerates into making minimal changes to the input. We also observe a lingering gap between the results achievable with paired training data and those achieved by our unpaired method. In some cases, this gap may be very hard -- or even impossible,-- to close: for example, our method sometimes permutes the labels for tree and building in the output of the cityscapes photos->labels task. ## Citation If you use this code for your research, please cite our [paper](https://junyanz.github.io/CycleGAN/): ``` @inproceedings{CycleGAN2017, title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss}, author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, year={2017} } ``` ## Related Projects: **[contrastive-unpaired-translation](https://github.com/taesungp/contrastive-unpaired-translation) (CUT)**
**[pix2pix-Torch](https://github.com/phillipi/pix2pix) | [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) | [BicycleGAN](https://github.com/junyanz/BicycleGAN) | [vid2vid](https://tcwang0509.github.io/vid2vid/) | [SPADE/GauGAN](https://github.com/NVlabs/SPADE)**
**[iGAN](https://github.com/junyanz/iGAN) | [GAN Dissection](https://github.com/CSAILVision/GANDissect) | [GAN Paint](http://ganpaint.io/)** ## Cat Paper Collection If you love cats, and love reading cool graphics, vision, and ML papers, please check out the Cat Paper [Collection](https://github.com/junyanz/CatPapers). ## Acknowledgments Code borrows from [pix2pix](https://github.com/phillipi/pix2pix) and [DCGAN](https://github.com/soumith/dcgan.torch). The data loader is modified from [DCGAN](https://github.com/soumith/dcgan.torch) and [Context-Encoder](https://github.com/pathak22/context-encoder). The generative network is adopted from [neural-style](https://github.com/jcjohnson/neural-style) with [Instance Normalization](https://github.com/DmitryUlyanov/texture_nets/blob/master/InstanceNormalization.lua). ================================================ FILE: data/aligned_data_loader.lua ================================================ -------------------------------------------------------------------------------- -- Subclass of BaseDataLoader that provides data from two datasets. -- The samples from the datasets are aligned -- The datasets are of the same size -------------------------------------------------------------------------------- require 'data.base_data_loader' local class = require 'class' data_util = paths.dofile('data_util.lua') AlignedDataLoader = class('AlignedDataLoader', 'BaseDataLoader') function AlignedDataLoader:__init(conf) BaseDataLoader.__init(self, conf) conf = conf or {} end function AlignedDataLoader:name() return 'AlignedDataLoader' end function AlignedDataLoader:Initialize(opt) opt.align_data = 1 self.idx_A = {1, opt.input_nc} self.idx_B = {opt.input_nc+1, opt.input_nc+opt.output_nc} local nc = 3--opt.input_nc + opt.output_nc self.data = data_util.load_dataset('', opt, nc) end -- actually fetches the data -- |return|: a table of two tables, each corresponding to -- the batch for dataset A and dataset B function AlignedDataLoader:LoadBatchForAllDatasets() local batch_data, path = self.data:getBatch() local batchA = batch_data[{ {}, self.idx_A, {}, {} }] local batchB = batch_data[{ {}, self.idx_B, {}, {} }] return batchA, batchB, path, path end -- returns the size of each dataset function AlignedDataLoader:size(dataset) return self.data:size() end ================================================ FILE: data/base_data_loader.lua ================================================ -------------------------------------------------------------------------------- -- Base Class for Providing Data -------------------------------------------------------------------------------- local class = require 'class' require 'torch' BaseDataLoader = class('BaseDataLoader') function BaseDataLoader:__init(conf) conf = conf or {} self.data_tm = torch.Timer() end function BaseDataLoader:name() return 'BaseDataLoader' end function BaseDataLoader:Initialize(opt) end -- actually fetches the data -- |return|: a table of two tables, each corresponding to -- the batch for dataset A and dataset B function BaseDataLoader:LoadBatchForAllDatasets() return {},{},{},{} end -- returns the next batch -- a wrapper of getBatch(), which is meant to be overriden by subclasses -- |return|: a table of two tables, each corresponding to -- the batch for dataset A and dataset B function BaseDataLoader:GetNextBatch() self.data_tm:reset() self.data_tm:resume() local dataA, dataB, pathA, pathB = self:LoadBatchForAllDatasets() self.data_tm:stop() return dataA, dataB, pathA, pathB end function BaseDataLoader:time_elapsed_to_fetch_data() return self.data_tm:time().real end -- returns the size of each dataset function BaseDataLoader:size(dataset) return 0 end ================================================ FILE: data/data.lua ================================================ --[[ This data loader is a modified version of the one from dcgan.torch (see https://github.com/soumith/dcgan.torch/blob/master/data/data.lua). Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] ]]-- local Threads = require 'threads' Threads.serialization('threads.sharedserialize') local data = {} local result = {} local unpack = unpack and unpack or table.unpack function data.new(n, opt_) opt_ = opt_ or {} local self = {} for k,v in pairs(data) do self[k] = v end local donkey_file = 'donkey_folder.lua' -- print('n..' .. n) if n > 0 then local options = opt_ self.threads = Threads(n, function() require 'torch' end, function(idx) opt = options tid = idx local seed = (opt.manualSeed and opt.manualSeed or 0) + idx torch.manualSeed(seed) torch.setnumthreads(1) print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) assert(options, 'options not found') assert(opt, 'opt not given') print(opt) paths.dofile(donkey_file) end ) else if donkey_file then paths.dofile(donkey_file) end -- print('empty threads') self.threads = {} function self.threads:addjob(f1, f2) f2(f1()) end function self.threads:dojob() end function self.threads:synchronize() end end local nSamples = 0 self.threads:addjob(function() return trainLoader:size() end, function(c) nSamples = c end) self.threads:synchronize() self._size = nSamples for i = 1, n do self.threads:addjob(self._getFromThreads, self._pushResult) end -- print(self.threads) return self end function data._getFromThreads() assert(opt.batchSize, 'opt.batchSize not found') return trainLoader:sample(opt.batchSize) end function data._pushResult(...) local res = {...} if res == nil then self.threads:synchronize() end result[1] = res end function data:getBatch() -- queue another job self.threads:addjob(self._getFromThreads, self._pushResult) self.threads:dojob() local res = result[1] img_data = res[1] img_paths = res[3] result[1] = nil if torch.type(img_data) == 'table' then img_data = unpack(img_data) end return img_data, img_paths end function data:size() return self._size end return data ================================================ FILE: data/data_util.lua ================================================ local data_util = {} require 'torch' -- options = require '../options.lua' -- load dataset from the file system -- |name|: name of the dataset. It's currently either 'A' or 'B' function data_util.load_dataset(name, opt, nc) local tensortype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor') local new_opt = options.clone(opt) new_opt.manualSeed = torch.random(1, 10000) -- fix seed new_opt.nc = nc torch.manualSeed(new_opt.manualSeed) local data_loader = paths.dofile('../data/data.lua') new_opt.phase = new_opt.phase .. name local data = data_loader.new(new_opt.nThreads, new_opt) print("Dataset Size " .. name .. ": ", data:size()) torch.setdefaulttensortype(tensortype) return data end return data_util ================================================ FILE: data/dataset.lua ================================================ --[[ Copyright (c) 2015-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- require 'torch' torch.setdefaulttensortype('torch.FloatTensor') local ffi = require 'ffi' local class = require('pl.class') local dir = require 'pl.dir' local tablex = require 'pl.tablex' local argcheck = require 'argcheck' require 'sys' require 'xlua' require 'image' local dataset = torch.class('dataLoader') local initcheck = argcheck{ pack=true, help=[[ A dataset class for images in a flat folder structure (folder-name is class-name). Optimized for extremely large datasets (upwards of 14 million images). Tested only on Linux (as it uses command-line linux utilities to scale up) ]], {check=function(paths) local out = true; for k,v in ipairs(paths) do if type(v) ~= 'string' then print('paths can only be of string input'); out = false end end return out end, name="paths", type="table", help="Multiple paths of directories with images"}, {name="sampleSize", type="table", help="a consistent sample size to resize the images"}, {name="split", type="number", help="Percentage of split to go to Training" }, {name="serial_batches", type="number", help="if randomly sample training images"}, {name="samplingMode", type="string", help="Sampling mode: random | balanced ", default = "balanced"}, {name="verbose", type="boolean", help="Verbose mode during initialization", default = false}, {name="loadSize", type="table", help="a size to load the images to, initially", opt = true}, {name="forceClasses", type="table", help="If you want this loader to map certain classes to certain indices, " .. "pass a classes table that has {classname : classindex} pairs." .. " For example: {3 : 'dog', 5 : 'cat'}" .. "This function is very useful when you want two loaders to have the same " .. "class indices (trainLoader/testLoader for example)", opt = true}, {name="sampleHookTrain", type="function", help="applied to sample during training(ex: for lighting jitter). " .. "It takes the image path as input", opt = true}, {name="sampleHookTest", type="function", help="applied to sample during testing", opt = true}, } function dataset:__init(...) -- argcheck local args = initcheck(...) print(args) for k,v in pairs(args) do self[k] = v end if not self.loadSize then self.loadSize = self.sampleSize; end if not self.sampleHookTrain then self.sampleHookTrain = self.defaultSampleHook end if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end self.image_count = 1 -- print('image_count_init', self.image_count) -- find class names self.classes = {} local classPaths = {} if self.forceClasses then for k,v in pairs(self.forceClasses) do self.classes[k] = v classPaths[k] = {} end end local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end -- loop over each paths folder, get list of unique class names, -- also store the directory paths per class -- for each class, for k,path in ipairs(self.paths) do -- print('path', path) local dirs = {} -- hack dirs[1] = path -- local dirs = dir.getdirectories(path); for k,dirpath in ipairs(dirs) do local class = paths.basename(dirpath) local idx = tableFind(self.classes, class) -- print(class) -- print(idx) if not idx then table.insert(self.classes, class) idx = #self.classes classPaths[idx] = {} end if not tableFind(classPaths[idx], dirpath) then table.insert(classPaths[idx], dirpath); end end end self.classIndices = {} for k,v in ipairs(self.classes) do self.classIndices[v] = k end -- define command-line tools, try your best to maintain OSX compatibility local wc = 'wc' local cut = 'cut' local find = 'find -H' -- if folder name is symlink, do find inside it after dereferencing if ffi.os == 'OSX' then wc = 'gwc' cut = 'gcut' find = 'gfind' end ---------------------------------------------------------------------- -- Options for the GNU find command local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'} local findOptions = ' -iname "*.' .. extensionList[1] .. '"' for i=2,#extensionList do findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"' end -- find the image path names self.imagePath = torch.CharTensor() -- path to each image in dataset self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes) self.classList = {} -- index of imageList to each image of a particular class self.classListSample = self.classList -- the main list used when sampling data print('running "find" on each class directory, and concatenate all' .. ' those filenames into a single file containing all image paths for a given class') -- so, generates one file per class local classFindFiles = {} for i=1,#self.classes do classFindFiles[i] = os.tmpname() end local combinedFindList = os.tmpname(); local tmpfile = os.tmpname() local tmphandle = assert(io.open(tmpfile, 'w')) -- iterate over classes for i, class in ipairs(self.classes) do -- iterate over classPaths for j,path in ipairs(classPaths[i]) do local command = find .. ' "' .. path .. '" ' .. findOptions .. ' >>"' .. classFindFiles[i] .. '" \n' tmphandle:write(command) end end io.close(tmphandle) os.execute('bash ' .. tmpfile) os.execute('rm -f ' .. tmpfile) print('now combine all the files to a single large file') local tmpfile = os.tmpname() local tmphandle = assert(io.open(tmpfile, 'w')) -- concat all finds to a single large file in the order of self.classes for i=1,#self.classes do local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n' tmphandle:write(command) end io.close(tmphandle) os.execute('bash ' .. tmpfile) os.execute('rm -f ' .. tmpfile) --========================================================================== print('load the large concatenated list of sample paths to self.imagePath') local cmd = wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '" print('cmd..' .. cmd) local maxPathLength = tonumber(sys.fexecute(wc .. " -L '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) + 1 local length = tonumber(sys.fexecute(wc .. " -l '" .. combinedFindList .. "' |" .. cut .. " -f1 -d' '")) assert(length > 0, "Could not find any image file in the given input paths") assert(maxPathLength > 0, "paths of files are length 0?") self.imagePath:resize(length, maxPathLength):fill(0) local s_data = self.imagePath:data() local count = 0 for line in io.lines(combinedFindList) do ffi.copy(s_data, line) s_data = s_data + maxPathLength if self.verbose and count % 10000 == 0 then xlua.progress(count, length) end; count = count + 1 end self.numSamples = self.imagePath:size(1) if self.verbose then print(self.numSamples .. ' samples found.') end --========================================================================== print('Updating classList and imageClass appropriately') self.imageClass:resize(self.numSamples) local runningIndex = 0 for i=1,#self.classes do if self.verbose then xlua.progress(i, #(self.classes)) end local length = tonumber(sys.fexecute(wc .. " -l '" .. classFindFiles[i] .. "' |" .. cut .. " -f1 -d' '")) if length == 0 then error('Class has zero samples') else self.classList[i] = torch.linspace(runningIndex + 1, runningIndex + length, length):long() self.imageClass[{{runningIndex + 1, runningIndex + length}}]:fill(i) end runningIndex = runningIndex + length end --========================================================================== -- clean up temporary files print('Cleaning up temporary files') local tmpfilelistall = '' for i=1,#(classFindFiles) do tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"' if i % 1000 == 0 then os.execute('rm -f ' .. tmpfilelistall) tmpfilelistall = '' end end os.execute('rm -f ' .. tmpfilelistall) os.execute('rm -f "' .. combinedFindList .. '"') --========================================================================== if self.split == 100 then self.testIndicesSize = 0 else print('Splitting training and test sets to a ratio of ' .. self.split .. '/' .. (100-self.split)) self.classListTrain = {} self.classListTest = {} self.classListSample = self.classListTrain local totalTestSamples = 0 -- split the classList into classListTrain and classListTest for i=1,#self.classes do local list = self.classList[i] local count = self.classList[i]:size(1) local splitidx = math.floor((count * self.split / 100) + 0.5) -- +round local perm = torch.randperm(count) self.classListTrain[i] = torch.LongTensor(splitidx) for j=1,splitidx do self.classListTrain[i][j] = list[perm[j]] end if splitidx == count then -- all samples were allocated to train set self.classListTest[i] = torch.LongTensor() else self.classListTest[i] = torch.LongTensor(count-splitidx) totalTestSamples = totalTestSamples + self.classListTest[i]:size(1) local idx = 1 for j=splitidx+1,count do self.classListTest[i][idx] = list[perm[j]] idx = idx + 1 end end end -- Now combine classListTest into a single tensor self.testIndices = torch.LongTensor(totalTestSamples) self.testIndicesSize = totalTestSamples local tdata = self.testIndices:data() local tidx = 0 for i=1,#self.classes do local list = self.classListTest[i] if list:dim() ~= 0 then local ldata = list:data() for j=0,list:size(1)-1 do tdata[tidx] = ldata[j] tidx = tidx + 1 end end end end end -- size(), size(class) function dataset:size(class, list) list = list or self.classList if not class then return self.numSamples elseif type(class) == 'string' then return list[self.classIndices[class]]:size(1) elseif type(class) == 'number' then return list[class]:size(1) end end -- getByClass function dataset:getByClass(class) local index = 0 if self.serial_batches == 1 then index = math.fmod(self.image_count-1, self.classListSample[class]:nElement())+1 self.image_count = self.image_count +1 else index = math.ceil(torch.uniform() * self.classListSample[class]:nElement()) end local imgpath = ffi.string(torch.data(self.imagePath[self.classListSample[class][index]])) return self:sampleHookTrain(imgpath), imgpath end -- converts a table of samples (and corresponding labels) to a clean tensor local function tableToOutput(self, dataTable, scalarTable) local data, scalarLabels, labels if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then assert(#scalarTable == 1) data = torch.Tensor(1, dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3)) data[1]:copy(dataTable[1]) scalarLabels = torch.LongTensor(#scalarTable):fill(-1111) else local quantity = #scalarTable data = torch.Tensor(quantity, self.sampleSize[1], self.sampleSize[2], self.sampleSize[3]) scalarLabels = torch.LongTensor(quantity):fill(-1111) for i=1,#dataTable do data[i]:copy(dataTable[i]) scalarLabels[i] = scalarTable[i] end end return data, scalarLabels end -- sampler, samples from the training set. function dataset:sample(quantity) assert(quantity) local dataTable = {} local scalarTable = {} local samplePaths = {} for i=1,quantity do local class = torch.random(1, #self.classes) local out, imgpath = self:getByClass(class) table.insert(dataTable, out) table.insert(scalarTable, class) samplePaths[i] = imgpath end local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) return data, scalarLabels, samplePaths-- filePaths end function dataset:get(i1, i2) local indices = torch.range(i1, i2); local quantity = i2 - i1 + 1; assert(quantity > 0) -- now that indices has been initialized, get the samples local dataTable = {} local scalarTable = {} for i=1,quantity do -- load the sample local imgpath = ffi.string(torch.data(self.imagePath[indices[i]])) local out = self:sampleHookTest(imgpath) table.insert(dataTable, out) table.insert(scalarTable, self.imageClass[indices[i]]) end local data, scalarLabels = tableToOutput(self, dataTable, scalarTable) return data, scalarLabels end return dataset ================================================ FILE: data/donkey_folder.lua ================================================ --[[ This data loader is a modified version of the one from dcgan.torch (see https://github.com/soumith/dcgan.torch/blob/master/data/donkey_folder.lua). Copyright (c) 2016, Deepak Pathak [See LICENSE file for details] Copyright (c) 2015-present, Facebook, Inc. All rights reserved. This source code is licensed under the BSD-style license found in the LICENSE file in the root directory of this source tree. An additional grant of patent rights can be found in the PATENTS file in the same directory. ]]-- require 'image' paths.dofile('dataset.lua') -- This file contains the data-loading logic and details. -- It is run by each data-loader thread. ------------------------------------------ -------- COMMON CACHES and PATHS -- Check for existence of opt.data if opt.DATA_ROOT then opt.data = paths.concat(opt.DATA_ROOT, opt.phase) else print(os.getenv('DATA_ROOT')) opt.data = paths.concat(os.getenv('DATA_ROOT'), opt.phase) end if not paths.dirp(opt.data) then error('Did not find directory: ' .. opt.data) end -- a cache file of the training metadata (if doesnt exist, will be created) local cache_prefix = opt.data:gsub('/', '_') os.execute(('mkdir -p %s'):format(opt.cache_dir)) local trainCache = paths.concat(opt.cache_dir, cache_prefix .. '_trainCache.t7') -------------------------------------------------------------------------------------------- local input_nc = opt.nc -- input channels local loadSize = {input_nc, opt.loadSize} local sampleSize = {input_nc, opt.fineSize} local function loadImage(path) local input = image.load(path, 3, 'float') local h = input:size(2) local w = input:size(3) local imA = image.crop(input, 0, 0, w/2, h) imA = image.scale(imA, loadSize[2], loadSize[2]) local imB = image.crop(input, w/2, 0, w, h) imB = image.scale(imB, loadSize[2], loadSize[2]) local perm = torch.LongTensor{3, 2, 1} imA = imA:index(1, perm) imA = imA:mul(2):add(-1) imB = imB:index(1, perm) imB = imB:mul(2):add(-1) assert(imA:max()<=1,"A: badly scaled inputs") assert(imA:min()>=-1,"A: badly scaled inputs") assert(imB:max()<=1,"B: badly scaled inputs") assert(imB:min()>=-1,"B: badly scaled inputs") local oW = sampleSize[2] local oH = sampleSize[2] local iH = imA:size(2) local iW = imA:size(3) if iH~=oH then h1 = math.ceil(torch.uniform(1e-2, iH-oH)) end if iW~=oW then w1 = math.ceil(torch.uniform(1e-2, iW-oW)) end if iH ~= oH or iW ~= oW then imA = image.crop(imA, w1, h1, w1 + oW, h1 + oH) imB = image.crop(imB, w1, h1, w1 + oW, h1 + oH) end if opt.flip == 1 and torch.uniform() > 0.5 then imA = image.hflip(imA) imB = image.hflip(imB) end local concatenated = torch.cat(imA,imB,1) return concatenated end local function loadSingleImage(path) local im = image.load(path, input_nc, 'float') if opt.resize_or_crop == 'resize_and_crop' then im = image.scale(im, loadSize[2], loadSize[2]) end if input_nc == 3 then local perm = torch.LongTensor{3, 2, 1} im = im:index(1, perm)--:mul(256.0): brg, rgb im = im:mul(2):add(-1) end assert(im:max()<=1,"A: badly scaled inputs") assert(im:min()>=-1,"A: badly scaled inputs") local oW = sampleSize[2] local oH = sampleSize[2] local iH = im:size(2) local iW = im:size(3) if (opt.resize_or_crop == 'resize_and_crop' ) then local h1, w1 = 0, 0 if iH~=oH then h1 = math.ceil(torch.uniform(1e-2, iH-oH)) end if iW~=oW then w1 = math.ceil(torch.uniform(1e-2, iW-oW)) end if iH ~= oH or iW ~= oW then im = image.crop(im, w1, h1, w1 + oW, h1 + oH) end elseif (opt.resize_or_crop == 'combined') then local sH = math.min(math.ceil(oH * torch.uniform(1+1e-2, 2.0-1e-2)), iH-1e-2) local sW = math.min(math.ceil(oW * torch.uniform(1+1e-2, 2.0-1e-2)), iW-1e-2) local h1 = math.ceil(torch.uniform(1e-2, iH-sH)) local w1 = math.ceil(torch.uniform(1e-2, iW-sW)) im = image.crop(im, w1, h1, w1 + sW, h1 + sH) im = image.scale(im, oW, oH) elseif (opt.resize_or_crop == 'crop') then local w = math.min(math.min(oH, iH),iW) w = math.floor(w/4)*4 local x = math.floor(torch.uniform(0, iW - w)) local y = math.floor(torch.uniform(0, iH - w)) im = image.crop(im, x, y, x+w, y+w) elseif (opt.resize_or_crop == 'scale_width') then w = oW h = torch.floor(iH * oW/iW) im = image.scale(im, w, h) elseif (opt.resize_or_crop == 'scale_height') then h = oH w = torch.floor(iW * oH / iH) im = image.scale(im, w, h) end if opt.flip == 1 and torch.uniform() > 0.5 then im = image.hflip(im) end return im end -- channel-wise mean and std. Calculate or load them from disk later in the script. local mean,std -------------------------------------------------------------------------------- -- Hooks that are used for each image that is loaded -- function to load the image, jitter it appropriately (random crops etc.) local trainHook_singleimage = function(self, path) collectgarbage() -- print('load single image') local im = loadSingleImage(path) return im end -- function that loads images that have juxtaposition -- of two images from two domains local trainHook_doubleimage = function(self, path) -- print('load double image') collectgarbage() local im = loadImage(path) return im end if opt.align_data > 0 then sample_nc = input_nc*2 trainHook = trainHook_doubleimage else sample_nc = input_nc trainHook = trainHook_singleimage end trainLoader = dataLoader{ paths = {opt.data}, loadSize = {input_nc, loadSize[2], loadSize[2]}, sampleSize = {sample_nc, sampleSize[2], sampleSize[2]}, split = 100, serial_batches = opt.serial_batches, verbose = true } trainLoader.sampleHookTrain = trainHook collectgarbage() -- do some sanity checks on trainLoader do local class = trainLoader.imageClass local nClasses = #trainLoader.classes assert(class:max() <= nClasses, "class logic has error") assert(class:min() >= 1, "class logic has error") end ================================================ FILE: data/unaligned_data_loader.lua ================================================ -------------------------------------------------------------------------------- -- Subclass of BaseDataLoader that provides data from two datasets. -- The samples from the datasets are not aligned. -- The datasets can have different sizes -------------------------------------------------------------------------------- require 'data.base_data_loader' local class = require 'class' data_util = paths.dofile('data_util.lua') UnalignedDataLoader = class('UnalignedDataLoader', 'BaseDataLoader') function UnalignedDataLoader:__init(conf) BaseDataLoader.__init(self, conf) conf = conf or {} end function UnalignedDataLoader:name() return 'UnalignedDataLoader' end function UnalignedDataLoader:Initialize(opt) opt.align_data = 0 self.dataA = data_util.load_dataset('A', opt, opt.input_nc) self.dataB = data_util.load_dataset('B', opt, opt.output_nc) end -- actually fetches the data -- |return|: a table of two tables, each corresponding to -- the batch for dataset A and dataset B function UnalignedDataLoader:LoadBatchForAllDatasets() local batchA, pathA = self.dataA:getBatch() local batchB, pathB = self.dataB:getBatch() return batchA, batchB, pathA, pathB end -- returns the size of each dataset function UnalignedDataLoader:size(dataset) if dataset == 'A' then return self.dataA:size() end if dataset == 'B' then return self.dataB:size() end return math.max(self.dataA:size(), self.dataB:size()) -- return the size of the largest dataset by default end ================================================ FILE: examples/test_vangogh_style_on_ae_photos.sh ================================================ #!/bin/sh ## This script download the dataset and pre-trained network, ## and generates style transferred images. # Download the dataset. The downloaded dataset is stored in ./datasets/${DATASET_NAME} DATASET_NAME='ae_photos' bash ./datasets/download_dataset.sh $DATASET_NAME # Download the pre-trained model. The downloaded model is stored in ./models/${MODEL_NAME}_pretrained/latest_net_G.t7 MODEL_NAME='style_vangogh' bash ./pretrained_models/download_model.sh $MODEL_NAME # Run style transfer using the downloaded dataset and model DATA_ROOT=./datasets/$DATASET_NAME name=${MODEL_NAME}_pretrained model=one_direction_test phase=test how_many='all' loadSize=256 fineSize=256 resize_or_crop='scale_width' th test.lua if [ $? == 0 ]; then echo "The result can be viewed at ./results/${MODEL_NAME}_pretrained/latest_test/index.html" fi ================================================ FILE: examples/train_maps.sh ================================================ DB_NAME='maps' GPU_ID=1 DISPLAY_ID=1 NET_G=resnet_6blocks NET_D=basic MODEL=cycle_gan SAVE_EPOCH=5 ALIGN_DATA=0 LAMBDA=10 NF=64 EXPR_NAME=${DB_NAME}_${MODEL}_${LAMBDA} CHECKPOINT_DIR=./checkpoints/ LOG_FILE=${CHECKPOINT_DIR}${EXPR_NAME}/log.txt mkdir -p ${CHECKPOINT_DIR}${EXPR_NAME} DATA_ROOT=./datasets/$DB_NAME align_data=$ALIGN_DATA use_lsgan=1 \ which_direction='AtoB' display_plot=$PLOT pool_size=50 niter=100 niter_decay=100 \ which_model_netG=$NET_G which_model_netD=$NET_D model=$MODEL lr=0.0002 print_freq=200 lambda_A=$LAMBDA lambda_B=$LAMBDA \ loadSize=143 fineSize=128 gpu=$GPU_ID display_winsize=128 \ name=$EXPR_NAME flip=1 save_epoch_freq=$SAVE_EPOCH \ continue_train=0 display_id=$DISPLAY_ID \ checkpoints_dir=$CHECKPOINT_DIR\ th train.lua | tee -a $LOG_FILE ================================================ FILE: models/architectures.lua ================================================ require 'nngraph' ---------------------------------------------------------------------------- local function weights_init(m) local name = torch.type(m) if name:find('Convolution') then m.weight:normal(0.0, 0.02) m.bias:fill(0) elseif name:find('Normalization') then if m.weight then m.weight:normal(1.0, 0.02) end if m.bias then m.bias:fill(0) end end end normalization = nil function set_normalization(norm) if norm == 'instance' then require 'util.InstanceNormalization' print('use InstanceNormalization') normalization = nn.InstanceNormalization elseif norm == 'batch' then print('use SpatialBatchNormalization') normalization = nn.SpatialBatchNormalization end end function defineG(input_nc, output_nc, ngf, which_model_netG, nz, arch) local netG = nil if which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf) elseif which_model_netG == "unet128" then netG = defineG_unet128(input_nc, output_nc, ngf) elseif which_model_netG == "unet256" then netG = defineG_unet256(input_nc, output_nc, ngf) elseif which_model_netG == "resnet_6blocks" then netG = defineG_resnet_6blocks(input_nc, output_nc, ngf) elseif which_model_netG == "resnet_9blocks" then netG = defineG_resnet_9blocks(input_nc, output_nc, ngf) else error("unsupported netG model") end netG:apply(weights_init) return netG end function defineD(input_nc, ndf, which_model_netD, n_layers_D, use_sigmoid) local netD = nil if which_model_netD == "basic" then netD = defineD_basic(input_nc, ndf, use_sigmoid) elseif which_model_netD == "imageGAN" then netD = defineD_imageGAN(input_nc, ndf, use_sigmoid) elseif which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc, ndf, n_layers_D, use_sigmoid) else error("unsupported netD model") end netD:apply(weights_init) return netD end function defineG_encoder_decoder(input_nc, output_nc, ngf) -- input is (nc) x 256 x 256 local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) -- input is (ngf) x 128 x 128 local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 8 x 8 local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 4 x 4 local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 2 x 2 local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8) -- input is (ngf * 8) x 1 x 1 local d1 = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 2 x 2 local d2 = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 4 x 4 local d3 = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 8 x 8 local d4 = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local d5 = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local d6 = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) -- input is (ngf) x128 x 128 local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf, output_nc, 4, 4, 2, 2, 1, 1) -- input is (nc) x 256 x 256 local o1 = d8 - nn.Tanh() local netG = nn.gModule({e1},{o1}) return netG end function defineG_unet128(input_nc, output_nc, ngf) local netG = nil -- input is (nc) x 128 x 128 local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) -- input is (ngf) x 64 x 64 local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 2) x 32 x 32 local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 4) x 16 x 16 local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 8 x 8 local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 4 x 4 local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 2 x 2 local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- normalization(ngf * 8) -- input is (ngf * 8) x 1 x 1 local d1_ = e7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 2 x 2 local d1 = {d1_,e6} - nn.JoinTable(2) local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 4 x 4 local d2 = {d2_,e5} - nn.JoinTable(2) local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 8 x 8 local d3 = {d3_,e4} - nn.JoinTable(2) local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 8) x 16 x 16 local d4 = {d4_,e3} - nn.JoinTable(2) local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 4) x 32 x 32 local d5 = {d5_,e2} - nn.JoinTable(2) local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) -- input is (ngf * 2) x 64 x 64 local d6 = {d6_,e1} - nn.JoinTable(2) local d7 = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) -- input is (nc) x 128 x 128 local o1 = d7 - nn.Tanh() local netG = nn.gModule({e1},{o1}) return netG end function defineG_unet256(input_nc, output_nc, ngf) local netG = nil -- input is (nc) x 256 x 256 local e1 = - nn.SpatialConvolution(input_nc, ngf, 4, 4, 2, 2, 1, 1) -- input is (ngf) x 128 x 128 local e2 = e1 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local e3 = e2 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local e4 = e3 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 4, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local e5 = e4 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 8 x 8 local e6 = e5 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 4 x 4 local e7 = e6 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 2 x 2 local e8 = e7 - nn.LeakyReLU(0.2, true) - nn.SpatialConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) -- - normalization(ngf * 8) -- input is (ngf * 8) x 1 x 1 local d1_ = e8 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 2 x 2 local d1 = {d1_,e7} - nn.JoinTable(2) local d2_ = d1 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 4 x 4 local d2 = {d2_,e6} - nn.JoinTable(2) local d3_ = d2 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) - nn.Dropout(0.5) -- input is (ngf * 8) x 8 x 8 local d3 = {d3_,e5} - nn.JoinTable(2) local d4_ = d3 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 8, 4, 4, 2, 2, 1, 1) - normalization(ngf * 8) -- input is (ngf * 8) x 16 x 16 local d4 = {d4_,e4} - nn.JoinTable(2) local d5_ = d4 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 8 * 2, ngf * 4, 4, 4, 2, 2, 1, 1) - normalization(ngf * 4) -- input is (ngf * 4) x 32 x 32 local d5 = {d5_,e3} - nn.JoinTable(2) local d6_ = d5 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 4 * 2, ngf * 2, 4, 4, 2, 2, 1, 1) - normalization(ngf * 2) -- input is (ngf * 2) x 64 x 64 local d6 = {d6_,e2} - nn.JoinTable(2) local d7_ = d6 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2 * 2, ngf, 4, 4, 2, 2, 1, 1) - normalization(ngf) -- input is (ngf) x128 x 128 local d7 = {d7_,e1} - nn.JoinTable(2) local d8 = d7 - nn.ReLU(true) - nn.SpatialFullConvolution(ngf * 2, output_nc, 4, 4, 2, 2, 1, 1) -- input is (nc) x 256 x 256 local o1 = d8 - nn.Tanh() local netG = nn.gModule({e1},{o1}) return netG end -------------------------------------------------------------------------------- -- Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/ -------------------------------------------------------------------------------- local function build_conv_block(dim, padding_type) local conv_block = nn.Sequential() local p = 0 if padding_type == 'reflect' then conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1)) elseif padding_type == 'replicate' then conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1)) elseif padding_type == 'zero' then p = 1 end conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p)) conv_block:add(normalization(dim)) conv_block:add(nn.ReLU(true)) if padding_type == 'reflect' then conv_block:add(nn.SpatialReflectionPadding(1, 1, 1, 1)) elseif padding_type == 'replicate' then conv_block:add(nn.SpatialReplicationPadding(1, 1, 1, 1)) end conv_block:add(nn.SpatialConvolution(dim, dim, 3, 3, 1, 1, p, p)) conv_block:add(normalization(dim)) return conv_block end local function build_res_block(dim, padding_type) local conv_block = build_conv_block(dim, padding_type) local res_block = nn.Sequential() local concat = nn.ConcatTable() concat:add(conv_block) concat:add(nn.Identity()) res_block:add(concat):add(nn.CAddTable()) return res_block end function defineG_resnet_6blocks(input_nc, output_nc, ngf) padding_type = 'reflect' local ks = 3 local netG = nil local f = 7 local p = (f - 1) / 2 local data = -nn.Identity() local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true) local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true) local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true) local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true) local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true) local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh() netG = nn.gModule({data},{d4}) return netG end function defineG_resnet_9blocks(input_nc, output_nc, ngf) padding_type = 'reflect' local ks = 3 local netG = nil local f = 7 local p = (f - 1) / 2 local data = -nn.Identity() local e1 = data - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(input_nc, ngf, f, f, 1, 1) - normalization(ngf) - nn.ReLU(true) local e2 = e1 - nn.SpatialConvolution(ngf, ngf*2, ks, ks, 2, 2, 1, 1) - normalization(ngf*2) - nn.ReLU(true) local e3 = e2 - nn.SpatialConvolution(ngf*2, ngf*4, ks, ks, 2, 2, 1, 1) - normalization(ngf*4) - nn.ReLU(true) local d1 = e3 - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) - build_res_block(ngf*4, padding_type) local d2 = d1 - nn.SpatialFullConvolution(ngf*4, ngf*2, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf*2) - nn.ReLU(true) local d3 = d2 - nn.SpatialFullConvolution(ngf*2, ngf, ks, ks, 2, 2, 1, 1,1,1) - normalization(ngf) - nn.ReLU(true) local d4 = d3 - nn.SpatialReflectionPadding(p, p, p, p) - nn.SpatialConvolution(ngf, output_nc, f, f, 1, 1) - nn.Tanh() netG = nn.gModule({data},{d4}) return netG end function defineD_imageGAN(input_nc, ndf, use_sigmoid) local netD = nn.Sequential() -- input is (nc) x 256 x 256 netD:add(nn.SpatialConvolution(input_nc, ndf, 4, 4, 2, 2, 1, 1)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 128 x 128 netD:add(nn.SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1)) netD:add(nn.SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*2) x 64 x 64 netD:add(nn.SpatialConvolution(ndf * 2, ndf*4, 4, 4, 2, 2, 1, 1)) netD:add(nn.SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*4) x 32 x 32 netD:add(nn.SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1)) netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*8) x 16 x 16 netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1)) netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*8) x 8 x 8 netD:add(nn.SpatialConvolution(ndf * 8, ndf * 8, 4, 4, 2, 2, 1, 1)) netD:add(nn.SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*8) x 4 x 4 netD:add(nn.SpatialConvolution(ndf * 8, 1, 4, 4, 2, 2, 1, 1)) -- state size: 1 x 1 x 1 if use_sigmoid then netD:add(nn.Sigmoid()) end return netD end function defineD_basic(input_nc, ndf, use_sigmoid) n_layers = 3 return defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid) end -- rf=1 function defineD_pixelGAN(input_nc, ndf, use_sigmoid) local netD = nn.Sequential() -- input is (nc) x 256 x 256 netD:add(nn.SpatialConvolution(input_nc, ndf, 1, 1, 1, 1, 0, 0)) netD:add(nn.LeakyReLU(0.2, true)) -- state size: (ndf) x 256 x 256 netD:add(nn.SpatialConvolution(ndf, ndf * 2, 1, 1, 1, 1, 0, 0)) netD:add(normalization(ndf * 2)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*2) x 256 x 256 netD:add(nn.SpatialConvolution(ndf * 2, 1, 1, 1, 1, 1, 0, 0)) -- state size: 1 x 256 x 256 if use_sigmoid then netD:add(nn.Sigmoid()) -- state size: 1 x 30 x 30 end return netD end -- if n=0, then use pixelGAN (rf=1) -- else rf is 16 if n=1 -- 34 if n=2 -- 70 if n=3 -- 142 if n=4 -- 286 if n=5 -- 574 if n=6 function defineD_n_layers(input_nc, ndf, n_layers, use_sigmoid, kw, dropout_ratio) if dropout_ratio == nil then dropout_ratio = 0.0 end if kw == nil then kw = 4 end padw = math.ceil((kw-1)/2) if n_layers==0 then return defineD_pixelGAN(input_nc, ndf, use_sigmoid) else local netD = nn.Sequential() -- input is (nc) x 256 x 256 -- print('input_nc', input_nc) netD:add(nn.SpatialConvolution(input_nc, ndf, kw, kw, 2, 2, padw, padw)) netD:add(nn.LeakyReLU(0.2, true)) local nf_mult = 1 local nf_mult_prev = 1 for n = 1, n_layers-1 do nf_mult_prev = nf_mult nf_mult = math.min(2^n,8) netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 2, 2, padw,padw)) netD:add(normalization(ndf * nf_mult)):add(nn.Dropout(dropout_ratio)) netD:add(nn.LeakyReLU(0.2, true)) end -- state size: (ndf*M) x N x N nf_mult_prev = nf_mult nf_mult = math.min(2^n_layers,8) netD:add(nn.SpatialConvolution(ndf * nf_mult_prev, ndf * nf_mult, kw, kw, 1, 1, padw, padw)) netD:add(normalization(ndf * nf_mult)):add(nn.LeakyReLU(0.2, true)) -- state size: (ndf*M*2) x (N-1) x (N-1) netD:add(nn.SpatialConvolution(ndf * nf_mult, 1, kw, kw, 1, 1, padw,padw)) -- state size: 1 x (N-2) x (N-2) if use_sigmoid then netD:add(nn.Sigmoid()) end -- state size: 1 x (N-2) x (N-2) return netD end end ================================================ FILE: models/base_model.lua ================================================ -------------------------------------------------------------------------------- -- Base Class for Providing Models -------------------------------------------------------------------------------- local class = require 'class' BaseModel = class('BaseModel') function BaseModel:__init(conf) conf = conf or {} end -- Returns the name of the model function BaseModel:model_name() return 'DoesNothingModel' end -- Defines models and networks function BaseModel:Initialize(opt) models = {} return models end -- Runs the forward pass of the network function BaseModel:Forward(input, opt) output = {} return output end -- Runs the backprop gradient descent -- Corresponds to a single batch of data function BaseModel:OptimizeParameters(opt) end -- This function can be used to reset momentum after each epoch function BaseModel:RefreshParameters(opt) end -- This function can be used to reset momentum after each epoch function BaseModel:UpdateLearningRate(opt) end -- Save the current model to the file system function BaseModel:Save(prefix, opt) end -- returns a string that describes the current errors function BaseModel:GetCurrentErrorDescription() return "No Error exists in BaseModel" end -- returns current errors function BaseModel:GetCurrentErrors(opt) return {} end -- returns a table of image/label pairs that describe -- the current results. -- |return|: a table of table. List of image/label pairs function BaseModel:GetCurrentVisuals(opt, size) return {} end -- returns a string that describes the display plot configuration function BaseModel:DisplayPlot(opt) return {} end ================================================ FILE: models/bigan_model.lua ================================================ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' util = paths.dofile('../util/util.lua') content = paths.dofile('../util/content_loss.lua') BiGANModel = class('BiGANModel', 'BaseModel') function BiGANModel:__init(conf) BaseModel.__init(self, conf) conf = conf or {} end function BiGANModel:model_name() return 'BiGANModel' end function BiGANModel:InitializeStates(use_wgan) optimState = {learningRate=opt.lr, beta1=opt.beta1,} return optimState end -- Defines models and networks function BiGANModel:Initialize(opt) if opt.test == 0 then self.realABPool = ImagePool(opt.pool_size) self.fakeABPool = ImagePool(opt.pool_size) end -- define tensors local d_input_nc = opt.input_nc + opt.output_nc self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) -- load/define models self.criterionGAN = nn.MSECriterion() local netG, netE, netD = nil, nil, nil if opt.continue_train == 1 then if opt.test == 1 then -- which_epoch option exists in test mode netG = util.load_test_model('G', opt) netE = util.load_test_model('E', opt) netD = util.load_test_model('D', opt) else netG = util.load_model('G', opt) netE = util.load_model('E', opt) netD = util.load_model('D', opt) end else -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch) -- os.exit() netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- no sigmoid layer print('netD...', netD) netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch) print('netG...', netG) netE = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch) print('netE...', netE) end self.netD = netD self.netG = netG self.netE = netE -- define real/fake labels netD_output_size = self.netD:forward(self.real_AB):size() self.fake_label = torch.Tensor(netD_output_size):fill(0.0) self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing self.optimStateD = self:InitializeStates() self.optimStateG = self:InitializeStates() self.optimStateE = self:InitializeStates() self.A_idx = {{}, {1, opt.input_nc}, {}, {}} self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}} self:RefreshParameters() print('---------- # Learnable Parameters --------------') print(('G = %d'):format(self.parametersG:size(1))) print(('E = %d'):format(self.parametersE:size(1))) print(('D = %d'):format(self.parametersD:size(1))) print('------------------------------------------------') -- os.exit() end -- Runs the forward pass of the network and -- saves the result to member variables of the class function BiGANModel:Forward(input, opt) if opt.which_direction == 'BtoA' then local temp = input.real_A input.real_A = input.real_B input.real_B = temp end self.real_AB[self.A_idx]:copy(input.real_A) self.fake_AB[self.B_idx]:copy(input.real_B) self.real_A = self.real_AB[self.A_idx] self.real_B = self.fake_AB[self.B_idx] self.fake_B = self.netG:forward(self.real_A):clone() self.fake_A = self.netE:forward(self.real_B):clone() self.real_AB[self.B_idx]:copy(self.fake_B) -- real_AB: real_A, fake_B -> real_label self.fake_AB[self.A_idx]:copy(self.fake_A) -- fake_AB: fake_A, real_B -> fake_label -- if opt.test == 0 then -- self.real_AB = self.realABPool:Query(self.real_AB) -- batch history -- self.fake_AB = self.fakeABPool:Query(self.fake_AB) -- batch history -- end end -- create closure to evaluate f(X) and df/dX of discriminator function BiGANModel:fDx_basic(x, gradParams, netD, real_AB, fake_AB, opt) util.BiasZero(netD) gradParams:zero() -- Real log(D_A(B)) local output = netD:forward(real_AB):clone() local errD_real = self.criterionGAN:forward(output, self.real_label) local df_do = self.criterionGAN:backward(output, self.real_label) netD:backward(real_AB, df_do) -- Fake + log(1 - D_A(G(A))) output = netD:forward(fake_AB):clone() local errD_fake = self.criterionGAN:forward(output, self.fake_label) local df_do2 = self.criterionGAN:backward(output, self.fake_label) netD:backward(fake_AB, df_do2) -- Compute loss local errD = (errD_real + errD_fake) / 2.0 return errD, gradParams end function BiGANModel:fDx(x, opt) -- use image pool that stores the old fake images real_AB = self.realABPool:Query(self.real_AB) fake_AB = self.fakeABPool:Query(self.fake_AB) self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, real_AB, fake_AB, opt) return self.errD, gradParams end function BiGANModel:fGx_basic(x, netG, netD, gradParametersG, opt) util.BiasZero(netG) util.BiasZero(netD) gradParametersG:zero() -- First. G(A) should fake the discriminator local output = netD:forward(self.real_AB):clone() local errG = self.criterionGAN:forward(output, self.fake_label) local dgan_loss_dd = self.criterionGAN:backward(output, self.fake_label) local dgan_loss_do = netD:updateGradInput(self.real_AB, dgan_loss_dd) netG:backward(self.real_A, dgan_loss_do[self.B_idx]) -- real_AB: real_A, fake_B -> real_label return gradParametersG, errG end function BiGANModel:fGx(x, opt) self.gradParametersG, self.errG = self:fGx_basic(x, self.netG, self.netD, self.gradParametersG, opt) return self.errG, self.gradParametersG end function BiGANModel:fEx_basic(x, netE, netD, gradParametersE, opt) util.BiasZero(netE) util.BiasZero(netD) gradParametersE:zero() -- First. G(A) should fake the discriminator local output = netD:forward(self.fake_AB):clone() local errE= self.criterionGAN:forward(output, self.real_label) local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label) local dgan_loss_do = netD:updateGradInput(self.fake_AB, dgan_loss_dd) netE:backward(self.real_B, dgan_loss_do[self.A_idx])-- fake_AB: fake_A, real_B -> fake_label return gradParametersE, errE end function BiGANModel:fEx(x, opt) self.gradParametersE, self.errE = self:fEx_basic(x, self.netE, self.netD, self.gradParametersE, opt) return self.errE, self.gradParametersE end function BiGANModel:OptimizeParameters(opt) local fG = function(x) return self:fGx(x, opt) end local fE = function(x) return self:fEx(x, opt) end local fD = function(x) return self:fDx(x, opt) end optim.adam(fD, self.parametersD, self.optimStateD) optim.adam(fG, self.parametersG, self.optimStateG) optim.adam(fE, self.parametersE, self.optimStateE) end function BiGANModel:RefreshParameters() self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory self.parametersG, self.gradParametersG = nil, nil self.parametersE, self.gradParametersE = nil, nil -- define parameters of optimization self.parametersD, self.gradParametersD = self.netD:getParameters() self.parametersG, self.gradParametersG = self.netG:getParameters() self.parametersE, self.gradParametersE = self.netE:getParameters() end function BiGANModel:Save(prefix, opt) util.save_model(self.netG, prefix .. '_net_G.t7', 1) util.save_model(self.netE, prefix .. '_net_E.t7', 1) util.save_model(self.netD, prefix .. '_net_D.t7', 1) end function BiGANModel:GetCurrentErrorDescription() description = ('D: %.4f G: %.4f E: %.4f'):format( self.errD and self.errD or -1, self.errG and self.errG or -1, self.errE and self.errE or -1) return description end function BiGANModel:GetCurrentErrors() local errors = {errD=self.errD, errG=self.errG, errE=self.errE} return errors end -- returns a string that describes the display plot configuration function BiGANModel:DisplayPlot(opt) return 'errD,errG,errE' end function BiGANModel:UpdateLearningRate(opt) local lrd = opt.lr / opt.niter_decay local old_lr = self.optimStateD['learningRate'] local lr = old_lr - lrd self.optimStateD['learningRate'] = lr self.optimStateG['learningRate'] = lr self.optimStateE['learningRate'] = lr print(('update learning rate: %f -> %f'):format(old_lr, lr)) end local function MakeIm3(im) -- print('before im_size', im:size()) local im3 = nil if im:size(2) == 1 then im3 = torch.repeatTensor(im, 1,3,1,1) else im3 = im end -- print('after im_size', im:size()) -- print('after im3_size', im3:size()) return im3 end function BiGANModel:GetCurrentVisuals(opt, size) if not size then size = opt.display_winsize end local visuals = {} table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'}) table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'}) return visuals end ================================================ FILE: models/content_gan_model.lua ================================================ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' util = paths.dofile('../util/util.lua') content = paths.dofile('../util/content_loss.lua') ContentGANModel = class('ContentGANModel', 'BaseModel') function ContentGANModel:__init(conf) BaseModel.__init(self, conf) conf = conf or {} end function ContentGANModel:model_name() return 'ContentGANModel' end function ContentGANModel:InitializeStates() local optimState = {learningRate=opt.lr, beta1=opt.beta1,} return optimState end -- Defines models and networks function ContentGANModel:Initialize(opt) if opt.test == 0 then self.fakePool = ImagePool(opt.pool_size) end -- define tensors self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.real_B = self.fake_B:clone() --torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) -- load/define models self.criterionGAN = nn.MSECriterion() self.criterionContent = nn.AbsCriterion() self.contentFunc = content.defineContent(opt.content_loss, opt.layer_name) self.netG, self.netD = nil, nil if opt.continue_train == 1 then if opt.which_epoch then -- which_epoch option exists in test mode self.netG = util.load_test_model('G_A', opt) self.netD = util.load_test_model('D_A', opt) else self.netG = util.load_model('G_A', opt) self.netD = util.load_model('D_A', opt) end else self.netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG) print('netG...', self.netG) self.netD = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) print('netD...', self.netD) end -- define real/fake labels netD_output_size = self.netD:forward(self.real_A):size() self.fake_label = torch.Tensor(netD_output_size):fill(0.0) self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing self.optimStateD = self:InitializeStates() self.optimStateG = self:InitializeStates() self:RefreshParameters() print('---------- # Learnable Parameters --------------') print(('G = %d'):format(self.parametersG:size(1))) print(('D = %d'):format(self.parametersD:size(1))) print('------------------------------------------------') -- os.exit() end -- Runs the forward pass of the network and -- saves the result to member variables of the class function ContentGANModel:Forward(input, opt) if opt.which_direction == 'BtoA' then local temp = input.real_A input.real_A = input.real_B input.real_B = temp end self.real_A:copy(input.real_A) self.real_B:copy(input.real_B) self.fake_B = self.netG:forward(self.real_A):clone() -- output = {self.fake_B} output = {} -- if opt.test == 1 then -- end return output end -- create closure to evaluate f(X) and df/dX of discriminator function ContentGANModel:fDx_basic(x, gradParams, netD, netG, real_target, fake_target, opt) util.BiasZero(netD) util.BiasZero(netG) gradParams:zero() local errD_real, errD_rec, errD_fake, errD = 0, 0, 0, 0 -- Real log(D_A(B)) local output = netD:forward(real_target) errD_real = self.criterionGAN:forward(output, self.real_label) df_do = self.criterionGAN:backward(output, self.real_label) netD:backward(real_target, df_do) -- Fake + log(1 - D_A(G_A(A))) output = netD:forward(fake_target) errD_fake = self.criterionGAN:forward(output, self.fake_label) df_do = self.criterionGAN:backward(output, self.fake_label) netD:backward(fake_target, df_do) errD = (errD_real + errD_fake) / 2.0 -- print('errD', errD return errD, gradParams end function ContentGANModel:fDx(x, opt) fake_B = self.fakePool:Query(self.fake_B) self.errD, gradParams = self:fDx_basic(x, self.gradparametersD, self.netD, self.netG, self.real_B, fake_B, opt) return self.errD, gradParams end function ContentGANModel:fGx_basic(x, netG_source, netD_source, real_source, real_target, fake_target, gradParametersG_source, opt) util.BiasZero(netD_source) util.BiasZero(netG_source) gradParametersG_source:zero() -- GAN loss -- local df_d_GAN = torch.zeros(fake_target:size()) -- local errGAN = 0 -- local errRec = 0 --- Domain GAN loss: D_A(G_A(A)) local output = netD_source.output -- [hack] forward was already executed in fDx, so save computation netD_source:forward(fake_B) --- local errGAN = self.criterionGAN:forward(output, self.real_label) local df_do = self.criterionGAN:backward(output, self.real_label) local df_d_GAN = netD_source:updateGradInput(fake_target, df_do) ---:narrow(2,fake_AB:size(2)-output_nc+1, output_nc) -- content loss -- print('content_loss', opt.content_loss) -- function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight) local errContent, df_d_content = content.lossUpdate(self.criterionContent, real_source, fake_target, self.contentFunc, opt.content_loss, opt.lambda_A) netG_source:forward(real_source) netG_source:backward(real_source, df_d_GAN + df_d_content) -- print('errD', errGAN) return gradParametersG_source, errGAN, errContent end function ContentGANModel:fGx(x, opt) self.gradparametersG, self.errG, self.errCont = self:fGx_basic(x, self.netG, self.netD, self.real_A, self.real_B, self.fake_B, self.gradparametersG, opt) return self.errG, self.gradparametersG end function ContentGANModel:OptimizeParameters(opt) local fDx = function(x) return self:fDx(x, opt) end local fGx = function(x) return self:fGx(x, opt) end optim.adam(fDx, self.parametersD, self.optimStateD) optim.adam(fGx, self.parametersG, self.optimStateG) end function ContentGANModel:RefreshParameters() self.parametersD, self.gradparametersD = nil, nil -- nil them to avoid spiking memory self.parametersG, self.gradparametersG = nil, nil -- define parameters of optimization self.parametersG, self.gradparametersG = self.netG:getParameters() self.parametersD, self.gradparametersD = self.netD:getParameters() end function ContentGANModel:Save(prefix, opt) util.save_model(self.netG, prefix .. '_net_G_A.t7', 1.0) util.save_model(self.netD, prefix .. '_net_D_A.t7', 1.0) end function ContentGANModel:GetCurrentErrorDescription() description = ('G: %.4f D: %.4f Content: %.4f'):format(self.errG and self.errG or -1, self.errD and self.errD or -1, self.errCont and self.errCont or -1) return description end function ContentGANModel:GetCurrentErrors() local errors = {errG=self.errG and self.errG or -1, errD=self.errD and self.errD or -1, errCont=self.errCont and self.errCont or -1} return errors end -- returns a string that describes the display plot configuration function ContentGANModel:DisplayPlot(opt) return 'errG,errD,errCont' end function ContentGANModel:GetCurrentVisuals(opt, size) if not size then size = opt.display_winsize end local visuals = {} table.insert(visuals, {img=self.real_A, label='real_A'}) table.insert(visuals, {img=self.fake_B, label='fake_B'}) table.insert(visuals, {img=self.real_B, label='real_B'}) return visuals end ================================================ FILE: models/cycle_gan_model.lua ================================================ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' util = paths.dofile('../util/util.lua') CycleGANModel = class('CycleGANModel', 'BaseModel') function CycleGANModel:__init(conf) BaseModel.__init(self, conf) conf = conf or {} end function CycleGANModel:model_name() return 'CycleGANModel' end function CycleGANModel:InitializeStates(use_wgan) optimState = {learningRate=opt.lr, beta1=opt.beta1,} return optimState end -- Defines models and networks function CycleGANModel:Initialize(opt) if opt.test == 0 then self.fakeAPool = ImagePool(opt.pool_size) self.fakeBPool = ImagePool(opt.pool_size) end -- define tensors if opt.test == 0 then -- allocate tensors for training self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.real_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.fake_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.fake_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) self.rec_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) self.rec_B = torch.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize) end -- load/define models local use_lsgan = ((opt.use_lsgan ~= nil) and (opt.use_lsgan == 1)) if not use_lsgan then self.criterionGAN = nn.BCECriterion() else self.criterionGAN = nn.MSECriterion() end self.criterionRec = nn.AbsCriterion() local netG_A, netD_A, netG_B, netD_B = nil, nil, nil, nil if opt.continue_train == 1 then if opt.test == 1 then -- test mode netG_A = util.load_test_model('G_A', opt) netG_B = util.load_test_model('G_B', opt) --setup optnet to save a little bit of memory if opt.use_optnet == 1 then local sample_input = torch.randn(1, opt.input_nc, 2, 2) local optnet = require 'optnet' optnet.optimizeMemory(netG_A, sample_input, {inplace=true, reuseBuffers=true}) optnet.optimizeMemory(netG_B, sample_input, {inplace=true, reuseBuffers=true}) end else netG_A = util.load_model('G_A', opt) netG_B = util.load_model('G_B', opt) netD_A = util.load_model('D_A', opt) netD_B = util.load_model('D_B', opt) end else local use_sigmoid = (not use_lsgan) -- netG_test = defineG(opt.input_nc, opt.output_nc, opt.ngf, "resnet_unet", opt.arch) -- os.exit() netG_A = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.arch) print('netG_A...', netG_A) netD_A = defineD(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer print('netD_A...', netD_A) netG_B = defineG(opt.output_nc, opt.input_nc, opt.ngf, opt.which_model_netG, opt.arch) print('netG_B...', netG_B) netD_B = defineD(opt.input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, use_sigmoid) -- no sigmoid layer print('netD_B', netD_B) end self.netD_A = netD_A self.netG_A = netG_A self.netG_B = netG_B self.netD_B = netD_B -- define real/fake labels if opt.test == 0 then local D_A_size = self.netD_A:forward(self.real_B):size() -- hack: assume D_size_A = D_size_B self.fake_label_A = torch.Tensor(D_A_size):fill(0.0) self.real_label_A = torch.Tensor(D_A_size):fill(1.0) -- no soft smoothing local D_B_size = self.netD_B:forward(self.real_A):size() -- hack: assume D_size_A = D_size_B self.fake_label_B = torch.Tensor(D_B_size):fill(0.0) self.real_label_B = torch.Tensor(D_B_size):fill(1.0) -- no soft smoothing self.optimStateD_A = self:InitializeStates() self.optimStateG_A = self:InitializeStates() self.optimStateD_B = self:InitializeStates() self.optimStateG_B = self:InitializeStates() self:RefreshParameters() print('---------- # Learnable Parameters --------------') print(('G_A = %d'):format(self.parametersG_A:size(1))) print(('D_A = %d'):format(self.parametersD_A:size(1))) print(('G_B = %d'):format(self.parametersG_B:size(1))) print(('D_B = %d'):format(self.parametersD_B:size(1))) print('------------------------------------------------') end end -- Runs the forward pass of the network and -- saves the result to member variables of the class function CycleGANModel:Forward(input, opt) if opt.which_direction == 'BtoA' then local temp = input.real_A:clone() input.real_A = input.real_B:clone() input.real_B = temp end if opt.test == 0 then self.real_A:copy(input.real_A) self.real_B:copy(input.real_B) end if opt.test == 1 then -- forward for test if opt.gpu > 0 then self.real_A = input.real_A:cuda() self.real_B = input.real_B:cuda() else self.real_A = input.real_A:clone() self.real_B = input.real_B:clone() end self.fake_B = self.netG_A:forward(self.real_A):clone() self.fake_A = self.netG_B:forward(self.real_B):clone() self.rec_A = self.netG_B:forward(self.fake_B):clone() self.rec_B = self.netG_A:forward(self.fake_A):clone() end end -- create closure to evaluate f(X) and df/dX of discriminator function CycleGANModel:fDx_basic(x, gradParams, netD, netG, real, fake, real_label, fake_label, opt) util.BiasZero(netD) util.BiasZero(netG) gradParams:zero() -- Real log(D_A(B)) local output = netD:forward(real) local errD_real = self.criterionGAN:forward(output, real_label) local df_do = self.criterionGAN:backward(output, real_label) netD:backward(real, df_do) -- Fake + log(1 - D_A(G_A(A))) output = netD:forward(fake) local errD_fake = self.criterionGAN:forward(output, fake_label) local df_do2 = self.criterionGAN:backward(output, fake_label) netD:backward(fake, df_do2) -- Compute loss local errD = (errD_real + errD_fake) / 2.0 return errD, gradParams end function CycleGANModel:fDAx(x, opt) -- use image pool that stores the old fake images fake_B = self.fakeBPool:Query(self.fake_B) self.errD_A, gradParams = self:fDx_basic(x, self.gradparametersD_A, self.netD_A, self.netG_A, self.real_B, fake_B, self.real_label_A, self.fake_label_A, opt) return self.errD_A, gradParams end function CycleGANModel:fDBx(x, opt) -- use image pool that stores the old fake images fake_A = self.fakeAPool:Query(self.fake_A) self.errD_B, gradParams = self:fDx_basic(x, self.gradparametersD_B, self.netD_B, self.netG_B, self.real_A, fake_A, self.real_label_B, self.fake_label_B, opt) return self.errD_B, gradParams end function CycleGANModel:fGx_basic(x, gradParams, netG, netD, netE, real, real2, real_label, lambda1, lambda2, opt) util.BiasZero(netD) util.BiasZero(netG) util.BiasZero(netE) -- inverse mapping gradParams:zero() -- G should be identity if real2 is fed. local errI = nil local identity = nil if opt.lambda_identity > 0 then identity = netG:forward(real2):clone() errI = self.criterionRec:forward(identity, real2) * lambda2 * opt.lambda_identity local didentity_loss_do = self.criterionRec:backward(identity, real2):mul(lambda2):mul(opt.lambda_identity) netG:backward(real2, didentity_loss_do) end --- GAN loss: D_A(G_A(A)) local fake = netG:forward(real):clone() local output = netD:forward(fake) local errG = self.criterionGAN:forward(output, real_label) local df_do1 = self.criterionGAN:backward(output, real_label) local df_d_GAN = netD:updateGradInput(fake, df_do1) -- -- forward cycle loss local rec = netE:forward(fake):clone() local errRec = self.criterionRec:forward(rec, real) * lambda1 local df_do2 = self.criterionRec:backward(rec, real):mul(lambda1) local df_do_rec = netE:updateGradInput(fake, df_do2) netG:backward(real, df_d_GAN + df_do_rec) -- backward cycle loss local fake2 = netE:forward(real2)--:clone() local rec2 = netG:forward(fake2)--:clone() local errAdapt = self.criterionRec:forward(rec2, real2) * lambda2 local df_do_coadapt = self.criterionRec:backward(rec2, real2):mul(lambda2) netG:backward(fake2, df_do_coadapt) return gradParams, errG, errRec, errI, fake, rec, identity end function CycleGANModel:fGAx(x, opt) self.gradparametersG_A, self.errG_A, self.errRec_A, self.errI_A, self.fake_B, self.rec_A, self.identity_B = self:fGx_basic(x, self.gradparametersG_A, self.netG_A, self.netD_A, self.netG_B, self.real_A, self.real_B, self.real_label_A, opt.lambda_A, opt.lambda_B, opt) return self.errG_A, self.gradparametersG_A end function CycleGANModel:fGBx(x, opt) self.gradparametersG_B, self.errG_B, self.errRec_B, self.errI_B, self.fake_A, self.rec_B, self.identity_A = self:fGx_basic(x, self.gradparametersG_B, self.netG_B, self.netD_B, self.netG_A, self.real_B, self.real_A, self.real_label_B, opt.lambda_B, opt.lambda_A, opt) return self.errG_B, self.gradparametersG_B end function CycleGANModel:OptimizeParameters(opt) local fDA = function(x) return self:fDAx(x, opt) end local fGA = function(x) return self:fGAx(x, opt) end local fDB = function(x) return self:fDBx(x, opt) end local fGB = function(x) return self:fGBx(x, opt) end optim.adam(fGA, self.parametersG_A, self.optimStateG_A) optim.adam(fDA, self.parametersD_A, self.optimStateD_A) optim.adam(fGB, self.parametersG_B, self.optimStateG_B) optim.adam(fDB, self.parametersD_B, self.optimStateD_B) end function CycleGANModel:RefreshParameters() self.parametersD_A, self.gradparametersD_A = nil, nil -- nil them to avoid spiking memory self.parametersG_A, self.gradparametersG_A = nil, nil self.parametersG_B, self.gradparametersG_B = nil, nil self.parametersD_B, self.gradparametersD_B = nil, nil -- define parameters of optimization self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters() self.parametersD_A, self.gradparametersD_A = self.netD_A:getParameters() self.parametersG_B, self.gradparametersG_B = self.netG_B:getParameters() self.parametersD_B, self.gradparametersD_B = self.netD_B:getParameters() end function CycleGANModel:Save(prefix, opt) util.save_model(self.netG_A, prefix .. '_net_G_A.t7', 1) util.save_model(self.netD_A, prefix .. '_net_D_A.t7', 1) util.save_model(self.netG_B, prefix .. '_net_G_B.t7', 1) util.save_model(self.netD_B, prefix .. '_net_D_B.t7', 1) end function CycleGANModel:GetCurrentErrorDescription() description = ('[A] G: %.4f D: %.4f Rec: %.4f I: %.4f || [B] G: %.4f D: %.4f Rec: %.4f I:%.4f'):format( self.errG_A and self.errG_A or -1, self.errD_A and self.errD_A or -1, self.errRec_A and self.errRec_A or -1, self.errI_A and self.errI_A or -1, self.errG_B and self.errG_B or -1, self.errD_B and self.errD_B or -1, self.errRec_B and self.errRec_B or -1, self.errI_B and self.errI_B or -1) return description end function CycleGANModel:GetCurrentErrors() local errors = {errG_A=self.errG_A, errD_A=self.errD_A, errRec_A=self.errRec_A, errI_A=self.errI_A, errG_B=self.errG_B, errD_B=self.errD_B, errRec_B=self.errRec_B, errI_B=self.errI_B} return errors end -- returns a string that describes the display plot configuration function CycleGANModel:DisplayPlot(opt) if opt.lambda_identity > 0 then return 'errG_A,errD_A,errRec_A,errI_A,errG_B,errD_B,errRec_B,errI_B' else return 'errG_A,errD_A,errRec_A,errG_B,errD_B,errRec_B' end end function CycleGANModel:UpdateLearningRate(opt) local lrd = opt.lr / opt.niter_decay local old_lr = self.optimStateD_A['learningRate'] local lr = old_lr - lrd self.optimStateD_A['learningRate'] = lr self.optimStateD_B['learningRate'] = lr self.optimStateG_A['learningRate'] = lr self.optimStateG_B['learningRate'] = lr print(('update learning rate: %f -> %f'):format(old_lr, lr)) end local function MakeIm3(im) if im:size(2) == 1 then local im3 = torch.repeatTensor(im, 1,3,1,1) return im3 else return im end end function CycleGANModel:GetCurrentVisuals(opt, size) local visuals = {} table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) table.insert(visuals, {img=MakeIm3(self.rec_A), label='rec_A'}) if opt.test == 0 and opt.lambda_identity > 0 then table.insert(visuals, {img=MakeIm3(self.identity_A), label='identity_A'}) end table.insert(visuals, {img=MakeIm3(self.real_B), label='real_B'}) table.insert(visuals, {img=MakeIm3(self.fake_A), label='fake_A'}) table.insert(visuals, {img=MakeIm3(self.rec_B), label='rec_B'}) if opt.test == 0 and opt.lambda_identity > 0 then table.insert(visuals, {img=MakeIm3(self.identity_B), label='identity_B'}) end return visuals end ================================================ FILE: models/one_direction_test_model.lua ================================================ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' util = paths.dofile('../util/util.lua') OneDirectionTestModel = class('OneDirectionTestModel', 'BaseModel') function OneDirectionTestModel:__init(conf) BaseModel.__init(self, conf) conf = conf or {} end function OneDirectionTestModel:model_name() return 'OneDirectionTestModel' end -- Defines models and networks function OneDirectionTestModel:Initialize(opt) -- define tensors self.real_A = torch.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) -- load/define models self.netG_A = util.load_test_model('G', opt) -- setup optnet to save a bit of memory if opt.use_optnet == 1 then local optnet = require 'optnet' local sample_input = torch.randn(1, opt.input_nc, 2, 2) optnet.optimizeMemory(self.netG_A, sample_input, {inplace=true, reuseBuffers=true}) end self:RefreshParameters() print('---------- # Learnable Parameters --------------') print(('G_A = %d'):format(self.parametersG_A:size(1))) print('------------------------------------------------') end -- Runs the forward pass of the network and -- saves the result to member variables of the class function OneDirectionTestModel:Forward(input, opt) if opt.which_direction == 'BtoA' then input.real_A = input.real_B:clone() end self.real_A = input.real_A:clone() if opt.gpu > 0 then self.real_A = self.real_A:cuda() end self.fake_B = self.netG_A:forward(self.real_A):clone() end function OneDirectionTestModel:RefreshParameters() self.parametersG_A, self.gradparametersG_A = nil, nil self.parametersG_A, self.gradparametersG_A = self.netG_A:getParameters() end local function MakeIm3(im) if im:size(2) == 1 then local im3 = torch.repeatTensor(im, 1,3,1,1) return im3 else return im end end function OneDirectionTestModel:GetCurrentVisuals(opt, size) if not size then size = opt.display_winsize end local visuals = {} table.insert(visuals, {img=MakeIm3(self.real_A), label='real_A'}) table.insert(visuals, {img=MakeIm3(self.fake_B), label='fake_B'}) return visuals end ================================================ FILE: models/pix2pix_model.lua ================================================ local class = require 'class' require 'models.base_model' require 'models.architectures' require 'util.image_pool' util = paths.dofile('../util/util.lua') Pix2PixModel = class('Pix2PixModel', 'BaseModel') function Pix2PixModel:__init(conf) conf = conf or {} end -- Returns the name of the model function Pix2PixModel:model_name() return 'Pix2PixModel' end function Pix2PixModel:InitializeStates() return {learningRate=opt.lr, beta1=opt.beta1,} end -- Defines models and networks function Pix2PixModel:Initialize(opt) -- use lsgan -- define tensors local d_input_nc = opt.input_nc + opt.output_nc self.real_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) self.fake_AB = torch.Tensor(opt.batchSize, d_input_nc, opt.fineSize, opt.fineSize) if opt.test == 0 then self.fakeABPool = ImagePool(opt.pool_size) end -- load/define models self.criterionGAN = nn.MSECriterion() self.criterionL1 = nn.AbsCriterion() local netG, netD = nil, nil if opt.continue_train == 1 then if opt.test == 1 then -- only load model G for test netG = util.load_test_model('G', opt) else netG = util.load_model('G', opt) netD = util.load_model('D', opt) end else netG = defineG(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG) netD = defineD(d_input_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, false) -- with sigmoid end self.netD = netD self.netG = netG -- define real/fake labels if opt.test == 0 then netD_output_size = self.netD:forward(self.real_AB):size() self.fake_label = torch.Tensor(netD_output_size):fill(0.0) self.real_label = torch.Tensor(netD_output_size):fill(1.0) -- no soft smoothing self.optimStateD = self:InitializeStates() self.optimStateG = self:InitializeStates() self:RefreshParameters() print('---------- # Learnable Parameters --------------') print(('G = %d'):format(self.parametersG:size(1))) print(('D = %d'):format(self.parametersD:size(1))) print('------------------------------------------------') end self.A_idx = {{}, {1, opt.input_nc}, {}, {}} self.B_idx = {{}, {opt.input_nc+1, opt.input_nc+opt.output_nc}, {}, {}} end -- Runs the forward pass of the network function Pix2PixModel:Forward(input, opt) if opt.which_direction == 'BtoA' then local temp = input.real_A input.real_A = input.real_B input.real_B = temp end if opt.test == 0 then self.real_AB[self.A_idx]:copy(input.real_A) self.real_AB[self.B_idx]:copy(input.real_B) self.real_A = self.real_AB[self.A_idx] self.real_B = self.real_AB[self.B_idx] self.fake_AB[self.A_idx]:copy(self.real_A) self.fake_B = self.netG:forward(self.real_A):clone() self.fake_AB[self.B_idx]:copy(self.fake_B) else if opt.gpu > 0 then self.real_A = input.real_A:cuda() self.real_B = input.real_B:cuda() else self.real_A = input.real_A:clone() self.real_B = input.real_B:clone() end self.fake_B = self.netG:forward(self.real_A):clone() end end -- create closure to evaluate f(X) and df/dX of discriminator function Pix2PixModel:fDx_basic(x, gradParams, netD, netG, real, fake, opt) util.BiasZero(netD) util.BiasZero(netG) gradParams:zero() -- Real log(D(B)) local output = netD:forward(real) local errD_real = self.criterionGAN:forward(output, self.real_label) local df_do = self.criterionGAN:backward(output, self.real_label) netD:backward(real, df_do) -- Fake + log(1 - D(G(A))) output = netD:forward(fake) local errD_fake = self.criterionGAN:forward(output, self.fake_label) local df_do2 = self.criterionGAN:backward(output, self.fake_label) netD:backward(fake, df_do2) -- calculate loss local errD = (errD_real + errD_fake) / 2.0 return errD, gradParams end function Pix2PixModel:fDx(x, opt) fake_AB = self.fakeABPool:Query(self.fake_AB) self.errD, gradParams = self:fDx_basic(x, self.gradParametersD, self.netD, self.netG, self.real_AB, fake_AB, opt) return self.errD, gradParams end function Pix2PixModel:fGx_basic(x, netG, netD, real, fake, gradParametersG, opt) util.BiasZero(netG) util.BiasZero(netD) gradParametersG:zero() -- First. G(A) should fake the discriminator local output = netD:forward(fake) local errG = self.criterionGAN:forward(output, self.real_label) local dgan_loss_dd = self.criterionGAN:backward(output, self.real_label) local dgan_loss_do = netD:updateGradInput(fake, dgan_loss_dd) -- Second. G(A) should be close to the real real_B = real[self.B_idx] real_A = real[self.A_idx] fake_B = fake[self.B_idx] local errL1 = self.criterionL1:forward(fake_B, real_B) * opt.lambda_A local dl1_loss_do = self.criterionL1:backward(fake_B, real_B) * opt.lambda_A netG:backward(real_A, dgan_loss_do[self.B_idx] + dl1_loss_do) return gradParametersG, errG, errL1 end function Pix2PixModel:fGx(x, opt) self.gradParametersG, self.errG, self.errL1 = self:fGx_basic(x, self.netG, self.netD, self.real_AB, self.fake_AB, self.gradParametersG, opt) return self.errG, self.gradParametersG end -- Runs the backprop gradient descent -- Corresponds to a single batch of data function Pix2PixModel:OptimizeParameters(opt) local fD = function(x) return self:fDx(x, opt) end local fG = function(x) return self:fGx(x, opt) end optim.adam(fD, self.parametersD, self.optimStateD) optim.adam(fG, self.parametersG, self.optimStateG) end -- This function can be used to reset momentum after each epoch function Pix2PixModel:RefreshParameters() self.parametersD, self.gradParametersD = nil, nil -- nil them to avoid spiking memory self.parametersG, self.gradParametersG = nil, nil -- define parameters of optimization self.parametersG, self.gradParametersG = self.netG:getParameters() self.parametersD, self.gradParametersD = self.netD:getParameters() end -- This function updates the learning rate; lr for the first opt.niter iterations; graduatlly decreases the lr to 0 for the next opt.niter_decay iterations function Pix2PixModel:UpdateLearningRate(opt) local lrd = opt.lr / opt.niter_decay local old_lr = self.optimStateD['learningRate'] local lr = old_lr - lrd self.optimStateD['learningRate'] = lr self.optimStateG['learningRate'] = lr print(('update learning rate: %f -> %f'):format(old_lr, lr)) end -- Save the current model to the file system function Pix2PixModel:Save(prefix, opt) util.save_model(self.netG, prefix .. '_net_G.t7', 1.0) util.save_model(self.netD, prefix .. '_net_D.t7', 1.0) end -- returns a string that describes the current errors function Pix2PixModel:GetCurrentErrorDescription() description = ('G: %.4f D: %.4f L1: %.4f'):format( self.errG and self.errG or -1, self.errD and self.errD or -1, self.errL1 and self.errL1 or -1) return description end -- returns a string that describes the display plot configuration function Pix2PixModel:DisplayPlot(opt) return 'errG,errD,errL1' end -- returns current errors function Pix2PixModel:GetCurrentErrors() local errors = {errG=self.errG, errD=self.errD, errL1=self.errL1} return errors end -- returns a table of image/label pairs that describe -- the current results. -- |return|: a table of table. List of image/label pairs function Pix2PixModel:GetCurrentVisuals(opt, size) if not size then size = opt.display_winsize end local visuals = {} table.insert(visuals, {img=self.real_A, label='real_A'}) table.insert(visuals, {img=self.fake_B, label='fake_B'}) table.insert(visuals, {img=self.real_B, label='real_B'}) return visuals end ================================================ FILE: options.lua ================================================ -------------------------------------------------------------------------------- -- Configure options -------------------------------------------------------------------------------- local options = {} -- options for train local opt_train = { DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) batchSize = 1, -- # images in batch loadSize = 143, -- scale images to this size fineSize = 128, -- then crop to this size ngf = 64, -- # of gen filters in first conv layer ndf = 64, -- # of discrim filters in first conv layer input_nc = 3, -- # of input image channels output_nc = 3, -- # of output image channels niter = 100, -- # of iter at starting learning rate niter_decay = 100, -- # of iter to linearly decay learning rate to zero lr = 0.0002, -- initial learning rate for adam beta1 = 0.5, -- momentum term of adam ntrain = math.huge, -- # of examples per epoch. math.huge for full dataset flip = 1, -- if flip the images for data argumentation display_id = 10, -- display window id. display_winsize = 128, -- display window size display_freq = 25, -- display the current results every display_freq iterations gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X name = '', -- name of the experiment, should generally be passed on the command line which_direction = 'AtoB', -- AtoB or BtoA phase = 'train', -- train, val, test, etc nThreads = 2, -- # threads for loading data save_epoch_freq = 1, -- save a model every save_epoch_freq epochs (does not overwrite previously saved models) save_latest_freq = 5000, -- save the latest model every latest_freq sgd iterations (overwrites the previous latest model) print_freq = 50, -- print the debug information every print_freq iterations save_display_freq = 2500, -- save the current display of results every save_display_freq_iterations continue_train = 0, -- if continue training, load the latest model: 1: true, 0: false serial_batches = 0, -- if 1, takes images in order to make batches, otherwise takes them randomly checkpoints_dir = './checkpoints', -- models are saved here cache_dir = './cache', -- cache files are saved here cudnn = 1, -- set to 0 to not use cudnn which_model_netD = 'basic', -- selects model to use for netD which_model_netG = 'resnet_6blocks', -- selects model to use for netG norm = 'instance', -- batch or instance normalization n_layers_D = 3, -- only used if which_model_netD=='n_layers' content_loss = 'pixel', -- content loss type: pixel, vgg layer_name = 'pixel', -- layer used in content loss (e.g. relu4_2) lambda_A = 10.0, -- weight for cycle loss (A -> B -> A) lambda_B = 10.0, -- weight for cycle loss (B -> A -> B) model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan' use_lsgan = 1, -- if 1, use least square GAN, if 0, use vanilla GAN align_data = 0, -- if > 0, use the dataloader for where the images are aligned pool_size = 50, -- the size of image buffer that stores previously generated images resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height lambda_identity = 0.5, -- use identity mapping. Setting opt.lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set opt.lambda_identity = 0.1 use_optnet = 0, -- use optnet to save GPU memory during test } -- options for test local opt_test = { DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc) loadSize = 128, -- scale images to this size fineSize = 128, -- then crop to this size flip = 0, -- horizontal mirroring data augmentation display = 1, -- display samples while training. 0 = false display_id = 200, -- display window id. gpu = 1, -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X how_many = 'all', -- how many test images to run (set to all to run on every image found in the data/phase folder) phase = 'test', -- train, val, test, etc aspect_ratio = 1.0, -- aspect ratio of result images norm = 'instance', -- batchnorm or isntance norm name = '', -- name of experiment, selects which model to run, should generally should be passed on command line input_nc = 3, -- # of input image channels output_nc = 3, -- # of output image channels serial_batches = 1, -- if 1, takes images in order to make batches, otherwise takes them randomly cudnn = 1, -- set to 0 to not use cudnn (untested) checkpoints_dir = './checkpoints', -- loads models from here cache_dir = './cache', -- cache files are saved here results_dir='./results/', -- saves results here which_epoch = 'latest', -- which epoch to test? set to 'latest' to use latest cached model model = 'cycle_gan', -- which mode to run. 'cycle_gan', 'pix2pix', 'bigan', 'content_gan'; to use pretrained model, select `one_direction_test` align_data = 0, -- if > 0, use the dataloader for pix2pix which_direction = 'AtoB', -- AtoB or BtoA resize_or_crop = 'resize_and_crop', -- resizing/cropping strategy: resize_and_crop | crop | scale_width | scale_height } -------------------------------------------------------------------------------- -- util functions -------------------------------------------------------------------------------- function options.clone(opt) local copy = {} for orig_key, orig_value in pairs(opt) do copy[orig_key] = orig_value end return copy end function options.parse_options(mode) if mode == 'train' then opt = opt_train opt.test = 0 elseif mode == 'test' then opt = opt_test opt.test = 1 else print("Invalid option [" .. mode .. "]") return nil end -- one-line argument parser. parses enviroment variables to override the defaults for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end if mode == 'test' then opt.nThreads = 1 opt.continue_train = 1 opt.batchSize = 1 -- test code only supports batchSize=1 end -- print by keys keyset = {} for k,v in pairs(opt) do table.insert(keyset, k) end table.sort(keyset) print("------------------- Options -------------------") for i,k in ipairs(keyset) do print(('%+25s: %s'):format(k, opt[k])) end print("-----------------------------------------------") -- save opt to checkpoints paths.mkdir(opt.checkpoints_dir) paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name)) opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals') paths.mkdir(opt.visual_dir) -- save opt to the disk fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w') for i,k in ipairs(keyset) do fd:write(("%+25s: %s\n"):format(k, opt[k])) end fd:close() return opt end return options ================================================ FILE: pretrained_models/download_model.sh ================================================ FILE=$1 echo "Note: available models are apple2orange, facades_photo2label, map2sat, orange2apple, style_cezanne, style_ukiyoe, summer2winter_yosemite, zebra2horse, facades_label2photo, horse2zebra,monet2photo, sat2map, style_monet,style_vangogh, winter2summer_yosemite, iphone2dslr_flower" echo "Specified [$FILE]" mkdir -p ./checkpoints/${FILE}_pretrained URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/models/$FILE.t7 MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.t7 wget -N $URL -O $MODEL_FILE ================================================ FILE: pretrained_models/download_vgg.sh ================================================ URL1=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.caffemodel MODEL_FILE1=./models/places_vgg.caffemodel URL2=https://people.eecs.berkeley.edu/~taesung_park/projects/CycleGAN/models/places_vgg.prototxt MODEL_FILE2=./models/places_vgg.prototxt wget -N $URL1 -O $MODEL_FILE1 wget -N $URL2 -O $MODEL_FILE2 ================================================ FILE: pretrained_models/places_vgg.prototxt ================================================ name: "VGG-Places365" input: "data" input_dim: 1 input_dim: 3 input_dim: 224 input_dim: 224 layer { name: "conv1_1" type: "Convolution" bottom: "data" top: "conv1_1" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 64 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu1_1" type: "ReLU" bottom: "conv1_1" top: "conv1_1" } layer { name: "conv1_2" type: "Convolution" bottom: "conv1_1" top: "conv1_2" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 64 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu1_2" type: "ReLU" bottom: "conv1_2" top: "conv1_2" } layer { name: "pool1" type: "Pooling" bottom: "conv1_2" top: "pool1" pooling_param { pool: MAX kernel_size: 2 stride: 2 } } layer { name: "conv2_1" type: "Convolution" bottom: "pool1" top: "conv2_1" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 128 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu2_1" type: "ReLU" bottom: "conv2_1" top: "conv2_1" } layer { name: "conv2_2" type: "Convolution" bottom: "conv2_1" top: "conv2_2" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 128 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu2_2" type: "ReLU" bottom: "conv2_2" top: "conv2_2" } layer { name: "pool2" type: "Pooling" bottom: "conv2_2" top: "pool2" pooling_param { pool: MAX kernel_size: 2 stride: 2 } } layer { name: "conv3_1" type: "Convolution" bottom: "pool2" top: "conv3_1" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 256 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu3_1" type: "ReLU" bottom: "conv3_1" top: "conv3_1" } layer { name: "conv3_2" type: "Convolution" bottom: "conv3_1" top: "conv3_2" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 256 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu3_2" type: "ReLU" bottom: "conv3_2" top: "conv3_2" } layer { name: "conv3_3" type: "Convolution" bottom: "conv3_2" top: "conv3_3" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 256 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu3_3" type: "ReLU" bottom: "conv3_3" top: "conv3_3" } layer { name: "pool3" type: "Pooling" bottom: "conv3_3" top: "pool3" pooling_param { pool: MAX kernel_size: 2 stride: 2 } } layer { name: "conv4_1" type: "Convolution" bottom: "pool3" top: "conv4_1" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu4_1" type: "ReLU" bottom: "conv4_1" top: "conv4_1" } layer { name: "conv4_2" type: "Convolution" bottom: "conv4_1" top: "conv4_2" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu4_2" type: "ReLU" bottom: "conv4_2" top: "conv4_2" } layer { name: "conv4_3" type: "Convolution" bottom: "conv4_2" top: "conv4_3" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu4_3" type: "ReLU" bottom: "conv4_3" top: "conv4_3" } layer { name: "pool4" type: "Pooling" bottom: "conv4_3" top: "pool4" pooling_param { pool: MAX kernel_size: 2 stride: 2 } } layer { name: "conv5_1" type: "Convolution" bottom: "pool4" top: "conv5_1" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu5_1" type: "ReLU" bottom: "conv5_1" top: "conv5_1" } layer { name: "conv5_2" type: "Convolution" bottom: "conv5_1" top: "conv5_2" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu5_2" type: "ReLU" bottom: "conv5_2" top: "conv5_2" } layer { name: "conv5_3" type: "Convolution" bottom: "conv5_2" top: "conv5_3" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } convolution_param { num_output: 512 pad: 1 kernel_size: 3 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu5_3" type: "ReLU" bottom: "conv5_3" top: "conv5_3" } layer { name: "pool5" type: "Pooling" bottom: "conv5_3" top: "pool5" pooling_param { pool: MAX kernel_size: 2 stride: 2 } } layer { name: "fc6" type: "InnerProduct" bottom: "pool5" top: "fc6" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } inner_product_param { num_output: 4096 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu6" type: "ReLU" bottom: "fc6" top: "fc6" } layer { name: "drop6" type: "Dropout" bottom: "fc6" top: "fc6" dropout_param { dropout_ratio: 0.5 } } layer { name: "fc7" type: "InnerProduct" bottom: "fc6" top: "fc7" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } inner_product_param { num_output: 4096 weight_filler { type: "gaussian" std: 0.01 } bias_filler { type: "constant" value: 0.0 } } } layer { name: "relu7" type: "ReLU" bottom: "fc7" top: "fc7" } layer { name: "drop7" type: "Dropout" bottom: "fc7" top: "fc7" dropout_param { dropout_ratio: 0.5 } } layer { name: "fc8a" type: "InnerProduct" bottom: "fc7" top: "fc8a" param { lr_mult: 1.0 decay_mult: 1.0 } param { lr_mult: 2.0 decay_mult: 0.0 } inner_product_param { num_output: 365 } } layer { name: "prob" type: "Softmax" bottom: "fc8a" top: "prob" } ================================================ FILE: test.lua ================================================ -- usage: DATA_ROOT=/path/to/data/ name=expt1 which_direction=BtoA th test.lua -- -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix require 'image' require 'nn' require 'nngraph' require 'models.architectures' util = paths.dofile('util/util.lua') options = require 'options' opt = options.parse_options('test') -- initialize torch GPU/CPU mode if opt.gpu > 0 then require 'cutorch' require 'cunn' cutorch.setDevice(opt.gpu) print ("GPU Mode") torch.setdefaulttensortype('torch.CudaTensor') else torch.setdefaulttensortype('torch.FloatTensor') print ("CPU Mode") end -- setup visualization visualizer = require 'util/visualizer' function TableConcat(t1,t2) for i=1,#t2 do t1[#t1+1] = t2[i] end return t1 end -- load data local data_loader = nil if opt.align_data > 0 then require 'data.aligned_data_loader' data_loader = AlignedDataLoader() else require 'data.unaligned_data_loader' data_loader = UnalignedDataLoader() end print( "DataLoader " .. data_loader:name() .. " was created.") data_loader:Initialize(opt) if opt.how_many == 'all' then opt.how_many = data_loader:size() end opt.how_many = math.min(opt.how_many, data_loader:size()) -- set batch/instance normalization set_normalization(opt.norm) -- load model opt.continue_train = 1 -- define model if opt.model == 'cycle_gan' then require 'models.cycle_gan_model' model = CycleGANModel() elseif opt.model == 'one_direction_test' then require 'models.one_direction_test_model' model = OneDirectionTestModel() elseif opt.model == 'pix2pix' then require 'models.pix2pix_model' model = Pix2PixModel() elseif opt.model == 'bigan' then require 'models.bigan_model' model = BiGANModel() elseif opt.model == 'content_gan' then require 'models.content_gan_model' model = ContentGANModel() else error('Please specify a correct model') end model:Initialize(opt) local pathsA = {} -- paths to images A tested on local pathsB = {} -- paths to images B tested on local web_dir = paths.concat(opt.results_dir, opt.name .. '/' .. opt.which_epoch .. '_' .. opt.phase) paths.mkdir(web_dir) local image_dir = paths.concat(web_dir, 'images') paths.mkdir(image_dir) s1 = opt.fineSize s2 = opt.fineSize / opt.aspect_ratio visuals = {} for n = 1, math.floor(opt.how_many) do print('processing batch ' .. n) local cur_dataA, cur_dataB, cur_pathsA, cur_pathsB = data_loader:GetNextBatch() cur_pathsA = util.basename_batch(cur_pathsA) cur_pathsB = util.basename_batch(cur_pathsB) print('pathsA', cur_pathsA) print('pathsB', cur_PathsB) model:Forward({real_A=cur_dataA, real_B=cur_dataB}, opt) visuals = model:GetCurrentVisuals(opt, opt.fineSize) for i,visual in ipairs(visuals) do if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then s1 = nil s2 = nil end visualizer.save_images(visual.img, paths.concat(image_dir, visual.label), {string.gsub(cur_pathsA[1],'.jpg','.png')}, s1, s2) end print('Saved images to: ', image_dir) pathsA = TableConcat(pathsA, cur_pathsA) pathsB = TableConcat(pathsB, cur_pathsB) end labels = {} for i,visual in ipairs(visuals) do table.insert(labels, visual.label) end -- make webpage io.output(paths.concat(web_dir, 'index.html')) io.write('') io.write('') for i = 1, #labels do io.write('') end io.write('') for n = 1,math.floor(opt.how_many) do io.write('') io.write('') for j = 1, #labels do label = labels[j] io.write('') end io.write('') end io.write('
Image ' .. labels[i] .. '
' .. tostring(n) .. '
') ================================================ FILE: train.lua ================================================ -- usage example: DATA_ROOT=/path/to/data/ which_direction=BtoA name=expt1 th train.lua -- code derived from https://github.com/soumith/dcgan.torch and https://github.com/phillipi/pix2pix require 'torch' require 'nn' require 'optim' util = paths.dofile('util/util.lua') content = paths.dofile('util/content_loss.lua') require 'image' require 'models.architectures' -- load configuration file options = require 'options' opt = options.parse_options('train') -- setup visualization visualizer = require 'util/visualizer' -- initialize torch GPU/CPU mode if opt.gpu > 0 then require 'cutorch' require 'cunn' cutorch.setDevice(opt.gpu) print ("GPU Mode") torch.setdefaulttensortype('torch.CudaTensor') else torch.setdefaulttensortype('torch.FloatTensor') print ("CPU Mode") end -- load data local data_loader = nil if opt.align_data > 0 then require 'data.aligned_data_loader' data_loader = AlignedDataLoader() else require 'data.unaligned_data_loader' data_loader = UnalignedDataLoader() end print( "DataLoader " .. data_loader:name() .. " was created.") data_loader:Initialize(opt) -- set batch/instance normalization set_normalization(opt.norm) --- timer local epoch_tm = torch.Timer() local tm = torch.Timer() -- define model local model = nil local display_plot = nil if opt.model == 'cycle_gan' then assert(data_loader:name() == 'UnalignedDataLoader') require 'models.cycle_gan_model' model = CycleGANModel() elseif opt.model == 'pix2pix' then require 'models.pix2pix_model' assert(data_loader:name() == 'AlignedDataLoader') model = Pix2PixModel() elseif opt.model == 'bigan' then assert(data_loader:name() == 'UnalignedDataLoader') require 'models.bigan_model' model = BiGANModel() elseif opt.model == 'content_gan' then require 'models.content_gan_model' assert(data_loader:name() == 'UnalignedDataLoader') model = ContentGANModel() else error('Please specify a correct model') end -- print the model name print('Model ' .. model:model_name() .. ' was specified.') model:Initialize(opt) -- set up the loss plot require 'util/plot_util' plotUtil = PlotUtil() display_plot = model:DisplayPlot(opt) plotUtil:Initialize(display_plot, opt.display_id, opt.name) -------------------------------------------------------------------------------- -- Helper Functions -------------------------------------------------------------------------------- function visualize_current_results() local visuals = model:GetCurrentVisuals(opt) for i,visual in ipairs(visuals) do visualizer.disp_image(visual.img, opt.display_winsize, opt.display_id+i, opt.name .. ' ' .. visual.label) end end function save_current_results(epoch, counter) local visuals = model:GetCurrentVisuals(opt) for i,visual in ipairs(visuals) do output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg') visualizer.save_results(visual.img, output_path) end end function print_current_errors(epoch, counter_in_epoch) print(('Epoch: [%d][%8d / %8d]\t Time: %.3f DataTime: %.3f ' .. '%s'): format(epoch, ((counter_in_epoch-1) / opt.batchSize), math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize), tm:time().real / opt.batchSize, data_loader:time_elapsed_to_fetch_data() / opt.batchSize, model:GetCurrentErrorDescription() )) end function plot_current_errors(epoch, counter_ratio, opt) local errs = model:GetCurrentErrors(opt) local plot_vals = { epoch + counter_ratio} plotUtil:Display(plot_vals, errs) end -------------------------------------------------------------------------------- -- Main Training Loop -------------------------------------------------------------------------------- local counter = 0 local num_batches = math.floor(math.min(data_loader:size(), opt.ntrain) / opt.batchSize) print('#training iterations: ' .. opt.niter+opt.niter_decay ) for epoch = 1, opt.niter+opt.niter_decay do epoch_tm:reset() for counter_in_epoch = 1, math.min(data_loader:size(), opt.ntrain), opt.batchSize do tm:reset() -- load a batch and run G on that batch local real_dataA, real_dataB, _, _ = data_loader:GetNextBatch() model:Forward({real_A=real_dataA, real_B=real_dataB}, opt) -- run forward pass opt.counter = counter -- run backward pass model:OptimizeParameters(opt) -- display on the web server if counter % opt.display_freq == 0 and opt.display_id > 0 then visualize_current_results() end -- logging if counter % opt.print_freq == 0 then print_current_errors(epoch, counter_in_epoch) plot_current_errors(epoch, counter_in_epoch/num_batches, opt) end -- save latest model if counter % opt.save_latest_freq == 0 and counter > 0 then print(('saving the latest model (epoch %d, iters %d)'):format(epoch, counter)) model:Save('latest', opt) end -- save latest results if counter % opt.save_display_freq == 0 then save_current_results(epoch, counter) end counter = counter + 1 end -- save model at the end of epoch if epoch % opt.save_epoch_freq == 0 then print(('saving the model (epoch %d, iters %d)'):format(epoch, counter)) model:Save('latest', opt) model:Save(epoch, opt) end -- print the timing information after each epoch print(('End of epoch %d / %d \t Time Taken: %.3f'): format(epoch, opt.niter+opt.niter_decay, epoch_tm:time().real)) -- update learning rate if epoch > opt.niter then model:UpdateLearningRate(opt) end -- refresh parameters model:RefreshParameters(opt) end ================================================ FILE: util/InstanceNormalization.lua ================================================ require 'nn' --[[ Implements instance normalization as described in the paper Instance Normalization: The Missing Ingredient for Fast Stylization Dmitry Ulyanov, Andrea Vedaldi, Victor Lempitsky https://arxiv.org/abs/1607.08022 This implementation is based on https://github.com/DmitryUlyanov/texture_nets ]] local InstanceNormalization, parent = torch.class('nn.InstanceNormalization', 'nn.Module') function InstanceNormalization:__init(nOutput, eps, momentum, affine) parent.__init(self) self.running_mean = torch.zeros(nOutput) self.running_var = torch.ones(nOutput) self.eps = eps or 1e-5 self.momentum = momentum or 0.0 if affine ~= nil then assert(type(affine) == 'boolean', 'affine has to be true/false') self.affine = affine else self.affine = true end self.nOutput = nOutput self.prev_batch_size = -1 if self.affine then self.weight = torch.Tensor(nOutput):uniform() self.bias = torch.Tensor(nOutput):zero() self.gradWeight = torch.Tensor(nOutput) self.gradBias = torch.Tensor(nOutput) end end function InstanceNormalization:updateOutput(input) self.output = self.output or input.new() assert(input:size(2) == self.nOutput) local batch_size = input:size(1) if batch_size ~= self.prev_batch_size or (self.bn and self:type() ~= self.bn:type()) then self.bn = nn.SpatialBatchNormalization(input:size(1)*input:size(2), self.eps, self.momentum, self.affine) self.bn:type(self:type()) self.bn.running_mean:copy(self.running_mean:repeatTensor(batch_size)) self.bn.running_var:copy(self.running_var:repeatTensor(batch_size)) self.prev_batch_size = input:size(1) end -- Get statistics self.running_mean:copy(self.bn.running_mean:view(input:size(1),self.nOutput):mean(1)) self.running_var:copy(self.bn.running_var:view(input:size(1),self.nOutput):mean(1)) -- Set params for BN if self.affine then self.bn.weight:copy(self.weight:repeatTensor(batch_size)) self.bn.bias:copy(self.bias:repeatTensor(batch_size)) end local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) self.output = self.bn:forward(input_1obj):viewAs(input) return self.output end function InstanceNormalization:updateGradInput(input, gradOutput) self.gradInput = self.gradInput or gradOutput.new() assert(self.bn) local input_1obj = input:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) local gradOutput_1obj = gradOutput:contiguous():view(1,input:size(1)*input:size(2),input:size(3),input:size(4)) if self.affine then self.bn.gradWeight:zero() self.bn.gradBias:zero() end self.gradInput = self.bn:backward(input_1obj, gradOutput_1obj):viewAs(input) if self.affine then self.gradWeight:add(self.bn.gradWeight:view(input:size(1),self.nOutput):sum(1)) self.gradBias:add(self.bn.gradBias:view(input:size(1),self.nOutput):sum(1)) end return self.gradInput end function InstanceNormalization:clearState() self.output = self.output.new() self.gradInput = self.gradInput.new() if self.bn then self.bn:clearState() end end function InstanceNormalization:evaluate() end function InstanceNormalization:training() end ================================================ FILE: util/VGG_preprocess.lua ================================================ -- define nn module for VGG postprocessing local VGG_postprocess, parent = torch.class('nn.VGG_postprocess', 'nn.Module') function VGG_postprocess:__init() parent.__init(self) end function VGG_postprocess:updateOutput(input) self.output = input:add(1):mul(127.5) -- print(self.output:max(), self.output:min()) if self.output:max() > 255 or self.output:min() < 0 then print(self.output:min(), self.output:max()) end -- assert(self.output:min()>=0,"badly scaled inputs") -- assert(self.output:max()<=255,"badly scaled inputs") local mean_pixel = torch.FloatTensor({103.939, 116.779, 123.68}) mean_pixel = mean_pixel:reshape(1,3,1,1) mean_pixel = mean_pixel:repeatTensor(input:size(1), 1, input:size(3), input:size(4)):cuda() self.output:add(-1, mean_pixel) return self.output end function VGG_postprocess:updateGradInput(input, gradOutput) self.gradInput = gradOutput:div(127.5) return self.gradInput end ================================================ FILE: util/content_loss.lua ================================================ require 'torch' require 'nn' local content = {} function content.defineVGG(content_layer) local contentFunc = nn.Sequential() require 'loadcaffe' require 'util/VGG_preprocess' cnn = loadcaffe.load('../models/vgg.prototxt', '../models/vgg.caffemodel', 'cudnn') contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224})) contentFunc:add(nn.VGG_postprocess()) for i = 1, #cnn do local layer = cnn:get(i):clone() local name = layer.name local layer_type = torch.type(layer) contentFunc:add(layer) if name == content_layer then print("Setting up content layer: ", layer.name) break end end cnn = nil collectgarbage() print(contentFunc) return contentFunc end function content.defineAlexNet(content_layer) local contentFunc = nn.Sequential() require 'loadcaffe' require 'util/VGG_preprocess' cnn = loadcaffe.load('../models/alexnet.prototxt', '../models/alexnet.caffemodel', 'cudnn') contentFunc:add(nn.SpatialUpSamplingBilinear({oheight=224, owidth=224})) contentFunc:add(nn.VGG_postprocess()) for i = 1, #cnn do local layer = cnn:get(i):clone() local name = layer.name local layer_type = torch.type(layer) contentFunc:add(layer) if name == content_layer then print("Setting up content layer: ", layer.name) break end end cnn = nil collectgarbage() print(contentFunc) return contentFunc end function content.defineContent(content_loss, layer_name) -- print('content_loss_define', content_loss) if content_loss == 'pixel' or content_loss == 'none' then return nil elseif content_loss == 'vgg' then return content.defineVGG(layer_name) else print("unsupported content loss") return nil end end function content.lossUpdate(criterionContent, real_source, fake_target, contentFunc, loss_type, weight) if loss_type == 'none' then local errCont = 0.0 local df_d_content = torch.zeros(fake_target:size()) return errCont, df_d_content elseif loss_type == 'pixel' then local errCont = criterionContent:forward(fake_target, real_source) * weight local df_do_content = criterionContent:backward(fake_target, real_source)*weight return errCont, df_do_content elseif loss_type == 'vgg' then local f_fake = contentFunc:forward(fake_target):clone() local f_real = contentFunc:forward(real_source):clone() local errCont = criterionContent:forward(f_fake, f_real) * weight local df_do_tmp = criterionContent:backward(f_fake, f_real) * weight local df_do_content = contentFunc:updateGradInput(fake_target, df_do_tmp)--:mul(weight) return errCont, df_do_content else error("unsupported content loss") end end return content ================================================ FILE: util/cudnn_convert_custom.lua ================================================ -- modified from https://github.com/NVIDIA/torch-cudnn/blob/master/convert.lua -- removed error on nngraph -- modules that can be converted to nn seamlessly local layer_list = { 'BatchNormalization', 'SpatialBatchNormalization', 'SpatialConvolution', 'SpatialCrossMapLRN', 'SpatialFullConvolution', 'SpatialMaxPooling', 'SpatialAveragePooling', 'ReLU', 'Tanh', 'Sigmoid', 'SoftMax', 'LogSoftMax', 'VolumetricBatchNormalization', 'VolumetricConvolution', 'VolumetricFullConvolution', 'VolumetricMaxPooling', 'VolumetricAveragePooling', } -- goes over a given net and converts all layers to dst backend -- for example: net = cudnn_convert_custom(net, cudnn) -- same as cudnn.convert with gModule check commented out function cudnn_convert_custom(net, dst, exclusion_fn) return net:replace(function(x) --if torch.type(x) == 'nn.gModule' then -- io.stderr:write('Warning: cudnn.convert does not work with nngraph yet. Ignoring nn.gModule') -- return x --end local y = 0 local src = dst == nn and cudnn or nn local src_prefix = src == nn and 'nn.' or 'cudnn.' local dst_prefix = dst == nn and 'nn.' or 'cudnn.' local function convert(v) local y = {} torch.setmetatable(y, dst_prefix..v) if v == 'ReLU' then y = dst.ReLU() end -- because parameters for k,u in pairs(x) do y[k] = u end if src == cudnn and x.clearDesc then x.clearDesc(y) end if src == cudnn and v == 'SpatialAveragePooling' then y.divide = true y.count_include_pad = v.mode == 'CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING' end if src == nn and string.find(v, 'Convolution') then y.groups = 1 end return y end if exclusion_fn and exclusion_fn(x) then return x end local t = torch.typename(x) if t == 'nn.SpatialConvolutionMM' then y = convert('SpatialConvolution') elseif t == 'inn.SpatialCrossResponseNormalization' then y = convert('SpatialCrossMapLRN') else for i,v in ipairs(layer_list) do if torch.typename(x) == src_prefix..v then y = convert(v) end end end return y == 0 and x or y end) end ================================================ FILE: util/image_pool.lua ================================================ local class = require 'class' ImagePool= class('ImagePool') require 'torch' require 'image' function ImagePool:__init(pool_size) self.pool_size = pool_size if pool_size > 0 then self.num_imgs = 0 self.images = {} end end function ImagePool:model_name() return 'ImagePool' end -- -- function ImagePool:Initialize(pool_size) -- -- torch.manualSeed(0) -- -- assert(pool_size > 0) -- self.pool_size = pool_size -- if pool_size > 0 then -- self.num_imgs = 0 -- self.images = {} -- end -- end function ImagePool:Query(image) -- print('query image') if self.pool_size == 0 then -- print('get identical image') return image end if self.num_imgs < self.pool_size then -- self.images.insert(image:clone()) self.num_imgs = self.num_imgs + 1 self.images[self.num_imgs] = image return image else local p = math.random() -- print('p' ,p) -- os.exit() if p > 0.5 then -- print('use old image') -- random_id = torch.Tensor(1) -- random_id:random(1, self.pool_size) local random_id = math.random(self.pool_size) -- print('random_id', random_id) local tmp = self.images[random_id]:clone() self.images[random_id] = image:clone() return tmp else return image end end end ================================================ FILE: util/plot_util.lua ================================================ local class = require 'class' PlotUtil = class('PlotUtil') require 'torch' disp = require 'display' util = require 'util/util' require 'image' local unpack = unpack or table.unpack function PlotUtil:__init(conf) conf = conf or {} end function PlotUtil:model_name() return 'PlotUtil' end function PlotUtil:Initialize(display_plot, display_id, name) self.display_plot = string.split(string.gsub(display_plot, "%s+", ""), ",") self.plot_config = { title = name .. ' loss over time', labels = {'epoch', unpack(self.display_plot)}, ylabel = 'loss', win = display_id, } self.plot_data = {} print('display_opt', self.display_plot) end function PlotUtil:Display(plot_vals, loss) for k, v in ipairs(self.display_plot) do if loss[v] ~= nil then plot_vals[#plot_vals + 1] = loss[v] end end table.insert(self.plot_data, plot_vals) disp.plot(self.plot_data, self.plot_config) end ================================================ FILE: util/util.lua ================================================ -- -- code derived from https://github.com/soumith/dcgan.torch -- local util = {} require 'torch' function util.BiasZero(net) net:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end) end function util.checkEqual(A, B, name) local dif = (A:float()-B:float()):abs():mean() print(name, dif) end function util.containsValue(table, value) for k, v in pairs(table) do if v == value then return true end end return false end function util.CheckTensor(A, name) print(name, A:min(), A:max(), A:mean()) end function util.normalize(img) -- rescale image to 0 .. 1 local min = img:min() local max = img:max() img = torch.FloatTensor(img:size()):copy(img) img:add(-min):mul(1/(max-min)) return img end function util.normalizeBatch(batch) for i = 1, batch:size(1) do batch[i] = util.normalize(batch[i]:squeeze()) end return batch end function util.basename_batch(batch) for i = 1, #batch do batch[i] = paths.basename(batch[i]) end return batch end -- default preprocessing -- -- Preprocesses an image before passing it to a net -- Converts from RGB to BGR and rescales from [0,1] to [-1,1] function util.preprocess(img) -- RGB to BGR if img:size(1) == 3 then local perm = torch.LongTensor{3, 2, 1} img = img:index(1, perm) end -- [0,1] to [-1,1] img = img:mul(2):add(-1) -- check that input is in expected range assert(img:max()<=1,"badly scaled inputs") assert(img:min()>=-1,"badly scaled inputs") return img end -- Undo the above preprocessing. function util.deprocess(img) -- BGR to RGB if img:size(1) == 3 then local perm = torch.LongTensor{3, 2, 1} img = img:index(1, perm) end -- [-1,1] to [0,1] img = img:add(1):div(2) return img end function util.preprocess_batch(batch) for i = 1, batch:size(1) do batch[i] = util.preprocess(batch[i]:squeeze()) end return batch end function util.print_tensor(name, x) print(name, x:size(), x:min(), x:mean(), x:max()) end function util.deprocess_batch(batch) for i = 1, batch:size(1) do batch[i] = util.deprocess(batch[i]:squeeze()) end return batch end function util.scaleBatch(batch,s1,s2) -- print('s1', s1) -- print('s2', s2) local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2) for i = 1, batch:size(1) do scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze() end return scaled_batch end function util.toTrivialBatch(input) return input:reshape(1,input:size(1),input:size(2),input:size(3)) end function util.fromTrivialBatch(input) return input[1] end -- input is between -1 and 1 function util.jitter(input) local noise = torch.rand(input:size())/256.0 input:add(1.0):mul(0.5*255.0/256.0):add(noise):add(-0.5):mul(2.0) --local scaled = (input+1.0)*0.5 --local jittered = scaled*255.0/256.0 + torch.rand(input:size())/256.0 --local scaled_back = (jittered-0.5)*2.0 --return scaled_back end function util.scaleImage(input, loadSize) -- replicate bw images to 3 channels if input:size(1)==1 then input = torch.repeatTensor(input,3,1,1) end input = image.scale(input, loadSize, loadSize) return input end function util.getAspectRatio(path) local input = image.load(path, 3, 'float') local ar = input:size(3)/input:size(2) return ar end function util.loadImage(path, loadSize, nc) local input = image.load(path, 3, 'float') input= util.preprocess(util.scaleImage(input, loadSize)) if nc == 1 then input = input[{{1}, {}, {}}] end return input end function file_exists(filename) local f = io.open(filename,"r") if f ~= nil then io.close(f) return true else return false end end -- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations function load_helper(filename, opt) fileExists = file_exists(filename) if not fileExists then print('model not found! ' .. filename) return nil end print(('loading previously trained model (%s)'):format(filename)) if opt.norm == 'instance' then print('use InstanceNormalization') require 'util.InstanceNormalization' end if opt.cudnn>0 then require 'cudnn' end local net = torch.load(filename) if opt.gpu > 0 then require 'cunn' net:cuda() -- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below if net.forwardnodes then for i=1,#net.forwardnodes do if net.forwardnodes[i].data.module then net.forwardnodes[i].data.module:cuda() end end end else net:float() end net:apply(function(m) if m.weight then m.gradWeight = m.weight:clone():zero(); m.gradBias = m.bias:clone():zero(); end end) return net end function util.load_model(name, opt) -- if opt['lambda_'.. name] > 0.0 then -- print('not loading model '.. opt.checkpoints_dir .. opt.name .. -- 'latest_net_' .. name .. '.t7' .. ' because opt.lambda is not greater than zero') return load_helper(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_' .. name .. '.t7'), opt) -- end end function util.load_test_model(name, opt) return load_helper(paths.concat(opt.checkpoints_dir, opt.name, opt.which_epoch .. '_net_' .. name .. '.t7'), opt) end -- load dataset from the file system -- |name|: name of the dataset. It's currently either 'A' or 'B' -- function util.load_dataset(name, nc, opt, nc) -- local tensortype = torch.getdefaulttensortype() -- torch.setdefaulttensortype('torch.FloatTensor') -- -- local new_opt = options.clone(opt) -- new_opt.manualSeed = torch.random(1, 10000) -- fix seed -- new_opt.nc = nc -- torch.manualSeed(new_opt.manualSeed) -- local data_loader = paths.dofile('../data/data.lua') -- new_opt.phase = new_opt.phase .. name -- local data = data_loader.new(new_opt.nThreads, new_opt) -- print("Dataset Size " .. name .. ": ", data:size()) -- -- torch.setdefaulttensortype(tensortype) -- return data -- end function util.cudnn(net) require 'cudnn' require 'util/cudnn_convert_custom' return cudnn_convert_custom(net, cudnn) end function util.save_model(net, net_name, weight) if weight > 0.0 then torch.save(paths.concat(opt.checkpoints_dir, opt.name, net_name), net:clearState()) end end return util ================================================ FILE: util/visualizer.lua ================================================ ------------------------------------------------------------- -- Various utilities for visualization through the web server ------------------------------------------------------------- local visualizer = {} require 'torch' disp = nil print(opt) if opt.display_id > 0 then -- [hack]: assume that opt already existed disp = require 'display' end util = require 'util/util' require 'image' -- function visualizer function visualizer.disp_image(img_data, win_size, display_id, title) images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size)) disp.image(images, {win=display_id, title=title}) end function visualizer.save_results(img_data, output_path) local tensortype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor') local image_out = nil local win_size = opt.display_winsize images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size))) if images:dim() == 3 then image_out = images else for i = 1,images:size(1) do img = images[i] if image_out == nil then image_out = img else image_out = torch.cat(image_out, img) end end end image.save(output_path, image_out) torch.setdefaulttensortype(tensortype) end function visualizer.save_images(imgs, save_dir, impaths, s1, s2) local tensortype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor') batchSize = imgs:size(1) imgs_f = util.deprocess_batch(imgs):float() paths.mkdir(save_dir) for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio if s1 ~= nil and s2 ~= nil then im_s = image.scale(imgs_f[i], s1, s2):float() else im_s = imgs_f[i]:float() end img_to_save = torch.FloatTensor(im_s:size()):copy(im_s) image.save(paths.concat(save_dir, impaths[i]), img_to_save) end torch.setdefaulttensortype(tensortype) end return visualizer