[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Vikram Voleti\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# self-attention-GAN-pytorch\n\nThis is an almost exact replica in PyTorch of the Tensorflow version of [SAGAN](https://arxiv.org/abs/1805.08318) released by Google Brain [[repo](https://github.com/brain-research/self-attention-gan)] in August 2018.\n\nCode structure is inspired from [this repo](https://github.com/heykeetae/Self-Attention-GAN), but follows the details of [Google Brain's repo](https://github.com/brain-research/self-attention-gan).\n\n## Prerequisites\n\nCheck `requirements.txt`.\n\n* [Python 3.5+](https://www.continuum.io/downloads)\n* [PyTorch 0.4.1](http://pytorch.org/)\n\n## Training\n\n#### 1. Check `parameters.py` for all arguments and their default values\n\n#### 2. Train on custom images in folder a/b/c:\n```bash\n$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --name sagan\n```\n\n(Warning: Works only on *128x128* images, input images are resized to that. Tweak the Generator & Discriminator first if you would like to use some other image size. And then use the `imsize` option:\n```bash\n$ python train.py --data_path 'a/b/c' --save_path 'o/p/q' --batch_size 64 --imsize 32 --name sagan\n```\n)\n\nModel training will be recorded in a new folder inside `--save_path` with the name `<timestamp>_<name>_<basename of data_path>`.\n\nBy default, model weights are saved in a subfolder called `weights`, and train & validation samples during training in a subfolder called `samples` (can be changed in `parameters.py`).\n\n## Testing/Evaluating\n\nCheck `test.py`.\n\n## Self-Attention GAN\n**[Han Zhang, Ian Goodfellow, Dimitris Metaxas and Augustus Odena, \"Self-Attention Generative Adversarial Networks.\" arXiv preprint arXiv:1805.08318 (2018)](https://arxiv.org/abs/1805.08318).**\n\n```\n@article{Zhang2018SelfAttentionGA,\n    title={Self-Attention Generative Adversarial Networks},\n    author={Han Zhang and Ian J. Goodfellow and Dimitris N. Metaxas and Augustus Odena},\n    journal={CoRR},\n    year={2018},\n    volume={abs/1805.08318}\n}\n```\n"
  },
  {
    "path": "parameters.py",
    "content": "import argparse\nimport datetime\nimport os\n\n\ndef get_parameters():\n\n    parser = argparse.ArgumentParser()\n\n    # Images data path & Output path\n    parser.add_argument('--dataset', type=str, default='folder', choices=[\"cifar10\", \"fake\", \"folder\", \"hdf5\", \"imagenet\", \"lfw\", \"lsun\"],\n                        help=\"cifar10 | fake | folder | hdf5 | imagenet | lfw | lsun\")\n    parser.add_argument('--data_path', type=str, default='', help='Path to root of image data (saved in dirs of classes)')\n    parser.add_argument('--save_path', type=str, default='./sagan_models')\n\n    # Training settings\n    parser.add_argument('--batch_size', type=int, default=64)\n    parser.add_argument('--batch_size_in_gpu', type=int, default=0,\n                        help='0 => same as batch_size, else: if using multiple gpu iterations to make an effective batch, e.g. batch_size=32, batch_size_in_gpu=16 => optimizer.step() is run 2 iterations after running loss.backward()')\n    parser.add_argument('--total_step', type=int, default=200000, help='how many iterations')\n    parser.add_argument('--d_steps_per_iter', type=int, default=1, help='how many D updates per iteration')\n    parser.add_argument('--g_steps_per_iter', type=int, default=1, help='how many G updates per iteration')\n    parser.add_argument('--d_lr', type=float, default=0.0004)\n    parser.add_argument('--g_lr', type=float, default=0.0001)\n    parser.add_argument('--beta1', type=float, default=0.0)\n    parser.add_argument('--beta2', type=float, default=0.999)\n\n    # Model hyper-parameters\n    parser.add_argument('--adv_loss', type=str, default='hinge', choices=['hinge', 'dcgan', 'wgan_gp', 'gan'])\n    parser.add_argument('--z_dim', type=int, default=128)\n    parser.add_argument('--g_conv_dim', type=int, default=64)\n    parser.add_argument('--d_conv_dim', type=int, default=64)\n    parser.add_argument('--lambda_gp', type=float, default=10)\n\n    # Instance noise\n    # https://github.com/soumith/ganhacks/issues/14#issuecomment-312509518\n    # https://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/\n    parser.add_argument('--inst_noise_sigma', type=float, default=0.0)\n    parser.add_argument('--inst_noise_sigma_iters', type=int, default=2000)\n\n    # Image transforms\n    parser.add_argument('--dont_shuffle', action='store_true')\n    parser.add_argument('--dont_drop_last', action='store_true', help=\"Whether not to drop the last batch in dataset if its size < batch_size\")\n    parser.add_argument('--dont_resize', action='store_true', help=\"Whether not to resize images\")\n    parser.add_argument('--imsize', type=int, default=128)\n    parser.add_argument('--centercrop', action='store_true', help=\"Whether to center crop images\")\n    parser.add_argument('--centercrop_size', type=int, default=128)\n    parser.add_argument('--dont_normalize', action='store_true', help=\"Whether to normalize image values\")\n\n    # Step sizes\n    parser.add_argument('--log_step', type=int, default=10)\n    parser.add_argument('--sample_step', type=int, default=10)\n    parser.add_argument('--model_save_step', type=float, default=50)\n    parser.add_argument('--save_n_images', type=int, default=0,\n                        help='0 => same as batch_size_in_gpu')\n    parser.add_argument('--nrow', type=int, default=10)\n    parser.add_argument('--max_frames_per_gif', type=int, default=100)\n\n    # Pretrained model\n    parser.add_argument('--pretrained_model', type=str, default='')\n    parser.add_argument('--state_dict_or_model', type=str, default='', help=\"Specify whether .pth pretrained_model is a 'state_dict' or a complete 'model'\")\n\n    # Misc\n    parser.add_argument('--manual_seed', type=int, default=29)\n    parser.add_argument('--disable_cuda', action='store_true', help='Disable CUDA')\n    parser.add_argument('--parallel', action='store_true', help=\"Run on multiple GPUs\")\n    parser.add_argument('--num_workers', type=int, default=4)\n    # parser.add_argument('--use_tensorboard', action='store_true')\n\n    # Output paths\n    parser.add_argument('--model_weights_dir', type=str, default='weights')\n    parser.add_argument('--sample_images_dir', type=str, default='samples')\n\n    # Model name\n    parser.add_argument('--name', type=str, default='sagan')\n\n    args = parser.parse_args()\n\n    if args.batch_size_in_gpu == 0:\n        args.batch_size_in_gpu = args.batch_size\n\n    assert args.batch_size_in_gpu <= args.batch_size, \"ERROR: please make sure batch_size >= batch_size_in_gpu!! Given batch_size: \" + str(args.batch_size) + \" ; batch_size_in_gpu: \" + str(args.batch_size_in_gpu)\n    assert args.batch_size % args.batch_size_in_gpu == 0, \"ERROR: please make sure batch_size_in_gpu divides batch_size!! Given batch_size: \" + str(args.batch_size) + \" ; batch_size_in_gpu: \" + str(args.batch_size_in_gpu)\n\n    args.batch_size_effective = args.batch_size_in_gpu*(args.batch_size//args.batch_size_in_gpu)\n\n    print(\"Effective BATCH SIZE:\", args.batch_size_effective)\n\n    if args.save_n_images == 0:\n        args.save_n_images = args.batch_size_in_gpu\n\n    assert args.save_n_images <= args.batch_size_in_gpu, \"ERROR: please make save_n_images <= batch_size_in_gpu!! Given save_n_images: \" + str(args.save_n_images) + \" ; batch_size_in_gpu: \" + str(args.batch_size_in_gpu)\n\n    # Corrections\n    args.shuffle = not args.dont_shuffle\n    args.drop_last = not args.dont_drop_last\n    args.resize = not args.dont_resize\n    args.normalize = not args.dont_normalize\n\n    args.dataloader_args = {'num_workers':args.num_workers}\n\n    args.name = '{0:%Y%m%d_%H%M%S}_{1}_{2}'.format(datetime.datetime.now(), args.name, os.path.basename(args.data_path))\n\n    args.save_path = os.path.join(args.save_path, args.name)\n    args.model_weights_path = os.path.join(args.save_path, args.model_weights_dir)\n    args.sample_images_path = os.path.join(args.save_path, args.sample_images_dir)\n\n    return args\n"
  },
  {
    "path": "requirements.txt",
    "content": "matplotlib==3.0.0\ntorchvision==0.2.1\ntorch==2.2.0\nopencv_python==4.2.0.32\nimageio==2.4.1\nnumpy==1.22.0\n"
  },
  {
    "path": "sagan_models.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch.nn.utils import spectral_norm\nfrom torch.nn.init import xavier_uniform_\n\n\ndef init_weights(m):\n    if type(m) == nn.Linear or type(m) == nn.Conv2d:\n        xavier_uniform_(m.weight)\n        m.bias.data.fill_(0.)\n\n\ndef snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):\n    return spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,\n                                   stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias))\n\n\ndef snlinear(in_features, out_features):\n    return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features))\n\n\ndef sn_embedding(num_embeddings, embedding_dim):\n    return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim))\n\n\nclass Self_Attn(nn.Module):\n    \"\"\" Self attention Layer\"\"\"\n\n    def __init__(self, in_channels):\n        super(Self_Attn, self).__init__()\n        self.in_channels = in_channels\n        self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)\n        self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8, kernel_size=1, stride=1, padding=0)\n        self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=1, stride=1, padding=0)\n        self.snconv1x1_attn = snconv2d(in_channels=in_channels//2, out_channels=in_channels, kernel_size=1, stride=1, padding=0)\n        self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)\n        self.softmax  = nn.Softmax(dim=-1)\n        self.sigma = nn.Parameter(torch.zeros(1))\n\n    def forward(self, x):\n        \"\"\"\n            inputs :\n                x : input feature maps(B X C X W X H)\n            returns :\n                out : self attention value + input feature \n                attention: B X N X N (N is Width*Height)\n        \"\"\"\n        _, ch, h, w = x.size()\n        # Theta path\n        theta = self.snconv1x1_theta(x)\n        theta = theta.view(-1, ch//8, h*w)\n        # Phi path\n        phi = self.snconv1x1_phi(x)\n        phi = self.maxpool(phi)\n        phi = phi.view(-1, ch//8, h*w//4)\n        # Attn map\n        attn = torch.bmm(theta.permute(0, 2, 1), phi)\n        attn = self.softmax(attn)\n        # g path\n        g = self.snconv1x1_g(x)\n        g = self.maxpool(g)\n        g = g.view(-1, ch//2, h*w//4)\n        # Attn_g\n        attn_g = torch.bmm(g, attn.permute(0, 2, 1))\n        attn_g = attn_g.view(-1, ch//2, h, w)\n        attn_g = self.snconv1x1_attn(attn_g)\n        # Out\n        out = x + self.sigma*attn_g\n        return out\n\n\nclass ConditionalBatchNorm2d(nn.Module):\n    # https://github.com/pytorch/pytorch/issues/8985#issuecomment-405080775\n    def __init__(self, num_features, num_classes):\n        super().__init__()\n        self.num_features = num_features\n        self.bn = nn.BatchNorm2d(num_features, momentum=0.001, affine=False)\n        self.embed = nn.Embedding(num_classes, num_features * 2)\n        # self.embed.weight.data[:, :num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)\n        self.embed.weight.data[:, :num_features].fill_(1.)  # Initialize scale to 1\n        self.embed.weight.data[:, num_features:].zero_()  # Initialize bias at 0\n\n    def forward(self, x, y):\n        out = self.bn(x)\n        gamma, beta = self.embed(y).chunk(2, 1)\n        out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1)\n        return out\n\n\nclass GenBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, num_classes):\n        super(GenBlock, self).__init__()\n        self.cond_bn1 = ConditionalBatchNorm2d(in_channels, num_classes)\n        self.relu = nn.ReLU(inplace=True)\n        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.cond_bn2 = ConditionalBatchNorm2d(out_channels, num_classes)\n        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x, labels):\n        x0 = x\n\n        x = self.cond_bn1(x, labels)\n        x = self.relu(x)\n        x = F.interpolate(x, scale_factor=2, mode='nearest') # upsample\n        x = self.snconv2d1(x)\n        x = self.cond_bn2(x, labels)\n        x = self.relu(x)\n        x = self.snconv2d2(x)\n\n        x0 = F.interpolate(x0, scale_factor=2, mode='nearest') # upsample\n        x0 = self.snconv2d0(x0)\n\n        out = x + x0\n        return out\n\n\nclass Generator(nn.Module):\n    \"\"\"Generator.\"\"\"\n\n    def __init__(self, z_dim, g_conv_dim, num_classes):\n        super(Generator, self).__init__()\n\n        self.z_dim = z_dim\n        self.g_conv_dim = g_conv_dim\n        self.snlinear0 = snlinear(in_features=z_dim, out_features=g_conv_dim*16*4*4)\n        self.block1 = GenBlock(g_conv_dim*16, g_conv_dim*16, num_classes)\n        self.block2 = GenBlock(g_conv_dim*16, g_conv_dim*8, num_classes)\n        self.block3 = GenBlock(g_conv_dim*8, g_conv_dim*4, num_classes)\n        self.self_attn = Self_Attn(g_conv_dim*4)\n        self.block4 = GenBlock(g_conv_dim*4, g_conv_dim*2, num_classes)\n        self.block5 = GenBlock(g_conv_dim*2, g_conv_dim, num_classes)\n        self.bn = nn.BatchNorm2d(g_conv_dim, eps=1e-5, momentum=0.0001, affine=True)\n        self.relu = nn.ReLU(inplace=True)\n        self.snconv2d1 = snconv2d(in_channels=g_conv_dim, out_channels=3, kernel_size=3, stride=1, padding=1)\n        self.tanh = nn.Tanh()\n\n        # Weight init\n        self.apply(init_weights)\n\n    def forward(self, z, labels):\n        # n x z_dim\n        act0 = self.snlinear0(z)            # n x g_conv_dim*16*4*4\n        act0 = act0.view(-1, self.g_conv_dim*16, 4, 4) # n x g_conv_dim*16 x 4 x 4\n        act1 = self.block1(act0, labels)    # n x g_conv_dim*16 x 8 x 8\n        act2 = self.block2(act1, labels)    # n x g_conv_dim*8 x 16 x 16\n        act3 = self.block3(act2, labels)    # n x g_conv_dim*4 x 32 x 32\n        act3 = self.self_attn(act3)         # n x g_conv_dim*4 x 32 x 32\n        act4 = self.block4(act3, labels)    # n x g_conv_dim*2 x 64 x 64\n        act5 = self.block5(act4, labels)    # n x g_conv_dim  x 128 x 128\n        act5 = self.bn(act5)                # n x g_conv_dim  x 128 x 128\n        act5 = self.relu(act5)              # n x g_conv_dim  x 128 x 128\n        act6 = self.snconv2d1(act5)         # n x 3 x 128 x 128\n        act6 = self.tanh(act6)              # n x 3 x 128 x 128\n        return act6\n\n\nclass DiscOptBlock(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(DiscOptBlock, self).__init__()\n        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.downsample = nn.AvgPool2d(2)\n        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x):\n        x0 = x\n\n        x = self.snconv2d1(x)\n        x = self.relu(x)\n        x = self.snconv2d2(x)\n        x = self.downsample(x)\n\n        x0 = self.downsample(x0)\n        x0 = self.snconv2d0(x0)\n\n        out = x + x0\n        return out\n\n\nclass DiscBlock(nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super(DiscBlock, self).__init__()\n        self.relu = nn.ReLU(inplace=True)\n        self.snconv2d1 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.snconv2d2 = snconv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)\n        self.downsample = nn.AvgPool2d(2)\n        self.ch_mismatch = False\n        if in_channels != out_channels:\n            self.ch_mismatch = True\n        self.snconv2d0 = snconv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)\n\n    def forward(self, x, downsample=True):\n        x0 = x\n\n        x = self.relu(x)\n        x = self.snconv2d1(x)\n        x = self.relu(x)\n        x = self.snconv2d2(x)\n        if downsample:\n            x = self.downsample(x)\n\n        if downsample or self.ch_mismatch:\n            x0 = self.snconv2d0(x0)\n            if downsample:\n                x0 = self.downsample(x0)\n\n        out = x + x0\n        return out\n\n\nclass Discriminator(nn.Module):\n    \"\"\"Discriminator.\"\"\"\n\n    def __init__(self, d_conv_dim, num_classes):\n        super(Discriminator, self).__init__()\n        self.d_conv_dim = d_conv_dim\n        self.opt_block1 = DiscOptBlock(3, d_conv_dim)\n        self.block1 = DiscBlock(d_conv_dim, d_conv_dim*2)\n        self.self_attn = Self_Attn(d_conv_dim*2)\n        self.block2 = DiscBlock(d_conv_dim*2, d_conv_dim*4)\n        self.block3 = DiscBlock(d_conv_dim*4, d_conv_dim*8)\n        self.block4 = DiscBlock(d_conv_dim*8, d_conv_dim*16)\n        self.block5 = DiscBlock(d_conv_dim*16, d_conv_dim*16)\n        self.relu = nn.ReLU(inplace=True)\n        self.snlinear1 = snlinear(in_features=d_conv_dim*16, out_features=1)\n        self.sn_embedding1 = sn_embedding(num_classes, d_conv_dim*16)\n\n        # Weight init\n        self.apply(init_weights)\n        xavier_uniform_(self.sn_embedding1.weight)\n\n    def forward(self, x, labels):\n        # n x 3 x 128 x 128\n        h0 = self.opt_block1(x) # n x d_conv_dim   x 64 x 64\n        h1 = self.block1(h0)    # n x d_conv_dim*2 x 32 x 32\n        h1 = self.self_attn(h1) # n x d_conv_dim*2 x 32 x 32\n        h2 = self.block2(h1)    # n x d_conv_dim*4 x 16 x 16\n        h3 = self.block3(h2)    # n x d_conv_dim*8 x  8 x  8\n        h4 = self.block4(h3)    # n x d_conv_dim*16 x 4 x  4\n        h5 = self.block5(h4, downsample=False)  # n x d_conv_dim*16 x 4 x 4\n        h5 = self.relu(h5)              # n x d_conv_dim*16 x 4 x 4\n        h6 = torch.sum(h5, dim=[2,3])   # n x d_conv_dim*16\n        output1 = torch.squeeze(self.snlinear1(h6)) # n\n        # Projection\n        h_labels = self.sn_embedding1(labels)   # n x d_conv_dim*16\n        proj = torch.mul(h6, h_labels)          # n x d_conv_dim*16\n        output2 = torch.sum(proj, dim=[1])      # n\n        # Out\n        output = output1 + output2              # n\n        return output\n"
  },
  {
    "path": "test.py",
    "content": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom sagan_models import Generator, Discriminator\n\n\nif __name__ == '__main__':\n    config = get_parameters()\n    config.command = 'python ' + ' '.join(sys.argv)\n    print(config)\n    utils.check_for_CUDA(config)\n\n    # Load pretrained model (if provided)\n    if config.pretrained_model != '':\n        utils.load_pretrained_model(config)\n    else:\n        assert config.num_of_classes, \"Please provide number of classes! Eg. python3 test.py --num_of_classes 10\"\n        config.G = Generator(config.z_dim, config.g_conv_dim, config.num_of_classes).to(config.device)\n        config.D = Discriminator(config.d_conv_dim, config.num_of_classes).to(config.device)\n\n    config.G.eval()\n    config.D.eval()\n    print(config.G, config.D)\n"
  },
  {
    "path": "train.py",
    "content": "import sys\n\nimport utils\n\nfrom parameters import *\nfrom trainer import Trainer\n\n\nif __name__ == '__main__':\n    config = get_parameters()\n    config.command = 'python ' + ' '.join(sys.argv)\n    print(config)\n    trainer = Trainer(config)\n    trainer.train()\n    utils.save_ckpt(trainer, final=True)\n"
  },
  {
    "path": "trainer.py",
    "content": "import datetime\nimport numpy as np\nimport os\nimport random\nimport sys\nimport time\nimport torch\nimport torch.nn as nn\nimport torchvision.utils as vutils\n\nfrom torch.backends import cudnn\n\nimport utils\nfrom sagan_models import Generator, Discriminator\n\n\nclass Trainer(object):\n\n    def __init__(self, config):\n\n        # Config\n        self.config = config\n\n        self.start = 0 # Unless using pre-trained model\n\n        # Create directories if not exist\n        utils.make_folder(self.config.save_path)\n        utils.make_folder(self.config.model_weights_path)\n        utils.make_folder(self.config.sample_images_path)\n\n        # Copy files\n        utils.write_config_to_file(self.config, self.config.save_path)\n        utils.copy_scripts(self.config.save_path)\n\n        # Check for CUDA\n        utils.check_for_CUDA(self)\n\n        # Make dataloader\n        self.dataloader, self.num_of_classes = utils.make_dataloader(batch_size=self.config.batch_size_in_gpu,\n                                                                     dataset_type=self.config.dataset,\n                                                                     data_path=self.config.data_path,\n                                                                     shuffle=self.config.shuffle,\n                                                                     drop_last=self.config.drop_last,\n                                                                     dataloader_args=self.config.dataloader_args,\n                                                                     resize=self.config.resize,\n                                                                     imsize=self.config.imsize,\n                                                                     centercrop=self.config.centercrop,\n                                                                     centercrop_size=self.config.centercrop_size,\n                                                                     normalize=self.config.normalize,\n                                                                     )\n\n        # Data iterator\n        self.data_iter = iter(self.dataloader)\n\n        # Build G and D\n        self.build_models()\n\n        if self.config.adv_loss == 'dcgan':\n            self.criterion = nn.BCELoss()\n\n    def train(self):\n\n        # Seed\n        np.random.seed(self.config.manual_seed)\n        random.seed(self.config.manual_seed)\n        torch.manual_seed(self.config.manual_seed)\n\n        # For fast training\n        cudnn.benchmark = True\n\n        # For BatchNorm\n        self.G.train()\n        self.D.train()\n\n        # Fixed noise for sampling from G\n        fixed_noise = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)\n        if self.num_of_classes < self.config.batch_size_in_gpu:\n            fixed_labels = torch.from_numpy(np.tile(np.arange(self.num_of_classes), self.config.batch_size_in_gpu//self.num_of_classes + 1)[:self.config.batch_size_in_gpu]).to(self.device)\n        else:\n            fixed_labels = torch.from_numpy(np.arange(self.config.batch_size_in_gpu)).to(self.device)\n\n        # For gan loss\n        label = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)\n        ones = torch.full((self.config.batch_size_in_gpu,), 1, device=self.device)\n\n        # Losses file\n        log_file_name = os.path.join(self.config.save_path, 'log.txt')\n        log_file = open(log_file_name, \"wt\")\n\n        # Init\n        start_time = time.time()\n        G_losses = []\n        D_losses_real = []\n        D_losses_fake = []\n        D_losses = []\n        D_xs = []\n        D_Gz_trainDs = []\n        D_Gz_trainGs = []\n\n        # Instance noise - make random noise mean (0) and std for injecting\n        inst_noise_mean = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), 0, device=self.device)\n        inst_noise_std = torch.full((self.config.batch_size_in_gpu, 3, self.config.imsize, self.config.imsize), self.config.inst_noise_sigma, device=self.device)\n\n        self.gpu_batches = self.config.batch_size//self.config.batch_size_in_gpu\n\n        # Start training\n        for self.step in range(self.start, self.config.total_step):\n\n            # Instance noise std is linearly annealed from self.inst_noise_sigma to 0 thru self.inst_noise_sigma_iters\n            inst_noise_sigma_curr = 0 if self.step > self.config.inst_noise_sigma_iters else (1 - self.step/self.config.inst_noise_sigma_iters)*self.config.inst_noise_sigma\n            inst_noise_std.fill_(inst_noise_sigma_curr)\n\n            # ================== TRAIN D ================== #\n\n            for _ in range(self.config.d_steps_per_iter):\n\n                # Zero grad\n                self.reset_grad()\n\n                # Accumulate losses for full batch_size\n                # while running GPU computations on only batch_size_in_gpu\n                for gpu_batch in range(self.gpu_batches):\n\n                    # TRAIN with REAL\n\n                    # Get real images & real labels\n                    real_images, real_labels = self.get_real_samples()\n\n                    # Get D output for real images & real labels\n                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)\n                    d_out_real = self.D(real_images + inst_noise, real_labels)\n\n                    # Compute D loss with real images & real labels\n                    if self.config.adv_loss == 'hinge':\n                        d_loss_real = torch.nn.ReLU()(ones - d_out_real).mean()\n                    elif self.config.adv_loss == 'wgan_gp':\n                        d_loss_real = -d_out_real.mean()\n                    else:\n                        label.fill_(1)\n                        d_loss_real = self.criterion(d_out_real, label)\n\n                    # Backward\n                    d_loss_real /= self.gpu_batches\n                    d_loss_real.backward()\n\n                    # Delete loss, output\n                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:\n                        del d_out_real, d_loss_real\n\n                    # TRAIN with FAKE\n\n                    # Create random noise\n                    z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim, device=self.device)\n\n                    # Generate fake images for same real labels\n                    fake_images = self.G(z, real_labels)\n\n                    # Get D output for fake images & same real labels\n                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)\n                    d_out_fake = self.D(fake_images.detach() + inst_noise, real_labels)\n\n                    # Compute D loss with fake images & real labels\n                    if self.config.adv_loss == 'hinge':\n                        d_loss_fake = torch.nn.ReLU()(ones + d_out_fake).mean()\n                    elif self.config.adv_loss == 'dcgan':\n                        label.fill_(0)\n                        d_loss_fake = self.criterion(d_out_fake, label)\n                    else:\n                        d_loss_fake = d_out_fake.mean()\n\n                    # If WGAN_GP, compute GP and add to D loss\n                    if self.config.adv_loss == 'wgan_gp':\n                        d_loss_gp = self.config.lambda_gp * self.compute_gradient_penalty(real_images, real_labels, fake_images.detach())\n                        d_loss_fake += d_loss_gp\n\n                    # Backward\n                    d_loss_fake /= self.gpu_batches\n                    d_loss_fake.backward()\n\n                    # Delete loss, output\n                    del fake_images\n                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:\n                        del d_out_fake, d_loss_fake\n\n                # Optimize\n                self.D_optimizer.step()\n\n            # ================== TRAIN G ================== #\n\n            for _ in range(self.config.g_steps_per_iter):\n\n                # Zero grad\n                self.reset_grad()\n\n                # Accumulate losses for full batch_size\n                # while running GPU computations on only batch_size_in_gpu\n                for gpu_batch in range(self.gpu_batches):\n\n                    # Get real images & real labels (only need real labels)\n                    real_images, real_labels = self.get_real_samples()\n\n                    # Create random noise\n                    z = torch.randn(self.config.batch_size_in_gpu, self.config.z_dim).to(self.device)\n\n                    # Generate fake images for same real labels\n                    fake_images = self.G(z, real_labels)\n\n                    # Get D output for fake images & same real labels\n                    inst_noise = torch.normal(mean=inst_noise_mean, std=inst_noise_std).to(self.device)\n                    g_out_fake = self.D(fake_images + inst_noise, real_labels)\n\n                    # Compute G loss with fake images & real labels\n                    if self.config.adv_loss == 'dcgan':\n                        label.fill_(1)\n                        g_loss = self.criterion(g_out_fake, label)\n                    else:\n                        g_loss = -g_out_fake.mean()\n\n                    # Backward\n                    g_loss /= self.gpu_batches\n                    g_loss.backward()\n\n                    # Delete loss, output\n                    del fake_images\n                    if self.step % self.config.log_step != 0 or gpu_batch < self.gpu_batches - 1:\n                        del g_out_fake, g_loss\n\n                # Optimize\n                self.G_optimizer.step()\n\n            # Print out log info\n            if self.step % self.config.log_step == 0:\n                G_losses.append(g_loss.mean().item())\n                D_losses_real.append(d_loss_real.mean().item())\n                D_losses_fake.append(d_loss_fake.mean().item())\n                D_loss = D_losses_real[-1] + D_losses_fake[-1]\n                if self.config.adv_loss == 'wgan_gp':\n                    D_loss += d_loss_gp.mean().item()\n                D_losses.append(D_loss)\n                D_xs.append(d_out_real.mean().item())\n                D_Gz_trainDs.append(d_out_fake.mean().item())\n                D_Gz_trainGs.append(g_out_fake.mean().item())\n                curr_time = time.time()\n                curr_time_str = datetime.datetime.fromtimestamp(curr_time).strftime('%Y-%m-%d %H:%M:%S')\n                elapsed = str(datetime.timedelta(seconds=(curr_time - start_time)))\n                log = (\"[{}] : Elapsed [{}], Iter [{} / {}], G_loss: {:.4f}, D_loss: {:.4f}, D_loss_real: {:.4f}, D_loss_fake: {:.4f}, D(x): {:.4f}, D(G(z))_trainD: {:.4f}, D(G(z))_trainG: {:.4f}\\n\".\n                       format(curr_time_str, elapsed, self.step, self.config.total_step,\n                              G_losses[-1], D_losses[-1], D_losses_real[-1], D_losses_fake[-1],\n                              D_xs[-1], D_Gz_trainDs[-1], D_Gz_trainGs[-1]))\n                print('\\n' + log)\n                log_file.write(log)\n                log_file.flush()\n                utils.make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs,\n                                 self.config.log_step, self.config.save_path)\n\n                # Delete loss, output\n                del d_out_real, d_loss_real, d_out_fake, d_loss_fake, g_out_fake, g_loss\n\n            # Sample images\n            if self.step % self.config.sample_step == 0:\n                print(\"Saving image samples..\")\n                self.G.eval()\n                fake_images = self.G(fixed_noise, fixed_labels)\n                self.G.train()\n                sample_images = utils.denorm(fake_images.detach()[:self.config.save_n_images])\n                # Save batch images\n                vutils.save_image(sample_images, os.path.join(self.config.sample_images_path, 'fake_{:05d}.png'.format(self.step)), nrow=self.config.nrow)\n                # Save gif\n                utils.make_gif(sample_images[0].cpu().numpy().transpose(1, 2, 0)*255, self.step,\n                               self.config.sample_images_path, self.config.name, max_frames_per_gif=self.config.max_frames_per_gif)\n                # Delete output\n                del fake_images\n\n            # Save model\n            if self.step % self.config.model_save_step == 0:\n                utils.save_ckpt(self)\n\n    def build_models(self):\n        self.G = Generator(self.config.z_dim, self.config.g_conv_dim, self.num_of_classes).to(self.device)\n        self.D = Discriminator(self.config.d_conv_dim, self.num_of_classes).to(self.device)\n\n        # Loss and optimizer\n        # self.G_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])\n        self.G_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.G.parameters()), self.config.g_lr, [self.config.beta1, self.config.beta2])\n        self.D_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.D.parameters()), self.config.d_lr, [self.config.beta1, self.config.beta2])\n\n        # Start with pretrained model (if it exists)\n        if self.config.pretrained_model != '':\n            utils.load_pretrained_model(self)\n\n        if 'cuda' in self.device.type and self.config.parallel and torch.cuda.device_count() > 1:\n            self.G = nn.DataParallel(self.G)\n            self.D = nn.DataParallel(self.D)\n\n        # print networks\n        print(self.G)\n        print(self.D)\n\n    def reset_grad(self):\n        self.G_optimizer.zero_grad()\n        self.D_optimizer.zero_grad()\n\n    def get_real_samples(self):\n        try:\n            real_images, real_labels = next(self.data_iter)\n        except:\n            self.data_iter = iter(self.dataloader)\n            real_images, real_labels = next(self.data_iter)\n\n        real_images, real_labels = real_images.to(self.device), real_labels.to(self.device)\n        return real_images, real_labels\n\n    def compute_gradient_penalty(self, real_images, real_labels, fake_images):\n        # Compute gradient penalty\n        alpha = torch.rand(real_images.size(0), 1, 1, 1).expand_as(real_images).to(device)\n        interpolated = torch.tensor(alpha * real_images + (1 - alpha) * fake_images, requires_grad=True)\n        out = self.D(interpolated, real_labels)\n        exp_grad = torch.ones(out.size()).to(device)\n        grad = torch.autograd.grad(outputs=out,\n                                   inputs=interpolated,\n                                   grad_outputs=exp_grad,\n                                   retain_graph=True,\n                                   create_graph=True,\n                                   only_inputs=True)[0]\n        grad = grad.view(grad.size(0), -1)\n        grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))\n        d_loss_gp = torch.mean((grad_l2norm - 1) ** 2)\n        return d_loss_gp\n"
  },
  {
    "path": "utils.py",
    "content": "import cv2\nimport glob\nimport imageio\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport os\nimport shutil\nimport torch\nimport torchvision.datasets as dset\n\nfrom torchvision import transforms\n\n\ndef make_folder(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n\n\ndef denorm(x):\n    out = (x + 1) / 2\n    return out.clamp_(0, 1)\n\n\ndef write_config_to_file(config, save_path):\n    with open(os.path.join(save_path, 'config.txt'), 'w') as file:\n        for arg in vars(config):\n            file.write(str(arg) + ': ' + str(getattr(config, arg)) + '\\n')\n\n\ndef copy_scripts(dst):\n    for file in glob.glob('*.py'):\n        shutil.copy(file, dst)\n\n    for d in glob.glob('*/'):\n        if '__' not in d and d[0] != '.':\n            shutil.copytree(d, os.path.join(dst, d))\n\n\ndef make_transform(resize=True, imsize=128, centercrop=False, centercrop_size=128,\n                   totensor=True, normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):\n        options = []\n        if resize:\n            options.append(transforms.Resize((imsize)))\n        if centercrop:\n            options.append(transforms.CenterCrop(centercrop_size))\n        if totensor:\n            options.append(transforms.ToTensor())\n        if normalize:\n            options.append(transforms.Normalize(norm_mean, norm_std))\n        transform = transforms.Compose(options)\n        return transform\n\n\ndef make_dataloader(batch_size, dataset_type, data_path, shuffle=True, drop_last=True, dataloader_args={},\n                    resize=True, imsize=128, centercrop=False, centercrop_size=128, totensor=True,\n                    normalize=True, norm_mean=(0.5, 0.5, 0.5), norm_std=(0.5, 0.5, 0.5)):\n    # Make transform\n    transform = make_transform(resize=resize, imsize=imsize,\n                               centercrop=centercrop, centercrop_size=centercrop_size,\n                               totensor=totensor, normalize=normalize, norm_mean=norm_mean, norm_std=norm_std)\n    # Make dataset\n    if dataset_type in ['folder', 'imagenet', 'lfw']:\n        # folder dataset\n        assert os.path.exists(data_path), \"data_path does not exist! Given: \" + data_path\n        dataset = dset.ImageFolder(root=data_path, transform=transform)\n    elif dataset_type == 'lsun':\n        assert os.path.exists(data_path), \"data_path does not exist! Given: \" + data_path\n        dataset = dset.LSUN(root=data_path, classes=['bedroom_train'], transform=transform)\n    elif dataset_type == 'cifar10':\n        if not os.path.exists(data_path):\n            print(\"data_path does not exist! Given: {}\\nDownloading CIFAR10 dataset...\".format(data_path))\n        dataset = dset.CIFAR10(root=data_path, download=True, transform=transform)\n    elif dataset_type == 'fake':\n        dataset = dset.FakeData(image_size=(3, centercrop_size, centercrop_size), transform=transforms.ToTensor())\n    assert dataset\n    num_of_classes = len(dataset.classes)\n    print(\"Data found!  # of images =\", len(dataset), \", # of classes =\", num_of_classes, \", classes:\", dataset.classes)\n    # Make dataloader from dataset\n    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **dataloader_args)\n    return dataloader, num_of_classes\n\n\ndef make_gif(image, iteration_number, save_path, model_name, max_frames_per_gif=100):\n\n    # Make gif\n    gif_frames = []\n\n    # Read old gif frames\n    try:\n        gif_frames_reader = imageio.get_reader(os.path.join(save_path, model_name + \".gif\"))\n        for frame in gif_frames_reader:\n            gif_frames.append(frame[:, :, :3])\n    except:\n        pass\n\n    # Append new frame\n    im = cv2.putText(np.concatenate((np.zeros((32, image.shape[1], image.shape[2])), image), axis=0),\n                     'iter %s' % str(iteration_number), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, .5, (255, 255, 255), 1, cv2.LINE_AA).astype('uint8')\n    gif_frames.append(im)\n\n    # If frames exceeds, save as different file\n    if len(gif_frames) > max_frames_per_gif:\n        print(\"Splitting the GIF...\")\n        gif_frames_00 = gif_frames[:max_frames_per_gif]\n        num_of_gifs_already_saved = len(glob.glob(os.path.join(save_path, model_name + \"_*.gif\")))\n        print(\"Saving\", os.path.join(save_path, model_name + \"_%05d.gif\" % (num_of_gifs_already_saved)))\n        imageio.mimsave(os.path.join(save_path, model_name + \"_%05d.gif\" % (num_of_gifs_already_saved)), gif_frames_00)\n        gif_frames = gif_frames[max_frames_per_gif:]\n\n    # Save gif\n    # print(\"Saving\", os.path.join(save_path, model_name + \".gif\"))\n    imageio.mimsave(os.path.join(save_path, model_name + \".gif\"), gif_frames)\n\n\ndef make_plots(G_losses, D_losses, D_losses_real, D_losses_fake, D_xs, D_Gz_trainDs, D_Gz_trainGs, log_step, save_path, init_epoch=0):\n    iters = np.arange(len(D_losses))*log_step + init_epoch\n    fig = plt.figure(figsize=(20, 20))\n    plt.subplot(311)\n    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)\n    plt.plot(iters, G_losses, color='C0', label='G')\n    plt.legend()\n    plt.title(\"Generator loss\")\n    plt.xlabel(\"Iterations\")\n    plt.subplot(312)\n    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)\n    plt.plot(iters, D_losses_real, color='C1', alpha=0.7, label='D_real')\n    plt.plot(iters, D_losses_fake, color='C2', alpha=0.7, label='D_fake')\n    plt.plot(iters, D_losses, color='C0', alpha=0.7, label='D')\n    plt.legend()\n    plt.title(\"Discriminator loss\")\n    plt.xlabel(\"Iterations\")\n    plt.subplot(313)\n    plt.plot(iters, np.zeros(iters.shape), 'k--', alpha=0.5)\n    plt.plot(iters, np.ones(iters.shape), 'k--', alpha=0.5)\n    plt.plot(iters, D_xs, alpha=0.7, label='D(x)')\n    plt.plot(iters, D_Gz_trainDs, alpha=0.7, label='D(G(z))_trainD')\n    plt.plot(iters, D_Gz_trainGs, alpha=0.7, label='D(G(z))_trainG')\n    plt.legend()\n    plt.title(\"D(x), D(G(z))\")\n    plt.xlabel(\"Iterations\")\n    plt.savefig(os.path.join(save_path, \"plots.png\"))\n    plt.clf()\n    plt.close()\n\n\ndef save_ckpt(sagan_obj, model=False, final=False):\n    print(\"Saving ckpt...\")\n\n    if final:\n        # Save final - both model and state_dict\n        torch.save({\n                    'step': sagan_obj.step,\n                    'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, \"module\") else sagan_obj.G.state_dict(),    # \"module\" in case DataParallel is used\n                    'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),\n                    'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, \"module\") else sagan_obj.D.state_dict(),    # \"module\" in case DataParallel is used,\n                    'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),\n                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_state_dict_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))\n        torch.save({\n                    'step': sagan_obj.step,\n                    'G': sagan_obj.G.module if hasattr(sagan_obj.G, \"module\") else sagan_obj.G,\n                    'G_optimizer': sagan_obj.G_optimizer,\n                    'D': sagan_obj.D.module if hasattr(sagan_obj.D, \"module\") else sagan_obj.D,\n                    'D_optimizer': sagan_obj.D_optimizer,\n                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_final_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))\n\n    elif model:\n        # Save full model (not state_dict)\n        torch.save({\n                    'step': sagan_obj.step,\n                    'G': sagan_obj.G.module if hasattr(sagan_obj.G, \"module\") else sagan_obj.G,     # \"module\" in case DataParallel is used\n                    'G_optimizer': sagan_obj.G_optimizer,\n                    'D': sagan_obj.D.module if hasattr(sagan_obj.D, \"module\") else sagan_obj.D,     # \"module\" in case DataParallel is used\n                    'D_optimizer': sagan_obj.D_optimizer,\n                    }, os.path.join(sagan_obj.config.model_weights_path, '{}_model_ckpt_{:07d}.pth'.format(sagan_obj.config.name, sagan_obj.step)))\n\n    else:\n        # Save state_dict\n        torch.save({\n                    'step': sagan_obj.step,\n                    'G_state_dict': sagan_obj.G.module.state_dict() if hasattr(sagan_obj.G, \"module\") else sagan_obj.G.state_dict(),\n                    'G_optimizer_state_dict': sagan_obj.G_optimizer.state_dict(),\n                    'D_state_dict': sagan_obj.D.module.state_dict() if hasattr(sagan_obj.D, \"module\") else sagan_obj.D.state_dict(),\n                    'D_optimizer_state_dict': sagan_obj.D_optimizer.state_dict(),\n                    }, os.path.join(sagan_obj.config.model_weights_path, 'ckpt_{:07d}.pth'.format(sagan_obj.step)))\n\n\ndef load_pretrained_model(sagan_obj):\n    print(\"Loading pretrained_model\", sagan_obj.config.pretrained_model, \"...\")\n    # Check for path\n    assert os.path.exists(sagan_obj.config.pretrained_model), \"Path of .pth pretrained_model doesn't exist! Given: \" + sagan_obj.config.pretrained_model\n    checkpoint = torch.load(sagan_obj.config.pretrained_model)\n    # If we know it is a state_dict (instead of complete model)\n    if sagan_obj.config.state_dict_or_model == 'state_dict':\n        sagan_obj.start = checkpoint['step'] + 1\n        sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])\n        sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])\n        sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])\n        sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])\n    # Else, if we know it is a complete model (and not just state_dict)\n    elif sagan_obj.config.state_dict_or_model == 'model':\n        sagan_obj.start = checkpoint['step'] + 1\n        sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)\n        sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])\n        sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)\n        sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])\n    # Else try for complete model, then try for state_dict\n    else:\n        try:\n            sagan_obj.start = checkpoint['step'] + 1\n            sagan_obj.G.load_state_dict(checkpoint['G_state_dict'])\n            sagan_obj.G_optimizer.load_state_dict(checkpoint['G_optimizer_state_dict'])\n            sagan_obj.D.load_state_dict(checkpoint['D_state_dict'])\n            sagan_obj.D_optimizer.load_state_dict(checkpoint['D_optimizer_state_dict'])\n        except:\n            sagan_obj.start = checkpoint['step'] + 1\n            sagan_obj.G = torch.load(checkpoint['G']).to(sagan_obj.device)\n            sagan_obj.G_optimizer = torch.load(checkpoint['G_optimizer'])\n            sagan_obj.D = torch.load(checkpoint['D']).to(sagan_obj.device)\n            sagan_obj.D_optimizer = torch.load(checkpoint['D_optimizer'])\n\n\ndef check_for_CUDA(sagan_obj):\n    if not sagan_obj.config.disable_cuda and torch.cuda.is_available():\n        print(\"CUDA is available!\")\n        sagan_obj.device = torch.device('cuda')\n        sagan_obj.config.dataloader_args['pin_memory'] = True\n    else:\n        print(\"Cuda is NOT available, running on CPU.\")\n        sagan_obj.device = torch.device('cpu')\n\n    if torch.cuda.is_available() and sagan_obj.config.disable_cuda:\n        print(\"WARNING: You have a CUDA device, so you should probably run without --disable_cuda\")\n"
  }
]