[
  {
    "path": ".gitignore",
    "content": "*/**/__pycache__\n*.egg-info\n/dist\n*.pyc\n.idea"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Pablo Pernías\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\nFooter\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include LICENSE\ninclude readme.md\nrecursive-include torchtools/transforms/models *"
  },
  {
    "path": "readme.md",
    "content": "# Pytorch Tools\n\n## Install\n\nRequirements:\n```\nPyTorch >= 1.0.0\nTorchivision\nNumpy >= 1.0.0\n```\n\n```\n# In order to install the latest (beta) use\npip install git+https://github.com/pabloppp/pytorch-tools -U\n\n# if you want to install a specific version to avoid breaking changes (for example, v0.3.5), use \npip install git+https://github.com/pabloppp/pytorch-tools@0.3.5 -U\n```\n\n# Current available tools\n\n## Optimizers\n\nComparison table taken from https://github.com/mgrankin/over9000\nAnd the article explaining this recent improvements https://medium.com/@lessw/how-we-beat-the-fastai-leaderboard-score-by-19-77-a-cbb2338fab5c\n\nDataset                               | LR Schedule| Imagenette size 128, 5 epoch | Imagewoof size 128, 5 epoch\n---                                   | -- | ---                          | ---\nAdam - baseline                |OneCycle| 0.8493                       | 0.6125\nRangerLars (RAdam + LARS + Lookahead) |Flat and anneal| 0.8732                       | 0.6523\nRalamb (RAdam + LARS)                 |Flat and anneal| 0.8675                       | 0.6367\nRanger (RAdam + Lookahead)            |Flat and anneal| 0.8594                       | 0.5946\nNovograd                              |Flat and anneal| 0.8711                       | 0.6126\nRadam                                 |Flat and anneal| 0.8444                       | 0.537\nLookahead                             |OneCycle| 0.8578                       | 0.6106\nLamb                                  |OneCycle| 0.8400                       | 0.5597\n\n### Ranger\nTaken as is from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer  \nBlog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d\n\nExample of use:\n```python\nfrom torchtools.optim import Ranger\n\noptimizer = Ranger(model.parameters())\n```\n\n### RAdam\nTaken as is from https://github.com/LiyuanLucasLiu/RAdam  \nBlog post: https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b  \nOriginal Paper: https://arxiv.org/abs/1908.03265  \n\nExample of use:\n```python\nfrom torchtools.optim import RAdam, PlainRAdam, AdamW\n\noptimizer = RAdam(model.parameters()) \n# optimizer = PlainRAdam(model.parameters()) \n# optimizer = AdamW(model.parameters()) \n```\n\n### RangerLars (Over9000) \nTaken as is from https://github.com/mgrankin/over9000\n\nExample of use:\n```python\nfrom torchtools.optim import RangerLars # Over9000\n\noptimizer = RangerLars(model.parameters())\n```\n\n### Novograd \nTaken as is from https://github.com/mgrankin/over9000\n\nExample of use:\n```python\nfrom torchtools.optim import Novograd\n\noptimizer = Novograd(model.parameters())\n```\n\n### Ralamb \nTaken as is from https://github.com/mgrankin/over9000\n\nExample of use:\n```python\nfrom torchtools.optim import Ralamb\n\noptimizer = Ralamb(model.parameters())\n```\n \n### Lookahead\nTaken as is from https://github.com/lonePatient/lookahead_pytorch  \nOriginal Paper: https://arxiv.org/abs/1907.08610  \n\nThis lookahead can be used with any optimizer\n\nExample of use:\n```python\nfrom torch import optim\nfrom torchtools.optim import Lookahead\n\noptimizer = optim.Adam(model.parameters(), lr=0.001)\noptimizer = Lookahead(base_optimizer=optimizer, k=5, alpha=0.5)\n\n# for a base Lookahead + Adam you can just do:\n# \n# from torchtools.optim import LookaheadAdam\n```\n\n### Lamb\nTaken as is from https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py\nOriginal Paper: https://arxiv.org/abs/1904.00962\n\nExample of use:\n```python\nfrom torchtools.optim import Lamb\n\noptimizer = Lamb(model.parameters())\n```\n\n## LR Schedulers\n\n### Delayed LR\nAllows for a customizable number of initial steps where the learning rate remains fixed.  \nAfter those steps the learning rate will be updated according to the supplied scheduler's policy\n\nExample of use:\n```python\nfrom torch import optim, nn\nfrom torchtools.lr_scheduler import DelayerScheduler\n\noptimizer = optim.Adam(model.parameters(), lr=0.001) # define here your optimizer, the lr that you set will be the one used for the initial delay steps\n\ndelay_epochs = 10\ntotal_epochs = 20\nbase_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, delay_epochs) # delay the scheduler for 10 steps\ndelayed_scheduler = DelayerScheduler(optimizer, total_epochs - delay_epochs, base_scheduler)\n\nfor epoch in range(total_epochs):\n\t# train(...)\n\tdelayed_scheduler.step()\n\n\t# The lr will be 0.001 for the first 10 epochs, then will use the policy fro the base_scheduler for the rest of the epochs\n\n\n# for a base DelayerScheduler + CosineAnnealingLR you can just do:\n#\n# from torchtools.lr_scheduler import DelayedCosineAnnealingLR\n# scheduler = DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs) # the sum of both must be the total number of epochs\n```\n\n## Activations\n\n### Mish\nOriginal implementation: https://github.com/digantamisra98/Mish  \nOriginal Paper: https://arxiv.org/abs/1908.08681v1  \nImplementation taken as is from https://github.com/lessw2020/mish  \n\nExample of use:\n```python\nfrom torchtools.nn import Mish\n\n# Then you can just use Mish as a replacement for any activation function, like ReLU\n```\n\n### AliasFreeActivation\nImplementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 by Rosinality\nI modularized this activation so it can be easily used inside of any model without having to deal with complex initialization.\n\nThis activation actually takes a lot of responsibility, since it internally defines the channels and size of the input based on a set of parameters, instead of receiving them as a parameter, this means that the rest of the layers (convolutions, positional embedding, etc...) must adapt to it.\n\nExample of use:\n```python\nfrom torchtools.nn.alias_free_activation import AliasFreeActivation\nfrom torchtools.nn import EqualLeakyReLU\n\n# We can use the static function to get the filter parameters for a specific level. \n# It can be specially usefull to obtain the initial size and channels.\nmax_size, max_channels = 256, 512\nfirst_channels, first_size = AliasFreeActivation.alias_level_params(\n\t0, max_levels=14, max_size=max_size, max_channels=max_channels\n)[-2:]\n\nclass MyModel(nn.Module):\n\tdef __init__(self, level, max_levels=14, max_size=256, max_channels=512, margin=10):\n\t\t...\n\t\t# AdaIN will require the style vector to be 2*size\n\t\tleaky_relu = EqualLeakyReLU(negative_slope=0.2)\n\t\tself.activation = AliasFreeActivation(\n\t\t\tleaky_relu, level, max_levels=max_levels, max_size=max_size, max_channels=max_channels, margin=margin\n\t\t)\n\t\tself.conv = nn.Conv2d(self.activation.channels_prev, self.activation.channels, kernel_size=3, padding=1)\n\t\t...\n\t\n\tdef forward(self, x): # x the channels and size of X are dependent on  the level of this module.\n\t\t...\n\t\tx = self.conv(x)\n\t\tx = self.activation(x)\n\t\t...\n\n```\n\n## Layers\n\n### SimpleSelfAttention\nImplementation taken as is from https://github.com/sdoria/SimpleSelfAttention\n\nExample of use:\n```python\nfrom torchtools.nn import SimpleSelfAttention\n\n# The input of the layer has to at least have 3 dimensions (B, C, N), \n# the attention will be performed in the 2nd dimension.\n# \n# For images, the input will be internally reshaped to 3 dimensions,\n# and reshaped back to the original shape before returning it\n```\n\n### PixelNorm\nInspired from https://github.com/github-pengge/PyTorch-progressive_growing_of_gans\n\nExample of use:\n```python\nfrom torchtools.nn import PixelNorm\n\nmodel = nn.Linear(\n\tnn.Conv2d(...),\n\tPixelNorm(),\n\tnn.ReLU()\n)\n\n# It doesn't require any parameter, it just performs a simple element-wise normalization\n# x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)\n#\n# Just use it as a regular layer, generally after convolutions and before ReLU\n# (warning) since it performs a srtq root it's pretty slow if the layer sizes are big\n```\n\n### Adaptive Instance Normalization - AdaIN\nImplementation based on https://github.com/SiskonEmilia/StyleGAN-PyTorch  \nOriginal Paper https://arxiv.org/abs/1703.06868  \n\nExample of use:  \n```python\nfrom torchtools.nn import AdaIN\n\nclass MyModel(nn.Module):\n\tdef __init__(self, n_channels):\n\t\t...\n\t\t# AdaIN will require the style vector to be 2*size\n\t\tself.style = nn.Linear(input_size, output_size*2) \n\t\tself.adain = AdaIN(output_size)\n\t\t...\n\t\n\tdef forward(self, x, w):\n\t\t...\n\t\tx = self.adain(x, self.style(w))\n\t\t...\n\n# AdaIN will \"transfer\" a style encoded in a latent vector w into any tensor x.\n# In order to do this it first needs to be passed through a linear layer that will return 2 tensors (actually, one tensor of twice the size required, that we'll then split in 2)\n# It will then perform an instance normalization to \"whiten\" the tensor, followed with a de-normalization but using the values generated by the linear layer, thus encoding the original vector w in the tensor.\n```\n\n### EvoNorm\nImplementation taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98\nOriginal Paper https://arxiv.org/abs/2004.02967\n\nExample of use:  \n```python\nfrom torchtools.nn import EvoNorm2D\n\nmodel = nn.Linear(\n\tnn.Conv2d(...),\n\tEvoNorm2D(c_hidden), # For S0 version \n\t# evoB0 = EvoNorm2D(input, affine = True, version = 'B0', training = True) # For B0 version\n\tnn.ReLU()\n)\n```\n\n### GPT Transformer Encoder Layer\nImplementation based on MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy\n\nIt can be used as a drop-in replacement for the `torch.nn.TransformerEncoderLayer`\n\nExample of use: \n```python\nfrom torchtools.nn import GPTTransformerEncoderLayer\n\nclass MyTransformer(nn.Module):\n\tdef __init__(self, n_channels):\n\t\t...\n\t\tencoder_layer = GPTTransformerEncoderLayer(d_model=512, nhead=8)\n\t\tself.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)\n\t\t...\n\n```\n\n### Stylegan2 ModulatedConv2d\nImplementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 by Rosinality\n\nIt extends from `torch.nn.Conv2d` so you can use it as a drop-in replacement, the only Conv2d parámeter that you cannot use is 'groups' since it will be overriden for this to work.\n\nIt also includes a parameter `ema_decay` that will add the EMA normalization used in Alias-free GAN (defaults to 1, meaning that it's disabled)\n\nExample of use: \n```python\nfrom torchtools.nn import ModulatedConv2d\n\nclass MyModel(nn.Module):\n\tdef __init__(self):\n\t\t...\n\t\tself.conv = ModulatedConv2d(16, 32, kernel_size=3, padding=1) \n\t\t# SUGESTIONS: \n\t\t#   set bias=False if you want to handle bias on your own\n\t\t#   set demodulate=False for RGB output\n\t\t#   set ema_decay=0.9989 to imitate the alias-free gan setup\n\t\t...\n\n\tdef forward(self, x, w):\n\t\t...\n\t\tx = self.conv(x, w) # 'x' is a 4D tensor (B x C x W x H) and 'w' is a 2D tensor (B x C)\n\t\t...\n```\n\n### Equal Layers (EqualNorm, EqualLinear)\nImplementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94\n\nIt extends the base classes (nn.Linear, nn.Conv2d, nn.LeakyReLU) so you can use this as a drop-in replacement, although it includes some optiona parameters.\n\nExample of use: \n```python\nfrom torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d\n\nclass MyModel(nn.Module):\n\tdef __init__(self):\n\t\t...\n\t\tself.linear = EqualLinear(16, 32, bias_init=1, lr_mul=0.01) # bias_init and lr_mul are extra optional params\n\t\tself.leaky_relu = EqualLeakyReLU(negative_slope=0.2)\n\t\tself.conv = EqualConv2d(16, 32, kernel_size=3, padding=1)\n\t\t# Since this classes extend from the base classes, you can use all parameters from the original classes.\n\t\t...\n\n```\n\n### FourierFeatures2d\nImplementation inspired on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L88\nbut improved using my own understanding of how this should work...\n\nIt creates a 2d tensor of embeddings following a fourier series based on the parameters you provide, this features are dynamic, meaning that affine transformations can be applied to them in order to shift, rotate, and even scale (experimental).\n\n```python\nfrom torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d\n\nclass MyModel(nn.Module):\n\tdef __init__(self, dim=256, margin=10, cutoff=2):\n\t\t...\n\t\tself.feats = FourierFeatures2d(4+margin*2, dim, cutoff) # optionally enable scaling with allow_scaling=True\n\t\t# Also, you can randomize the frequencies if you plan on keeping them fixed, setting w_scale to any value > 0\n\t\t...\n\n\tdef forward(self, affine):\n\t\t...\n\t\tembds = self.feats(affine) # 'affine' should be a Bx4 tensor, or Bx6 if scaling is enabled...\n\t\t# the default or initial affine values should be [1, 0, 0, 0, 1, 1] => ([1, 0]: rotation, [0, 0]: shift, [1, 1]: scale)\n\t\t...\n\n```\n\n\n\n## Criterions\n\n### Gradient Penalty (for WGAN-GP)\nImplementation taken with minor changes from https://github.com/caogang/wgan-gp  \nOriginal paper https://arxiv.org/pdf/1704.00028.pdf\n\nExample of use:\n```python\nfrom torchtools.nn import GPLoss\n# This criterion defines the gradient penalty for WGAN GP\n# For an example of a training cycle refer to this repo https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L185\n\ndiscriminator = ...\ngpcriterion = GPLoss(discriminator) # l = 10 by default\n\ngradient_penalty = gpcriterion(real_data, fake_data)\ndiscriminator_loss = ... + gradient_penalty # add the gp component to the Wasserstein loss\n```\n\n### Total Variation Loss\nTotal Variation denoising https://www.wikiwand.com/en/Total_variation_denoising  \n\nExample of use:\n```python\n# This loss (or regularization) is usefull for removing artifacts and noise in generated images.  \n# It's widely used in style transfer.\nfrom torchtools.nn import TVLoss\n\ntvcriterion = TVLoss() # reduction = 'sum' and alpha = 1e-4 by default\n\nG = ... # output image\ntv_loss = tvcriterion(G)\nloss = ... + tv_loss # add the tv loss component to your reconstruction loss\n```\n\n\n## Vector Quantization\n### VectorQuantize: Encodding based quantization [(source)](torchtools/vq.py#L5)\nThis transforms any tensor to its quantized version using a codebook of embeddings.  \nIt uses a traight-forward approach for applying the gradients.  \nPassing a tensor trough the **VectorQuantize** module will return a new tensor with the same dimension but changing each one of the tensors of the last dimension by the nearest neighbor from the codebook, which has a limited number of values, thus quantizing the tensor.\n\nFor the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L4)\n\nThe output of the model is a quantized tensor, as well as a Touple of the loss components of the codebook (needed for training), and the indices of the quantized vectors in the form: `qx, (vq_loss, commit_loss), indices`\n\nWhen **creating a new instance of the module**, it accepts the following parameters: \n  - **embedding_size**: the size of the embeddings used in the codebook, should match the last dimension of the tensor you want to quantize\n  - **k**: the size of the codebook, or number of embeddings. \n  - **ema_decay** (default=0.99): the Exponentially Moving Average decay used (this only will be used if ema_loss is True)\n  - **ema_loss** (default=False): Enables Exponentially Moving Average update of the codebook (instead of relying on gradient descent as EMA converges faster) \n\nWhen **calling the forward method** of the module, it accepts the following parameters:\n  - **x**: this is the tensor you want to quantize, make sure the dimension that you want to quantize (by default is the last one) matches embedding_size defined when instantiating the module\n  - **get_losses** (default=True): when False, the vq_loss and commit_loss components of the output will both be None, this should speed up a little bit the model when used for inference.\n  - **dim** (default=-1): The dimension across which the input should be quantized.\n\nExample of use:\n```python\nfrom torchtools.nn import VectorQuantize\n\ne = torch.randn(1, 16, 16, 8) # create a random tensor with 8 as its last dimension size\nvquantizer = VectorQuantize(8, k=32, ema_loss=True) # we create the module with embedding size of 8, a codebook of size 32 and make the codebook update using EMA\nqe, (vq_loss, commit_loss), indices = vquantizer.forward(e) # we quantize our tensor while also getting the loss components and the indices\n\n# NOTE While the model is in training mode, the codebook will always be updated when calling the forward method, in order to freeze the codebook for inference put it in evaluation mode with 'vquantizer.eval()'\n\n# NOTE 2 In order to update the module properly, add the loss components to the final model loss before calling backward(), if you set ema_loss to true you only need to add the commit_loss to the total loss, an it's usually multiplied by a value between 0.1 and 2, being 0.25 a good default value\n\nloss = ... # whatever loss you have for your final output\nloss += commit_loss * 0.25\n# loss += vq_loss # only if you didn't set the ema_loss to True\n\n...\nloss.backward()\noptimizer.step()\n\n```\n\n--- \n\n### Binarize: binarize the input tensor [(source)](torchtools/vq.py#L55)\nThis transfors the values of a tensor into 0 and 1 depending if they're above or below a specified threshold.\nIt uses a traight-forward approach for applying the gradients, so it's effectively differentiable.\n\nFor the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L36)\n\nExample of use:\n```python\nfrom torchtools.nn import Binarize\n\ne = torch.randn(8, 16) # create a random tensor with any dimension\n\nbinarizer = Binarize(threshold=0.5) # you can set the threshold you want, for example if your output was passed through a tanh activation, 0 might be a better theshold since tanh outputs values between -1 and 1\n\nbq = binarizer(e) # will return a tensor with the same shape as e, but full of 0s and 1s\n```\n\n## Embeddings\n\n### RotaryEmbedding\nImplementation taken as is from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L161\n\nExample of use:\n```python\nfrom torchtools.nn import RotaryEmbedding\n\nclass MyModel(nn.Module):\n\tdef __init__(self, dim):\n\t\t...\n\t\tself.rotary_pos_embd =  RotaryEmbedding(dim)\n\t\t...\n\t\n\tdef forward(self, x):\n\t\tx = self.rotary_pos_embd(x)\n\t\t...\n\n\n```\n\n## Diffusion\n\n### Diffuzz\nCustom (non-cached) continuous forward/backward diffusion.\nIt's not SUPER performant since it calculates all the required values on the fly instead of caching them (although I add a very simple cache where you can specify the number of steps that you want to cache), but in general I think this will not make an extremely big difference in terms of performance, simplifies a lot the code, and removes the concept of having a fixed number of timesteps for the forward diffusion (since I always found it weird to train a model assuming 1000 forward diffusion steps, and then using way less steps during inference) by using a continuous value between 0 and 1 to decide how much noise we'll be adding to the output (1 being pure gaussian noise).\n\nDuring sampling, the same applies, instead of having a fixed number of steps, the diffuzz module will accept a noised input, a couple of values t & t_prev (between 0 and 1) and a predicted noise, and it will try to remove such noise in a scale such as to go from step t to step t_prev, so if we want to denoise in 10 steps we'll tell it to go from 1.0 to 0.9, then to 0.8, etc... while if we want to denoise in 100 steps, we'll start at 1.0 and go to 0.99, then to 0.98, etc... \n\nExample of use during training:\n```python\nfrom torchtools.utils import Diffuzz\ndevice = \"cuda\"\n\ndiffuzz = Diffuzz(device=device)\n# diffuzz = Diffuzz(device=device, cache_steps=10000) # optionally you can pass a 'cache_steps' parameter to speed up the noising process\ncustom_unet = CustomUnet().to(device) # Custom model whith output size = input size\n\ninput_tensor = torch.randn(8, 3, 16, 16, device=device) # an image, audio signal, or whatever...\n\nt = torch.rand(input_tensor.size(0), device=device) # get a tensor with batch_size of values between 0 and 1\nnoised_tensor, noise = diffuzz.diffuse(input_tensor, t)\n\npredicted_noise = custom_unet(noised_tensor, t)\nloss = nn.functional.mse_loss(predicted_noise, noise)\n\n# Optionally the diffuzz module provides loss gamma weighting (untested) but for this to work the loss \n# should not be averaged on the batch dimension before applying it.\n\n# loss = nn.functional.mse_loss(predicted_noise, noise, reduction='none').mean(dim=[1, 2, 3])\n# loss = (loss * diffuzz.p2_weight(t)).mean()\n\n```\n\nExample of use for sampling:\n```python\nfrom torchtools.utils import Diffuzz\ndevice = \"cuda\"\n\nsampled = diffuzz.sample(\n\tcustom_unet, {'c': conditioning}, \n\t(conditioning.size(0), 3, 16, 16),\n\ttimesteps=20, sampler='ddim'\n)[-1]\n```\n\nthe `sample` method accepts a `sampler` parameter, currently only `ddpm` (default) and `ddim` are implemented, but I'm planning on adding more, very likely by borrowing (and appropriately citing) code from this repo https://github.com/ozanciga/diffusion-for-beginners\n\n"
  },
  {
    "path": "setup.cfg",
    "content": "[metadata]\ndescription-file = readme.md"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name='torchtools',\n    packages=find_packages(),\n    description='PyTorch useful tools',\n    version='0.3.5',\n    url='https://github.com/pabloppp/pytorch-tools',\n    author='Pablo Pernías',\n    author_email='pablo@pernias.com',\n    keywords=['pip', 'pytorch', 'tools', 'RAdam', 'Lookahead', 'RALamb', 'quantization'],\n    zip_safe=False,\n    install_requires=[\n        'torch>=1.6',\n        'torchvision',\n        'numpy>=1.0',\n        'ninja>=1.0'\n    ],\n    package_data={\n        'stylegan2.tools': ['torchtools/nn/stylegan2/*'],\n        'transforms.models': ['torchtools/transforms/models/*']\n    },\n    include_package_data=True,\n)\n"
  },
  {
    "path": "torchtools/__init__.py",
    "content": ""
  },
  {
    "path": "torchtools/lr_scheduler/__init__.py",
    "content": "from .delayed import DelayerScheduler, DelayedCosineAnnealingLR\nfrom .inverse_sqrt import InverseSqrtLR\n"
  },
  {
    "path": "torchtools/lr_scheduler/delayed.py",
    "content": "from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR\n\nclass DelayerScheduler(_LRScheduler):\n\t\"\"\" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler \n\n\tArgs:\n\t\toptimizer (Optimizer): Wrapped optimizer.\n\t\tdelay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler\n\t\tafter_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)\n\t\"\"\"\n\n\tdef __init__(self, optimizer, delay_epochs, after_scheduler):\n\t\tself.delay_epochs = delay_epochs\n\t\tself.after_scheduler = after_scheduler\n\t\tself.finished = False\n\t\tsuper().__init__(optimizer)\n\n\tdef get_lr(self):\n\t\tif self.last_epoch >= self.delay_epochs:\n\t\t\tif not self.finished:\n\t\t\t\tself.after_scheduler.base_lrs = self.base_lrs\n\t\t\t\tself.finished = True\n\t\t\treturn self.after_scheduler.get_lr()\n\n\t\treturn self.base_lrs\n\n\tdef step(self, epoch=None):\n\t\tif self.finished:\n\t\t\tif epoch is None:\n\t\t\t\tself.after_scheduler.step(None)\n\t\t\telse:\n\t\t\t\tself.after_scheduler.step(epoch - self.delay_epochs)\n\t\telse:\n\t\t\treturn super(DelayerScheduler, self).step(epoch)\n\ndef DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs):\n\tbase_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs)\n\treturn DelayerScheduler(optimizer, delay_epochs, base_scheduler)"
  },
  {
    "path": "torchtools/lr_scheduler/inverse_sqrt.py",
    "content": "import warnings\n\nfrom torch.optim.lr_scheduler import LRScheduler\n\n\nclass InverseSqrtLR(LRScheduler):\n    def __init__(self, optimizer, lr, warmup_steps, pre_warmup_lr=None, last_epoch=-1, verbose=False):\n        warmup_steps = max(warmup_steps, 1)\n        self.lr = lr * warmup_steps**0.5\n        self.warmup_steps = warmup_steps\n        self.pre_warmup_lr = pre_warmup_lr if pre_warmup_lr is not None else lr\n        super().__init__(optimizer, last_epoch, verbose)\n\n    def _process_lr(self, _):\n        warmup_factor = min(self.last_epoch/self.warmup_steps, 1) # this grows linearly from 0 to 1 during the warmup\n        base_lr = self.lr / max(self.last_epoch, self.warmup_steps)**0.5\n        return warmup_factor * base_lr + (1-warmup_factor)*self.pre_warmup_lr\n\n    def get_lr(self):\n        if not self._get_lr_called_within_step:\n            warnings.warn(\"To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\", UserWarning)\n\n        lr = self._process_lr(self.lr)\n        return [lr for _ in self.optimizer.param_groups]\n\n    def _get_closed_form_lr(self):\n        return [self._process_lr(base_lr) for base_lr in self.base_lrs]"
  },
  {
    "path": "torchtools/nn/__init__.py",
    "content": "from .mish import Mish\nfrom .simple_self_attention import SimpleSelfAttention\nfrom .vq import VectorQuantize, Binarize, FSQ\nfrom .gp_loss import GPLoss\nfrom .pixel_normalzation import PixelNorm\nfrom .perceptual import TVLoss\nfrom .adain import AdaIN\nfrom .transformers import GPTTransformerEncoderLayer\nfrom .evonorm2d import EvoNorm2D\nfrom .pos_embeddings import RotaryEmbedding\nfrom .modulation import ModulatedConv2d\nfrom .equal_layers import EqualConv2d, EqualLeakyReLU, EqualLinear\nfrom .fourier_features import FourierFeatures2d\n# from .alias_free_activation import AliasFreeActivation\nfrom .magnitude_preserving import MP_GELU, MP_SiLU, Gain\nfrom .haar_dwt import HaarForward, HaarInverse"
  },
  {
    "path": "torchtools/nn/adain.py",
    "content": "import torch\nfrom torch import nn\n\nclass AdaIN(nn.Module):\n    def __init__(self, n_channels):\n        super(AdaIN, self).__init__()\n        self.norm = nn.InstanceNorm2d(n_channels)\n\n    def forward(self, image, style):\n        factor, bias = style.view(style.size(0), style.size(1), 1, 1).chunk(2, dim=1)\n        result = self.norm(image) * factor + bias\n        return result"
  },
  {
    "path": "torchtools/nn/alias_free_activation.py",
    "content": "import torch\nfrom torch import nn\nimport math\nfrom .stylegan2 import upfirdn2d\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM \n# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225\n# But I simplified it into a single (almots) self-contained module. \n# Probably I give this module too much reponsibility but meh...\n####\n\nclass AliasFreeActivation(nn.Module):\n    def __init__(self, activation, level, max_levels, max_size, max_channels, margin, start_cutoff=2, critical_layers=2, window_size=6):\n        super().__init__()\n        self.activation = activation\n\n        # Filter features\n        self.cutoff, self.stopband, self.band_half, self.channels, self.size = self.alias_level_params(\n            level, max_levels, max_size, max_channels, start_cutoff, critical_layers\n        )\n        self.cutoff_prev, self.stopband_prev, self.band_half_prev, self.channels_prev, self.size_prev = self.alias_level_params(\n            max(level-1, 0), max_levels, max_size, max_channels, start_cutoff, critical_layers\n        )\n\n        # Filters\n        self.scale_factor = 2 if self.size_prev < self.size else 1\n        up_filter = self._lowpass_filter(\n            window_size * self.scale_factor * 2, self.cutoff_prev, self.band_half_prev, self.size * self.scale_factor * 2\n        )\n        self.register_buffer(\"up_filter\", (up_filter / up_filter.sum()) * 2 * self.scale_factor)\n\n        down_filter = self._lowpass_filter(\n            window_size * self.scale_factor, self.cutoff, self.band_half, self.size * self.scale_factor * 2\n        )\n        self.register_buffer(\"down_filter\", down_filter / down_filter.sum())\n\n        p = self.up_filter.shape[0] - (2*self.scale_factor)\n        self.up_pad = ((p + 1) // 2 + (2*self.scale_factor) - 1, p // 2)\n\n        p = self.down_filter.shape[0] - 2\n        self.down_pad = ((p + 1) // 2, p // 2)\n        self.margin = margin\n\n    @staticmethod\n    def alias_level_params(level, max_levels, max_size, max_channels, start_cutoff=2, critical_layers=2, base_channels=2**14):\n        end_cutoff = max_size//2\n        cutoff = start_cutoff * (end_cutoff / start_cutoff) ** min(level / (max_levels - critical_layers), 1)\n\n        start_stopband = start_cutoff ** 2.1\n        end_stopband = end_cutoff * (2 ** 0.3)\n        stopband = start_stopband * (end_stopband/start_stopband) ** min(level / (max_levels - critical_layers), 1)\n\n        size = 2 ** math.ceil(math.log(min(2 * stopband, max_size), 2))\n        band_half = max(stopband, size / 2) - cutoff\n        channels = min(round(base_channels / size), max_channels)\n\n        return cutoff, stopband, band_half, channels, size\n\n    def _lowpass_filter(self, n_taps, cutoff, band_half, sr):\n        window = self._kaiser_window(n_taps, band_half, sr)\n        ind = torch.arange(n_taps) - (n_taps - 1) / 2\n        lowpass = 2 * cutoff / sr * torch.sinc(2 * cutoff / sr * ind) * window\n\n        return lowpass\n\n    def _kaiser_window(self, n_taps, f_h, sr):\n        beta = self._kaiser_beta(n_taps, f_h, sr)\n        ind = torch.arange(n_taps) - (n_taps - 1) / 2\n        return torch.i0(beta * torch.sqrt(1 - ((2 * ind) / (n_taps - 1)) ** 2)) / torch.i0(torch.tensor(beta))\n\n    def _kaiser_attenuation(self, n_taps, f_h, sr):\n        df = (2 * f_h) / (sr / 2)\n        return 2.285 * (n_taps - 1) * math.pi * df + 7.95\n\n\n    def _kaiser_beta(self, n_taps, f_h, sr):\n        atten = self._kaiser_attenuation(n_taps, f_h, sr)\n        if atten > 50:\n            return 0.1102 * (atten - 8.7)\n\n        elif 50 >= atten >= 21:\n            return 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21)\n        else:\n            return 0.0\n\n    def forward(self, x):\n        x = self._upsample(x, self.up_filter, 2*self.scale_factor, pad=self.up_pad)\n        x = self.activation(x)\n        x = self._downsample(x, self.down_filter, 2, pad=self.down_pad)\n        if self.scale_factor > 1 and self.margin > 0:\n            m = self.scale_factor * self.margin // 2\n            x = x[:, :, m:-m, m:-m]\n        return x\n\n    def _upsample(self, x, kernel, factor, pad=(0, 0)):\n        x = upfirdn2d(x, kernel.unsqueeze(0), up=(factor, 1), pad=(*pad, 0, 0))\n        x = upfirdn2d(x, kernel.unsqueeze(1), up=(1, factor), pad=(0, 0, *pad))\n        return x\n\n    def _downsample(self, x, kernel, factor, pad=(0, 0)):\n        x = upfirdn2d(x, kernel.unsqueeze(0), down=(factor, 1), pad=(*pad, 0, 0))\n        x = upfirdn2d(x, kernel.unsqueeze(1), down=(1, factor), pad=(0, 0, *pad))\n        return x\n\n    def extra_repr(self):\n        info_string = f'cutoff={self.cutoff}, stopband={self.stopband}, band_half={self.band_half}, channels={self.channels}, size={self.size}'\n        return info_string "
  },
  {
    "path": "torchtools/nn/equal_layers.py",
    "content": "\nimport torch\nfrom torch import nn\nimport math\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM \n# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94\n# But made it extend from the base modules to avoid some boilerplate\n####\n\nclass EqualLinear(nn.Linear):\n    def __init__(self, *args, bias_init=0, lr_mul=1, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        self.scale = (1 / math.sqrt(self.in_features)) * lr_mul\n        self.lr_mul = lr_mul\n\n        nn.init.normal_(self.weight, std=1/lr_mul)\n        if self.bias is not None:\n            nn.init.constant_(self.bias, bias_init)\n\n    def forward(self, x):\n        return nn.functional.linear(x, self.weight * self.scale, self.bias * self.lr_mul)\n\n\nclass EqualConv2d(nn.Conv2d):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        fan_in = self.in_channels * self.kernel_size[0] ** 2\n        self.scale = 1 / math.sqrt(fan_in)\n\n        nn.init.normal_(self.weight)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def forward(self, x):\n        return self._conv_forward(x, self.weight * self.scale, self.bias)\n\n\nclass EqualLeakyReLU(nn.LeakyReLU):\n    def __init__(self, *args, scale=2**0.5, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.scale = scale\n    \n    def forward(self, x):\n        return super().forward(x) * self.scale"
  },
  {
    "path": "torchtools/nn/evonorm2d.py",
    "content": "import torch\nimport torch.nn as nn\n\n\n## Taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98 \nclass SwishImplementation(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, i):\n        ctx.save_for_backward(i)\n        return i * torch.sigmoid(i)\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        sigmoid_i = torch.sigmoid(ctx.saved_variables[0])\n        return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))\n\n\nclass MemoryEfficientSwish(nn.Module):\n    def forward(self, x):\n        return SwishImplementation.apply(x)\n\ndef instance_std(x, eps=1e-5):\n    var = torch.var(x, dim = (2, 3), keepdim=True).expand_as(x)\n    if torch.isnan(var).any():\n        var = torch.zeros(var.shape)\n    return torch.sqrt(var + eps)\n\ndef group_std(x, groups = 32, eps = 1e-5):\n    N, C, H, W = x.size()\n    x = torch.reshape(x, (N, groups, C // groups, H, W))\n    var = torch.var(x, dim = (2, 3, 4), keepdim = True).expand_as(x)\n    return torch.reshape(torch.sqrt(var + eps), (N, C, H, W))\n\nclass EvoNorm2D(nn.Module):\n\n    def __init__(self, input, non_linear = True, version = 'S0', efficient = False, affine = True, momentum = 0.9, eps = 1e-5, groups = 32, training = True):\n        super(EvoNorm2D, self).__init__()\n        self.non_linear = non_linear\n        self.version = version\n        self.training = training\n        self.momentum = momentum\n        self.efficient = efficient\n        if self.version == 'S0':\n            self.swish = MemoryEfficientSwish()\n        self.groups = groups\n        self.eps = eps\n        if self.version not in ['B0', 'S0']:\n            raise ValueError(\"Invalid EvoNorm version\")\n        self.insize = input\n        self.affine = affine\n\n        if self.affine:\n            self.gamma = nn.Parameter(torch.ones(1, self.insize, 1, 1))\n            self.beta = nn.Parameter(torch.zeros(1, self.insize, 1, 1))\n            if self.non_linear:\n                self.v = nn.Parameter(torch.ones(1,self.insize,1,1))\n        else:\n            self.register_parameter('gamma', None)\n            self.register_parameter('beta', None)\n            self.register_buffer('v', None)\n        self.register_buffer('running_var', torch.ones(1, self.insize, 1, 1))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.running_var.fill_(1)\n\n    def _check_input_dim(self, x):\n        if x.dim() != 4:\n            raise ValueError('expected 4D input (got {}D input)'\n                             .format(x.dim()))\n    \n    def forward(self, x):\n        self._check_input_dim(x)\n        if self.version == 'S0':\n            if self.non_linear:\n                if not self.efficient:\n                    num = x * torch.sigmoid(self.v * x)   # Original Swish Implementation, however memory intensive.\n                else:\n                    num = self.swish(x)    # Experimental Memory Efficient Variant of Swish\n                return num / group_std(x, groups = self.groups, eps = self.eps) * self.gamma + self.beta\n            else:\n                return x * self.gamma + self.beta\n        if self.version == 'B0':\n            if self.training:\n                var = torch.var(x, dim = (0, 2, 3), unbiased = False, keepdim = True)\n                self.running_var.mul_(self.momentum)\n                self.running_var.add_((1 - self.momentum) * var)\n            else:\n                var = self.running_var\n\n            if self.non_linear:\n                den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x, eps = self.eps))\n                return x / den * self.gamma + self.beta\n            else:\n                return x * self.gamma + self.beta"
  },
  {
    "path": "torchtools/nn/fourier_features.py",
    "content": "import torch\nfrom torch import nn\nimport math\n\nclass FourierFeatures2d(nn.Module):\n    def __init__(self, size, dim, cutoff, affine_eps=1e-8, freq_range=[-0.5, 0.5], w_scale=0, allow_scaling=False, op_order=['r', 't', 's']):\n        super().__init__()\n        self.size = size\n        self.dim = dim\n        self.cutoff = cutoff\n        self.freq_range = freq_range\n        self.affine_eps = affine_eps\n        self.w_scale = w_scale\n        coords = torch.linspace(freq_range[0], freq_range[1], size+1)[:-1]\n        freqs = torch.linspace(0, cutoff, dim // 4)\n        if w_scale > 0:\n            freqs = freqs @ (torch.randn(dim // 4, dim // 4) * w_scale)\n        coord_map = torch.outer(freqs, coords)\n        coord_map = 2 * math.pi * coord_map\n        self.register_buffer(\"coord_h\", coord_map.view(freqs.shape[0], 1, size))\n        self.register_buffer(\"coord_w\", self.coord_h.transpose(1, 2).detach())\n        self.register_buffer(\"lf\", freqs.view(1, dim // 4, 1, 1) * 2*math.pi * 2/size)\n        self.allow_scaling = allow_scaling\n        for op in op_order:\n            assert op in ['r', 't', 's'], f\"Operation not valid: {op}\"\n        self.op_order = op_order\n\n    def forward(self, affine):\n        norm = ((affine[:, 0:1].pow(2) + affine[:, 1:2].pow(2)).sqrt() + self.affine_eps).expand(affine.size(0), 4)\n        if self.allow_scaling:\n            assert affine.size(-1) == 6, f\"If scaling is enabled, 2 extra values must be passed for a total of 6, and not {affine.size(-1)}\"\n            norm = torch.cat([norm, norm.new_ones(affine.size(0), 2)], dim=1)\n        else:\n            assert affine.size(-1) == 4, f\"If scaling is disabled, 4 affine values should be passed, and not {affine.size(-1)}\"\n        affine = affine / norm\n        affine = affine[:, :, None, None, None]\n\n        coord_h, coord_w = self.coord_h.unsqueeze(0), self.coord_w.unsqueeze(0)\n\n        for op in reversed(self.op_order):\n            if op == 's' and self.allow_scaling:\n                coord_h = coord_h / nn.functional.threshold(affine[:, 5], 1.0, 1.0) # scale\n                coord_w = coord_w / nn.functional.threshold(affine[:, 4], 1.0, 1.0)\n\n            elif op == 't':\n                coord_h = coord_h - (affine[:, 3] * self.lf) # shift\n                coord_w = coord_w - (affine[:, 2] * self.lf) \n            \n            elif op == 'r':\n                _coord_h = -coord_w * affine[:, 1] + coord_h * affine[:, 0] # rotation\n                coord_w = coord_w * affine[:, 0] + coord_h * affine[:, 1]\n                coord_h = _coord_h\n\n        coord_h = torch.cat((torch.sin(coord_h), torch.cos(coord_h)), 1)\n        coord_w = torch.cat((torch.sin(coord_w), torch.cos(coord_w)), 1)\n\n        coords = torch.cat((coord_h, coord_w), 1)\n        return coords\n\n    def extra_repr(self):\n        info_string = f'size={self.size}, dim={self.dim}, cutoff={self.cutoff}, freq_range={self.freq_range}'\n        if self.w_scale > 0:\n            info_string += f', w_scale={self.w_scale}'\n        if self.allow_scaling:\n            info_string += f', allow_scaling={self.allow_scaling}'\n        return info_string \n"
  },
  {
    "path": "torchtools/nn/functional/__init__.py",
    "content": "from .vq import vector_quantize, binarize\nfrom .gradient_penalty import gradient_penalty\nfrom .perceptual import total_variation\nfrom .magnitude_preserving import mp_cat, mp_sum\n"
  },
  {
    "path": "torchtools/nn/functional/gradient_penalty.py",
    "content": "####\n# CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp\n# ORIGINAL PAPER https://arxiv.org/pdf/1704.00028.pdf\n####\n\nimport torch\nfrom torch import autograd\n\ndef gradient_penalty(netD, real_data, fake_data, l=10):\n    batch_size = real_data.size(0)\n    alpha = real_data.new_empty((batch_size, 1, 1, 1)).uniform_(0, 1)\n    alpha = alpha.expand_as(real_data)\n\n    interpolates = alpha * real_data + ((1 - alpha) * fake_data)\n    interpolates = autograd.Variable(interpolates, requires_grad=True)\n\n    disc_interpolates = netD(interpolates)\n\n    gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,\n                              grad_outputs=real_data.new_ones(disc_interpolates.size()),\n                              create_graph=True, retain_graph=True, only_inputs=True)[0]\n\n    gradients = gradients.view(gradients.size(0), -1)\n    gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)\n    gradient_penalty = ((gradients_norm - 1) ** 2).mean() * l\n\n    return gradient_penalty"
  },
  {
    "path": "torchtools/nn/functional/magnitude_preserving.py",
    "content": "import torch\n\ndef mp_cat(*args, dim=1, t=0.5):\n    if isinstance(t, float):\n        t = [1-t, t]\n    assert len(args) == len(t), \"t must be a single scalar or a list of scalars of length len(args)\"\n    \n    w = [m/a.size(dim)**0.5 for a, m in zip(args, t)]\n    C = (sum([a.size(dim) for a in args]) / sum([m**2 for m in t]))**0.5\n\n    return torch.cat([a*v for a, v in zip(args, w)], dim=dim) * C\n\ndef mp_sum(*args, t=0.5):\n    if isinstance(t, float):\n        t = [1-t, t]\n\n    assert len(args) == len(t), \"t must be a single scalar or a list of scalars of length len(args)\"\n    assert abs(sum(t)-1) < 1e-3 , \"the values of t should all add up to one\"\n\n    return sum([a*m for a, m in zip(args, t)]) / sum([m**2 for m in t])**0.5\n"
  },
  {
    "path": "torchtools/nn/functional/perceptual.py",
    "content": "import torch\n\ndef total_variation(X, reduction='sum'):\n\ttv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1])\n\ttv_v = torch.abs(X[:, :, 1:] - X[:, :, :-1])\n\n\ttv = torch.mean(tv_h) + torch.mean(tv_v) if reduction == 'mean' else torch.sum(tv_h) + torch.sum(tv_v)\n\t\n\treturn tv"
  },
  {
    "path": "torchtools/nn/functional/vq.py",
    "content": "import torch\nfrom torch.autograd import Function\n\nclass vector_quantize(Function):\n\t@staticmethod\n\tdef forward(ctx, x, codebook):\n\t\twith torch.no_grad():\n\t\t\tcodebook_sqr = torch.sum(codebook ** 2, dim=1)\n\t\t\tx_sqr = torch.sum(x ** 2, dim=1, keepdim=True)\n\n\t\t\tdist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)\n\t\t\t_, indices = dist.min(dim=1)\n\t\t\t\n\t\t\tctx.save_for_backward(indices, codebook)\n\t\t\tctx.mark_non_differentiable(indices)\n\n\t\t\tnn = torch.index_select(codebook, 0, indices)\n\t\t\treturn nn, indices\n\t\n\t@staticmethod\n\tdef backward(ctx, grad_output, grad_indices):\n\t\tgrad_inputs, grad_codebook = None, None\n\t\t\n\t\tif ctx.needs_input_grad[0]:\n\t\t\tgrad_inputs = grad_output.clone()\n\t\tif ctx.needs_input_grad[1]:\n\t\t\t# Gradient wrt. the codebook\n\t\t\tindices, codebook = ctx.saved_tensors\n\n\t\t\tgrad_codebook = torch.zeros_like(codebook)\n\t\t\tgrad_codebook.index_add_(0, indices, grad_output)\n\t\t\n\t\treturn (grad_inputs, grad_codebook)\n\n\nclass binarize(Function):\n\t@staticmethod\n\tdef forward(ctx, x, threshold=0.5):\n\t\twith torch.no_grad():\n\t\t\tbinarized = (x > threshold).float()\n\t\t\tctx.mark_non_differentiable(binarized)\n\n\t\t\treturn binarized\n\t\n\t@staticmethod\n\tdef backward(ctx, grad_output):\n\t\tgrad_inputs = None\n\t\t\n\t\tif ctx.needs_input_grad[0]:\n\t\t\tgrad_inputs = grad_output.clone()\n\t\t\n\t\treturn grad_inputs"
  },
  {
    "path": "torchtools/nn/gp_loss.py",
    "content": "import torch\nfrom torch import nn\nfrom .functional import gradient_penalty\n\nclass GPLoss(nn.Module):\n\tdef __init__(self, discriminator, l=10):\n\t\tsuper(GPLoss, self).__init__()\n\t\tself.discriminator = discriminator\n\t\tself.l = l\n\n\tdef forward(self, real_data, fake_data):\n\t\treturn gradient_penalty(self.discriminator, real_data, fake_data, self.l)"
  },
  {
    "path": "torchtools/nn/haar_dwt.py",
    "content": "import torch\nfrom torch import nn\n\n# Taken almost as is from https://github.com/bes-dev/haar_pytorch\nclass HaarForward(nn.Module):\n    \"\"\"\n    Performs a 2d DWT Forward decomposition of an image using Haar Wavelets\n    set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt\n    \"\"\"\n    def __init__(self, beta=2):\n        super().__init__()\n        self.alpha = 0.5\n        self.beta = beta\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Performs a 2d DWT Forward decomposition of an image using Haar Wavelets\n\n        Arguments:\n            x (torch.Tensor): input tensor of shape [b, c, h, w]\n\n        Returns:\n            out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2]\n        \"\"\"\n\n        ll = self.alpha/self.beta * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2])\n        lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2])\n        hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2])\n        hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2])\n        return torch.cat([ll,lh,hl,hh], axis=1)\n\n\nclass HaarInverse(nn.Module):\n    \"\"\"\n    Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets\n    set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt\n    \"\"\"\n    def __init__(self, beta=2):\n        super().__init__()\n        self.alpha = 0.5\n        self.beta = beta\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        \"\"\"\n        Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets\n\n        Arguments:\n            x (torch.Tensor): input tensor of shape [b, c, h, w]\n\n        Returns:\n            out (torch.Tensor): output tensor of shape [b, c / 4, h * 2, w * 2]\n        \"\"\"\n        assert x.size(1) % 4 == 0, \"The number of channels must be divisible by 4.\"\n        size = [x.shape[0], x.shape[1] // 4, x.shape[2] * 2, x.shape[3] * 2]\n        f = lambda i: x[:, size[1] * i : size[1] * (i + 1)]\n        out = torch.zeros(size, dtype=x.dtype, device=x.device)\n        out[:,:,0::2,0::2] = self.alpha * (f(0)*self.beta + f(1) + f(2) + f(3))\n        out[:,:,0::2,1::2] = self.alpha * (f(0)*self.beta + f(1) - f(2) - f(3))\n        out[:,:,1::2,0::2] = self.alpha * (f(0)*self.beta - f(1) + f(2) - f(3))\n        out[:,:,1::2,1::2] = self.alpha * (f(0)*self.beta - f(1) - f(2) + f(3))\n        return out"
  },
  {
    "path": "torchtools/nn/magnitude_preserving.py",
    "content": "import torch\nfrom torch import nn\n\nclass MP_GELU(nn.GELU):\n    def forward(self, x):\n        return super().forward(x) / 0.652 # ¯\\_(ツ)_/¯\n\nclass MP_SiLU(nn.SiLU):\n    def forward(self, x):\n        return super().forward(x) / 0.596 # ¯\\_(ツ)_/¯\n    \nclass Gain(nn.Module):\n    def __init__(self, init_w=0.0):\n        super().__init__()\n        self.g = nn.Parameter(torch.tensor([init_w]))\n\n    def forward(self, x):\n        return x * self.g\n\n"
  },
  {
    "path": "torchtools/nn/mish.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/lessw2020/mish\n# ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1\n####\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F  #(uncomment if needed,but you likely already have it)\n\n#Mish - \"Mish: A Self Regularized Non-Monotonic Neural Activation Function\"\n#https://arxiv.org/abs/1908.08681v1\n#implemented for PyTorch / FastAI by lessw2020 \n#github: https://github.com/lessw2020/mish\n\nclass Mish(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)\n        return x *( torch.tanh(F.softplus(x)))"
  },
  {
    "path": "torchtools/nn/modulation.py",
    "content": "\nimport torch\nfrom torch import nn\nimport math\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM \n# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143\n# But made it extend from the base Conv2d to avoid some boilerplate\n####\nclass ModulatedConv2d(nn.Conv2d):\n    def __init__(self,  *args, demodulate=True, ema_decay=1.0, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        fan_in = self.in_channels * self.kernel_size[0] ** 2\n        self.scale = 1 / math.sqrt(fan_in)\n\n        self.demodulate = demodulate\n        self.ema_decay = ema_decay\n        self.register_buffer(\"ema_var\", torch.tensor(1.0))\n        nn.init.normal_(self.weight)\n        if self.bias is not None:\n            nn.init.zeros_(self.bias)\n\n    def forward(self, x, w):\n        batch, in_channels, height, width = x.shape\n\n        style = w.view(batch, 1, in_channels, 1, 1)\n        weight = self.scale * self.weight.unsqueeze(0) * style\n\n        if self.demodulate:\n            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)\n            weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)\n\n        weight = weight.view(\n            batch * self.out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]\n        )\n\n        if self.ema_decay < 1:\n            if self.training:\n                var = x.pow(2).mean((0, 1, 2, 3))\n                self.ema_var.mul_(self.ema_decay).add_(var.detach(), alpha=1 - self.ema_decay)\n\n            weight = weight / (torch.sqrt(self.ema_var) + 1e-8)\n\n        input = x.view(1, batch * in_channels, height, width)\n        self.groups = batch\n        out = self._conv_forward(input, weight, None)\n        _, _, height, width = out.shape\n        out = out.view(batch, self.out_channels, height, width)\n        if self.bias is not None:\n            out = out + self.bias.view(1, -1, 1, 1)\n        return out\n"
  },
  {
    "path": "torchtools/nn/perceptual.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .functional import total_variation\n\nclass TVLoss(nn.Module):\n\tdef __init__(self, reduction='sum', alpha=1e-4):\n\t\tsuper(TVLoss, self).__init__()\n\t\tself.reduction = reduction\n\t\tself.alpha = alpha\n\n\tdef forward(self, x):\n\t\treturn total_variation(x, reduction=self.reduction) * self.alpha"
  },
  {
    "path": "torchtools/nn/pixel_normalzation.py",
    "content": "import torch\nfrom torch import nn\n\nclass PixelNorm(nn.Module):\n    def __init__(self, dim=1, eps=1e-4):\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n \n    def forward(self, x):\n        return x / (torch.sqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True)) + self.eps)"
  },
  {
    "path": "torchtools/nn/pos_embeddings.py",
    "content": "import torch\nfrom torch import nn\n\n####\n# CODE TAKEN FROM https://github.com/lucidrains/x-transformers\n####\n\nclass RotaryEmbedding(nn.Module):\n    def __init__(self, dim, base=10000):\n        super().__init__()\n        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer('inv_freq', inv_freq)\n        self.seq_len_cached = None\n        self.cos_cached = None\n        self.sin_cached = None\n\n    def forward(self, x, seq_dim=1):\n        seq_len = x.shape[seq_dim]\n        if seq_len != self.seq_len_cached:\n            self.seq_len_cached = seq_len\n            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)\n            freqs = torch.einsum('i,j->ij', t, self.inv_freq)\n            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)\n            self.cos_cached = emb.cos()[:, None, None, :]\n            self.sin_cached = emb.sin()[:, None, None, :]\n        return self.cos_cached, self.sin_cached\n        \n# rotary pos emb helpers:\ndef rotate_half(x):\n    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]\n    return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0\n\n@torch.jit.script\ndef apply_rotary_pos_emb(q, k, cos, sin):\n    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)\n"
  },
  {
    "path": "torchtools/nn/simple_self_attention.py",
    "content": "import torch.nn as nn\nimport torch, math, sys\n\n####\n# CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention\n####\n\n#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py\ndef conv1d(ni, no, ks=1, stride=1, padding=0, bias=False):\n\t\"Create and initialize a `nn.Conv1d` layer with spectral normalization.\"\n\tconv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)\n\tnn.init.kaiming_normal_(conv.weight)\n\tif bias: conv.bias.data.zero_()\n\treturn nn.utils.spectral_norm(conv)\n\n# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py\n# Inspired by https://arxiv.org/pdf/1805.08318.pdf\nclass SimpleSelfAttention(nn.Module):\n\t\n\tdef __init__(self, n_in, ks=1, sym=False):\n\t\tsuper().__init__()        \n\t\tself.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)            \n\t\tself.gamma = nn.Parameter(torch.Tensor([0.]))      \n\t\tself.sym = sym\n\t\tself.n_in = n_in\n\t\t\n\tdef forward(self, x):\n\t\tif self.sym:\n\t\t\t# symmetry hack by https://github.com/mgrankin\n\t\t\tc = self.conv.weight.view(self.n_in,self.n_in)\n\t\t\tc = (c + c.t())/2\n\t\t\tself.conv.weight = c.view(self.n_in,self.n_in,1)\n\t\t\t\t\n\t\tsize = x.size()  \n\t\tx = x.view(*size[:2],-1)   # (C,N)\n\t\t\n\t\t# changed the order of mutiplication to avoid O(N^2) complexity\n\t\t# (x*xT)*(W*x) instead of (x*(xT*(W*x)))\n\t\t\n\t\tconvx = self.conv(x)   # (C,C) * (C,N) = (C,N)   => O(NC^2)\n\t\txxT = torch.bmm(x, x.permute(0,2,1).contiguous())   # (C,N) * (N,C) = (C,C)   => O(NC^2)  \t\t    \n\t\to = torch.bmm(xxT, convx)   # (C,C) * (C,N) = (C,N)   => O(NC^2)         \n\t\to = self.gamma * o + x  \n\t\t  \n\t\treturn o.view(*size).contiguous()  "
  },
  {
    "path": "torchtools/nn/stylegan2/__init__.py",
    "content": "from .upfirdn2d import upfirdn2d"
  },
  {
    "path": "torchtools/nn/stylegan2/upfirdn2d.cpp",
    "content": "#include <ATen/ATen.h>\r\n#include <torch/extension.h>\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n                           const torch::Tensor &kernel, int up_x, int up_y,\r\n                           int down_x, int down_y, int pad_x0, int pad_x1,\r\n                           int pad_y0, int pad_y1);\r\n\r\n#define CHECK_CUDA(x)                                                          \\\r\n  TORCH_CHECK(x.type().is_cuda(), #x \" must be a CUDA tensor\")\r\n#define CHECK_CONTIGUOUS(x)                                                    \\\r\n  TORCH_CHECK(x.is_contiguous(), #x \" must be contiguous\")\r\n#define CHECK_INPUT(x)                                                         \\\r\n  CHECK_CUDA(x);                                                               \\\r\n  CHECK_CONTIGUOUS(x)\r\n\r\ntorch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,\r\n                        int up_x, int up_y, int down_x, int down_y, int pad_x0,\r\n                        int pad_x1, int pad_y0, int pad_y1) {\r\n  CHECK_INPUT(input);\r\n  CHECK_INPUT(kernel);\r\n\r\n  at::DeviceGuard guard(input.device());\r\n\r\n  return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,\r\n                      pad_y0, pad_y1);\r\n}\r\n\r\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\r\n  m.def(\"upfirdn2d\", &upfirdn2d, \"upfirdn2d (CUDA)\");\r\n}"
  },
  {
    "path": "torchtools/nn/stylegan2/upfirdn2d.py",
    "content": "from collections import abc\r\nimport os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import Function\r\nfrom torch.utils.cpp_extension import load\r\n\r\n\r\nmodule_path = os.path.dirname(__file__)\r\nupfirdn2d_op = load(\r\n    \"upfirdn2d\",\r\n    sources=[\r\n        os.path.join(module_path, \"upfirdn2d.cpp\"),\r\n        os.path.join(module_path, \"upfirdn2d_kernel.cu\"),\r\n    ],\r\n)\r\n\r\n\r\nclass UpFirDn2dBackward(Function):\r\n    @staticmethod\r\n    def forward(\r\n        ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size\r\n    ):\r\n\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad\r\n\r\n        grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)\r\n\r\n        grad_input = upfirdn2d_op.upfirdn2d(\r\n            grad_output,\r\n            grad_kernel,\r\n            down_x,\r\n            down_y,\r\n            up_x,\r\n            up_y,\r\n            g_pad_x0,\r\n            g_pad_x1,\r\n            g_pad_y0,\r\n            g_pad_y1,\r\n        )\r\n        grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])\r\n\r\n        ctx.save_for_backward(kernel)\r\n\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        ctx.up_x = up_x\r\n        ctx.up_y = up_y\r\n        ctx.down_x = down_x\r\n        ctx.down_y = down_y\r\n        ctx.pad_x0 = pad_x0\r\n        ctx.pad_x1 = pad_x1\r\n        ctx.pad_y0 = pad_y0\r\n        ctx.pad_y1 = pad_y1\r\n        ctx.in_size = in_size\r\n        ctx.out_size = out_size\r\n\r\n        return grad_input\r\n\r\n    @staticmethod\r\n    def backward(ctx, gradgrad_input):\r\n        kernel, = ctx.saved_tensors\r\n\r\n        gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)\r\n\r\n        gradgrad_out = upfirdn2d_op.upfirdn2d(\r\n            gradgrad_input,\r\n            kernel,\r\n            ctx.up_x,\r\n            ctx.up_y,\r\n            ctx.down_x,\r\n            ctx.down_y,\r\n            ctx.pad_x0,\r\n            ctx.pad_x1,\r\n            ctx.pad_y0,\r\n            ctx.pad_y1,\r\n        )\r\n        # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])\r\n        gradgrad_out = gradgrad_out.view(\r\n            ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]\r\n        )\r\n\r\n        return gradgrad_out, None, None, None, None, None, None, None, None\r\n\r\n\r\nclass UpFirDn2d(Function):\r\n    @staticmethod\r\n    def forward(ctx, input, kernel, up, down, pad):\r\n        up_x, up_y = up\r\n        down_x, down_y = down\r\n        pad_x0, pad_x1, pad_y0, pad_y1 = pad\r\n\r\n        kernel_h, kernel_w = kernel.shape\r\n        batch, channel, in_h, in_w = input.shape\r\n        ctx.in_size = input.shape\r\n\r\n        input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n        ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))\r\n\r\n        out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y\r\n        out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x\r\n        ctx.out_size = (out_h, out_w)\r\n\r\n        ctx.up = (up_x, up_y)\r\n        ctx.down = (down_x, down_y)\r\n        ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)\r\n\r\n        g_pad_x0 = kernel_w - pad_x0 - 1\r\n        g_pad_y0 = kernel_h - pad_y0 - 1\r\n        g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1\r\n        g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1\r\n\r\n        ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)\r\n\r\n        out = upfirdn2d_op.upfirdn2d(\r\n            input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n        )\r\n        # out = out.view(major, out_h, out_w, minor)\r\n        out = out.view(-1, channel, out_h, out_w)\r\n\r\n        return out\r\n\r\n    @staticmethod\r\n    def backward(ctx, grad_output):\r\n        kernel, grad_kernel = ctx.saved_tensors\r\n\r\n        grad_input = None\r\n\r\n        if ctx.needs_input_grad[0]:\r\n            grad_input = UpFirDn2dBackward.apply(\r\n                grad_output,\r\n                kernel,\r\n                grad_kernel,\r\n                ctx.up,\r\n                ctx.down,\r\n                ctx.pad,\r\n                ctx.g_pad,\r\n                ctx.in_size,\r\n                ctx.out_size,\r\n            )\r\n\r\n        return grad_input, None, None, None, None\r\n\r\n\r\ndef upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):\r\n    if not isinstance(up, abc.Iterable):\r\n        up = (up, up)\r\n\r\n    if not isinstance(down, abc.Iterable):\r\n        down = (down, down)\r\n\r\n    if len(pad) == 2:\r\n        pad = (pad[0], pad[1], pad[0], pad[1])\r\n\r\n    if input.device.type == \"cpu\":\r\n        out = upfirdn2d_native(input, kernel, *up, *down, *pad)\r\n\r\n    else:\r\n        out = UpFirDn2d.apply(input, kernel, up, down, pad)\r\n\r\n    return out\r\n\r\n\r\ndef upfirdn2d_native(\r\n    input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1\r\n):\r\n    _, channel, in_h, in_w = input.shape\r\n    input = input.reshape(-1, in_h, in_w, 1)\r\n\r\n    _, in_h, in_w, minor = input.shape\r\n    kernel_h, kernel_w = kernel.shape\r\n\r\n    out = input.view(-1, in_h, 1, in_w, 1, minor)\r\n    out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])\r\n    out = out.view(-1, in_h * up_y, in_w * up_x, minor)\r\n\r\n    out = F.pad(\r\n        out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]\r\n    )\r\n    out = out[\r\n        :,\r\n        max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),\r\n        max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),\r\n        :,\r\n    ]\r\n\r\n    out = out.permute(0, 3, 1, 2)\r\n    out = out.reshape(\r\n        [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]\r\n    )\r\n    w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)\r\n    out = F.conv2d(out, w)\r\n    out = out.reshape(\r\n        -1,\r\n        minor,\r\n        in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,\r\n        in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,\r\n    )\r\n    out = out.permute(0, 2, 3, 1)\r\n    out = out[:, ::down_y, ::down_x, :]\r\n\r\n    out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y\r\n    out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x\r\n\r\n    return out.view(-1, channel, out_h, out_w)\r\n"
  },
  {
    "path": "torchtools/nn/stylegan2/upfirdn2d_kernel.cu",
    "content": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Source Code License-NC.\r\n// To view a copy of this license, visit\r\n// https://nvlabs.github.io/stylegan2/license.html\r\n\r\n#include <torch/types.h>\r\n\r\n#include <ATen/ATen.h>\r\n#include <ATen/AccumulateType.h>\r\n#include <ATen/cuda/CUDAApplyUtils.cuh>\r\n#include <ATen/cuda/CUDAContext.h>\r\n\r\n#include <cuda.h>\r\n#include <cuda_runtime.h>\r\n\r\nstatic __device__ __forceinline__ int floor_div(int a, int b) {\r\n  int t = 1 - a / b;\r\n  return (a + t * b) / b - t;\r\n}\r\n\r\nstruct UpFirDn2DKernelParams {\r\n  int up_x;\r\n  int up_y;\r\n  int down_x;\r\n  int down_y;\r\n  int pad_x0;\r\n  int pad_x1;\r\n  int pad_y0;\r\n  int pad_y1;\r\n\r\n  int major_dim;\r\n  int in_h;\r\n  int in_w;\r\n  int minor_dim;\r\n  int kernel_h;\r\n  int kernel_w;\r\n  int out_h;\r\n  int out_w;\r\n  int loop_major;\r\n  int loop_x;\r\n};\r\n\r\ntemplate <typename scalar_t>\r\n__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,\r\n                                       const scalar_t *kernel,\r\n                                       const UpFirDn2DKernelParams p) {\r\n  int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;\r\n  int out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= out_y * p.minor_dim;\r\n  int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (out_x_base >= p.out_w || out_y >= p.out_h ||\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;\r\n  int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);\r\n  int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;\r\n  int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major && major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, out_x = out_x_base;\r\n         loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {\r\n      int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;\r\n      int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);\r\n      int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;\r\n      int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;\r\n\r\n      const scalar_t *x_p =\r\n          &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +\r\n                 minor_idx];\r\n      const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];\r\n      int x_px = p.minor_dim;\r\n      int k_px = -p.up_x;\r\n      int x_py = p.in_w * p.minor_dim;\r\n      int k_py = -p.up_y * p.kernel_w;\r\n\r\n      scalar_t v = 0.0f;\r\n\r\n      for (int y = 0; y < h; y++) {\r\n        for (int x = 0; x < w; x++) {\r\n          v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);\r\n          x_p += x_px;\r\n          k_p += k_px;\r\n        }\r\n\r\n        x_p += x_py - w * x_px;\r\n        k_p += k_py - w * k_px;\r\n      }\r\n\r\n      out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n          minor_idx] = v;\r\n    }\r\n  }\r\n}\r\n\r\ntemplate <typename scalar_t, int up_x, int up_y, int down_x, int down_y,\r\n          int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>\r\n__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,\r\n                                 const scalar_t *kernel,\r\n                                 const UpFirDn2DKernelParams p) {\r\n  const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;\r\n  const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;\r\n\r\n  __shared__ volatile float sk[kernel_h][kernel_w];\r\n  __shared__ volatile float sx[tile_in_h][tile_in_w];\r\n\r\n  int minor_idx = blockIdx.x;\r\n  int tile_out_y = minor_idx / p.minor_dim;\r\n  minor_idx -= tile_out_y * p.minor_dim;\r\n  tile_out_y *= tile_out_h;\r\n  int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;\r\n  int major_idx_base = blockIdx.z * p.loop_major;\r\n\r\n  if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |\r\n      major_idx_base >= p.major_dim) {\r\n    return;\r\n  }\r\n\r\n  for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;\r\n       tap_idx += blockDim.x) {\r\n    int ky = tap_idx / kernel_w;\r\n    int kx = tap_idx - ky * kernel_w;\r\n    scalar_t v = 0.0;\r\n\r\n    if (kx < p.kernel_w & ky < p.kernel_h) {\r\n      v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];\r\n    }\r\n\r\n    sk[ky][kx] = v;\r\n  }\r\n\r\n  for (int loop_major = 0, major_idx = major_idx_base;\r\n       loop_major < p.loop_major & major_idx < p.major_dim;\r\n       loop_major++, major_idx++) {\r\n    for (int loop_x = 0, tile_out_x = tile_out_x_base;\r\n         loop_x < p.loop_x & tile_out_x < p.out_w;\r\n         loop_x++, tile_out_x += tile_out_w) {\r\n      int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;\r\n      int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;\r\n      int tile_in_x = floor_div(tile_mid_x, up_x);\r\n      int tile_in_y = floor_div(tile_mid_y, up_y);\r\n\r\n      __syncthreads();\r\n\r\n      for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;\r\n           in_idx += blockDim.x) {\r\n        int rel_in_y = in_idx / tile_in_w;\r\n        int rel_in_x = in_idx - rel_in_y * tile_in_w;\r\n        int in_x = rel_in_x + tile_in_x;\r\n        int in_y = rel_in_y + tile_in_y;\r\n\r\n        scalar_t v = 0.0;\r\n\r\n        if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {\r\n          v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *\r\n                        p.minor_dim +\r\n                    minor_idx];\r\n        }\r\n\r\n        sx[rel_in_y][rel_in_x] = v;\r\n      }\r\n\r\n      __syncthreads();\r\n      for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;\r\n           out_idx += blockDim.x) {\r\n        int rel_out_y = out_idx / tile_out_w;\r\n        int rel_out_x = out_idx - rel_out_y * tile_out_w;\r\n        int out_x = rel_out_x + tile_out_x;\r\n        int out_y = rel_out_y + tile_out_y;\r\n\r\n        int mid_x = tile_mid_x + rel_out_x * down_x;\r\n        int mid_y = tile_mid_y + rel_out_y * down_y;\r\n        int in_x = floor_div(mid_x, up_x);\r\n        int in_y = floor_div(mid_y, up_y);\r\n        int rel_in_x = in_x - tile_in_x;\r\n        int rel_in_y = in_y - tile_in_y;\r\n        int kernel_x = (in_x + 1) * up_x - mid_x - 1;\r\n        int kernel_y = (in_y + 1) * up_y - mid_y - 1;\r\n\r\n        if (out_x < p.out_w & out_y < p.out_h) {\r\n          scalar_t v = 0.0;\r\n\r\n#pragma unroll\r\n          for (int y = 0; y < kernel_h / up_y; y++)\r\n#pragma unroll\r\n            for (int x = 0; x < kernel_w / up_x; x++)\r\n              v += sx[rel_in_y + y][rel_in_x + x] *\r\n                   sk[kernel_y + y * up_y][kernel_x + x * up_x];\r\n\r\n          out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +\r\n              minor_idx] = v;\r\n        }\r\n      }\r\n    }\r\n  }\r\n}\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n                           const torch::Tensor &kernel, int up_x, int up_y,\r\n                           int down_x, int down_y, int pad_x0, int pad_x1,\r\n                           int pad_y0, int pad_y1) {\r\n  int curDevice = -1;\r\n  cudaGetDevice(&curDevice);\r\n  cudaStream_t stream = at::cuda::getCurrentCUDAStream();\r\n\r\n  UpFirDn2DKernelParams p;\r\n\r\n  auto x = input.contiguous();\r\n  auto k = kernel.contiguous();\r\n\r\n  p.major_dim = x.size(0);\r\n  p.in_h = x.size(1);\r\n  p.in_w = x.size(2);\r\n  p.minor_dim = x.size(3);\r\n  p.kernel_h = k.size(0);\r\n  p.kernel_w = k.size(1);\r\n  p.up_x = up_x;\r\n  p.up_y = up_y;\r\n  p.down_x = down_x;\r\n  p.down_y = down_y;\r\n  p.pad_x0 = pad_x0;\r\n  p.pad_x1 = pad_x1;\r\n  p.pad_y0 = pad_y0;\r\n  p.pad_y1 = pad_y1;\r\n\r\n  p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /\r\n            p.down_y;\r\n  p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /\r\n            p.down_x;\r\n\r\n  auto out =\r\n      at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());\r\n\r\n  int mode = -1;\r\n\r\n  int tile_out_h = -1;\r\n  int tile_out_w = -1;\r\n\r\n  AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&] {\r\n    void *cuda_kernel = (void *)upfirdn2d_kernel_large<scalar_t>;\r\n\r\n    if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 1 && p.kernel_w <= 24) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 2, 1, 1, 1, 1, 24, 8, 128>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 128;\r\n    }\r\n\r\n    if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 1 && p.kernel_w <= 12) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 2, 1, 1, 1, 1, 12, 8, 128>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 128;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 24 && p.kernel_w <= 1) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 2, 1, 1, 24, 1, 32, 32>;\r\n      tile_out_h = 32;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 12 && p.kernel_w <= 1) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 2, 1, 1, 12, 1, 32, 32>;\r\n      tile_out_h = 32;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    //\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 &&\r\n        p.kernel_h <= 1 && p.kernel_w <= 24) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 1, 1, 24, 8, 64>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 &&\r\n        p.kernel_h <= 1 && p.kernel_w <= 12) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 1, 1, 12, 8, 64>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 &&\r\n        p.kernel_h <= 24 && p.kernel_w <= 1) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 2, 24, 1, 16, 32>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 &&\r\n        p.kernel_h <= 12 && p.kernel_w <= 1) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 2, 12, 1, 16, 32>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    //\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 3 && p.kernel_w <= 3) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&\r\n        p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n      cuda_kernel =\r\n          (void *)upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>;\r\n      tile_out_h = 16;\r\n      tile_out_w = 64;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n        p.kernel_h <= 4 && p.kernel_w <= 4) {\r\n      cuda_kernel = (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&\r\n        p.kernel_h <= 2 && p.kernel_w <= 2) {\r\n      cuda_kernel = (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>;\r\n      tile_out_h = 8;\r\n      tile_out_w = 32;\r\n    }\r\n\r\n    dim3 block_size;\r\n    dim3 grid_size;\r\n\r\n    if (tile_out_h > 0 && tile_out_w > 0) {\r\n      p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n      p.loop_x = 1;\r\n      block_size = dim3(32 * 8, 1, 1);\r\n      grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,\r\n                       (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,\r\n                       (p.major_dim - 1) / p.loop_major + 1);\r\n    } else {\r\n      p.loop_major = (p.major_dim - 1) / 16384 + 1;\r\n      p.loop_x = 4;\r\n      block_size = dim3(4, 32, 1);\r\n      grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,\r\n                       (p.out_w - 1) / (p.loop_x * block_size.y) + 1,\r\n                       (p.major_dim - 1) / p.loop_major + 1);\r\n    }\r\n\r\n    scalar_t *out_p = out.data_ptr<scalar_t>();\r\n    scalar_t *x_p = x.data_ptr<scalar_t>();\r\n    scalar_t *k_p = k.data_ptr<scalar_t>();\r\n\r\n    void *args[] = {&out_p, &x_p, &k_p, &p};\r\n    AT_CUDA_CHECK(\r\n        cudaLaunchKernel(cuda_kernel, grid_size, block_size, args, 0, stream));\r\n  });\r\n\r\n  return out;\r\n}"
  },
  {
    "path": "torchtools/nn/transformers.py",
    "content": "import torch.nn as nn\n\n\n# Based on the GPT2 implementatyion from MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy\nclass GPTTransformerEncoderLayer(nn.Module):\n    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0):\n        super().__init__()\n        self.ln1 = nn.LayerNorm(d_model)\n        self.ln2 = nn.LayerNorm(d_model)\n        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)\n        self.mlp = nn.Sequential(\n            nn.Linear(d_model, dim_feedforward),\n            nn.GELU(),\n            nn.Linear(dim_feedforward, d_model),\n            nn.Dropout(dropout),\n        )\n\n    def forward(self, x, src_mask=None, src_key_padding_mask=None):\n        x = self.ln1(x)\n        x = x + self.attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]\n        x = x + self.mlp(self.ln2(x))\n        return x"
  },
  {
    "path": "torchtools/nn/vq.py",
    "content": "import torch\nfrom torch import nn\nfrom .functional.vq import vector_quantize, binarize\nimport numpy as np\n\nclass VectorQuantize(nn.Module):\n\tdef __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):\n\t\t\"\"\"\n\t\tTakes an input of variable size (as long as the last dimension matches the embedding size).\n\t\tReturns one tensor containing the nearest neigbour embeddings to each of the inputs, \n\t\twith the same size as the input, vq and commitment components for the loss as a touple \n\t\tin the second output and the indices of the quantized vectors in the third: \n\t\tquantized, (vq_loss, commit_loss), indices\n\t\t\"\"\"\n\t\tsuper(VectorQuantize, self).__init__()\n\n\t\tself.codebook = nn.Embedding(k, embedding_size)\n\t\tself.codebook.weight.data.uniform_(-1./k, 1./k)\t\n\t\tself.vq = vector_quantize.apply\n\n\t\tself.ema_decay = ema_decay\n\t\tself.ema_loss = ema_loss\n\t\tif ema_loss:\n\t\t\tself.register_buffer('ema_element_count', torch.ones(k))\n\t\t\tself.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))\t\t\n\n\tdef _laplace_smoothing(self, x, epsilon):\n\t\tn = torch.sum(x)\n\t\treturn ((x + epsilon) / (n + x.size(0) * epsilon) * n)\n\n\tdef _updateEMA(self, z_e_x, indices):\n\t\tmask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()\n\t\telem_count = mask.sum(dim=0)\n\t\tweight_sum = torch.mm(mask.t(), z_e_x)\n\t\t\n\t\tself.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)\n\t\tself.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)\t\t\n\t\tself.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)\n\t\t\n\t\tself.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)\n\n\tdef idx2vq(self, idx, dim=-1):\n\t\tq_idx = self.codebook(idx)\n\t\tif dim != -1:\n\t\t\tq_idx = q_idx.movedim(-1, dim)\n\t\treturn q_idx\n\n\tdef forward(self, x, get_losses=True, dim=-1):\n\t\tif dim != -1:\n\t\t\tx = x.movedim(dim, -1)\n\t\tz_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x\n\t\tz_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())\t\n\t\tvq_loss, commit_loss = None, None\t\n\t\tif self.ema_loss and self.training:\n\t\t\tself._updateEMA(z_e_x.detach(), indices.detach())\n\t\t# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss\n\t\tz_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) \n\t\tif get_losses:\n\t\t\tvq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()\n\t\t\tcommit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()\n\n\t\tz_q_x = z_q_x.view(x.shape)\n\t\tif dim != -1:\n\t\t\tz_q_x = z_q_x.movedim(-1, dim)\n\t\treturn z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])\n\nclass Binarize(nn.Module):\n\tdef __init__(self, threshold=0.5):\n\t\t\"\"\"\n\t\tTakes an input of any size.\n\t\tReturns an output of the same size but with its values binarized (0 if input is below a threshold, 1 if its above)\n\t\t\"\"\"\n\t\tsuper(Binarize, self).__init__()\n\n\t\tself.bin = binarize.apply\n\t\tself.threshold = threshold\n\n\tdef forward(self, x):\n\t\treturn self.bin(x, self.threshold)\n\t\n# Finite Scalar Quantization: https://arxiv.org/abs/2309.15505\nclass FSQ(nn.Module):\n    def __init__(self, bins, dim=-1, eps=1e-1):\n        super().__init__()\n        self.dim = dim\n        self.eps = eps\n        self.register_buffer('bins', torch.tensor(bins))\n        self.register_buffer('bases', torch.tensor([1] + np.cumprod(bins[:-1]).tolist()))\n        self.codebook_size = np.prod(bins)\n        \n        self.in_shift, self.out_shift = None, None\n\n    def _round(self, x, quantize):\n        x = x.sigmoid() * (1-1e-7)\n        if quantize is True:\n            x_rounded = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins)\n            x = x + (x_rounded - x).detach()\n        x_sigmoid = x\n        x = (x / (1-1e-7)).logit(eps=self.eps)\n        return x, x_sigmoid\n\n    def vq_to_idx(self, x, is_sigmoid=False):\n        if not is_sigmoid:\n            x = x.sigmoid() * (1-1e-7)\n            x = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins)\n        x = x.mul(self.bins-1).long()\n        x = (x * self.bases).sum(dim=-1).long()\n        return x\n\n    def idx_to_vq(self, x):\n        x = x.unsqueeze(-1) // self.bases % self.bins\n        x = x.div(self.bins-1)\n        x = (x / (1-1e-7)).logit(eps=self.eps)\n        if self.dim != -1:\n            x = x.movedim(-1, self.dim)\n        return x\n\n    def forward(self, x, quantize=True):\n        if self.dim != -1:\n            x = x.movedim(self.dim, -1)\n\n        x, x_sigmoid = self._round(x, quantize=quantize)\n        idx = self.vq_to_idx(x_sigmoid, is_sigmoid=True)\n\n        if self.dim != -1:\n            x = x.movedim(-1, self.dim)\n        return x, idx\n"
  },
  {
    "path": "torchtools/optim/__init__.py",
    "content": "from .radam import RAdam, PlainRAdam, AdamW\nfrom .ranger import Ranger\nfrom .lookahead import Lookahead, LookaheadAdam\nfrom .over9000 import Over9000, RangerLars\nfrom .ralamb import Ralamb\nfrom .novograd import Novograd\nfrom .lamb import Lamb\n"
  },
  {
    "path": "torchtools/optim/lamb.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport collections\nimport math\n\nimport torch\nfrom torch.optim import Optimizer\n\ntry: \n    from tensorboardX import SummaryWriter\n\n    def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):\n        \"\"\"Log a histogram of trust ratio scalars in across layers.\"\"\"\n        results = collections.defaultdict(list)\n        for group in optimizer.param_groups:\n            for p in group['params']:\n                state = optimizer.state[p]\n                for i in ('weight_norm', 'adam_norm', 'trust_ratio'):\n                    if i in state:\n                        results[i].append(state[i])\n\n        for k, v in results.items():\n            event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)\nexcept ModuleNotFoundError as e: \n    print(\"To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results\")\n\nclass Lamb(Optimizer):\n    r\"\"\"Implements Lamb algorithm.\n    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        adam (bool, optional): always use trust ratio = 1, which turns this into\n            Adam. Useful for comparison purposes.\n    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:\n        https://arxiv.org/abs/1904.00962\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,\n                 weight_decay=0, adam=False):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay)\n        self.adam = adam\n        super(Lamb, self).__init__(params, defaults)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                # Decay the first and second moment running average coefficient\n                # m_t\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n                # v_t\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n\n                # Paper v3 does not use debiasing.\n                # bias_correction1 = 1 - beta1 ** state['step']\n                # bias_correction2 = 1 - beta2 ** state['step']\n                # Apply bias to lr to avoid broadcast.\n                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1\n\n                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)\n\n                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])\n                if group['weight_decay'] != 0:\n                    adam_step.add_(group['weight_decay'], p.data)\n\n                adam_norm = adam_step.pow(2).sum().sqrt()\n                if weight_norm == 0 or adam_norm == 0:\n                    trust_ratio = 1\n                else:\n                    trust_ratio = weight_norm / adam_norm\n                state['weight_norm'] = weight_norm\n                state['adam_norm'] = adam_norm\n                state['trust_ratio'] = trust_ratio\n                if self.adam:\n                    trust_ratio = 1\n\n                p.data.add_(-step_size * trust_ratio, adam_step)\n\n        return loss"
  },
  {
    "path": "torchtools/optim/lookahead.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch\n# Original paper: https://arxiv.org/abs/1907.08610\n####\n# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py\n\n\"\"\" Lookahead Optimizer Wrapper.\nImplementation modified from: https://github.com/alphadl/lookahead.pytorch\nPaper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610\n\"\"\"\nimport torch\nfrom torch.optim.optimizer import Optimizer\nfrom collections import defaultdict\n\n\nclass Lookahead(Optimizer):\n    def __init__(self, base_optimizer, alpha=0.5, k=6):\n        if not 0.0 <= alpha <= 1.0:\n            raise ValueError(f'Invalid slow update rate: {alpha}')\n        if not 1 <= k:\n            raise ValueError(f'Invalid lookahead steps: {k}')\n        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)\n        self.base_optimizer = base_optimizer\n        self.param_groups = self.base_optimizer.param_groups\n        self.defaults = base_optimizer.defaults\n        self.defaults.update(defaults)\n        self.state = defaultdict(dict)\n        # manually add our defaults to the param groups\n        for name, default in defaults.items():\n            for group in self.param_groups:\n                group.setdefault(name, default)\n\n    def update_slow(self, group):\n        for fast_p in group[\"params\"]:\n            if fast_p.grad is None:\n                continue\n            param_state = self.state[fast_p]\n            if 'slow_buffer' not in param_state:\n                param_state['slow_buffer'] = torch.empty_like(fast_p.data)\n                param_state['slow_buffer'].copy_(fast_p.data)\n            slow = param_state['slow_buffer']\n            slow.add_(group['lookahead_alpha'], fast_p.data - slow)\n            fast_p.data.copy_(slow)\n\n    def sync_lookahead(self):\n        for group in self.param_groups:\n            self.update_slow(group)\n\n    def step(self, closure=None):\n        # print(self.k)\n        # assert id(self.param_groups) == id(self.base_optimizer.param_groups)\n        loss = self.base_optimizer.step(closure)\n        for group in self.param_groups:\n            group['lookahead_step'] += 1\n            if group['lookahead_step'] % group['lookahead_k'] == 0:\n                self.update_slow(group)\n        return loss\n\n    def state_dict(self):\n        fast_state_dict = self.base_optimizer.state_dict()\n        slow_state = {\n            (id(k) if isinstance(k, torch.Tensor) else k): v\n            for k, v in self.state.items()\n        }\n        fast_state = fast_state_dict['state']\n        param_groups = fast_state_dict['param_groups']\n        return {\n            'state': fast_state,\n            'slow_state': slow_state,\n            'param_groups': param_groups,\n        }\n\n    def load_state_dict(self, state_dict):\n        fast_state_dict = {\n            'state': state_dict['state'],\n            'param_groups': state_dict['param_groups'],\n        }\n        self.base_optimizer.load_state_dict(fast_state_dict)\n\n        # We want to restore the slow state, but share param_groups reference\n        # with base_optimizer. This is a bit redundant but least code\n        slow_state_new = False\n        if 'slow_state' not in state_dict:\n            print('Loading state_dict from optimizer without Lookahead applied.')\n            state_dict['slow_state'] = defaultdict(dict)\n            slow_state_new = True\n        slow_state_dict = {\n            'state': state_dict['slow_state'],\n            'param_groups': state_dict['param_groups'],  # this is pointless but saves code\n        }\n        super(Lookahead, self).load_state_dict(slow_state_dict)\n        self.param_groups = self.base_optimizer.param_groups  # make both ref same container\n        if slow_state_new:\n            # reapply defaults to catch missing lookahead specific ones\n            for name, default in self.defaults.items():\n                for group in self.param_groups:\n                    group.setdefault(name, default)\n\n\ndef LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):\n    adam = Adam(params, *args, **kwargs)\n    return Lookahead(adam, alpha, k)\n"
  },
  {
    "path": "torchtools/optim/novograd.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport torch\nfrom torch.optim import Optimizer\nimport math\n\n\nclass AdamW(Optimizer):\n    \"\"\"Implements AdamW algorithm.\n  \n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n  \n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n  \n        Adam: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n        On the Convergence of Adam and Beyond:\n        https://openreview.net/forum?id=ryQu7f-RZ\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,\n                 weight_decay=0, amsgrad=False):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay, amsgrad=amsgrad)\n        super(AdamW, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(AdamW, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('amsgrad', False)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n  \n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')\n                amsgrad = group['amsgrad']\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                if amsgrad:\n                    max_exp_avg_sq = state['max_exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n                # Decay the first and second moment running average coefficient\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                if amsgrad:\n                    # Maintains the maximum of all 2nd moment running avg. till now\n                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)\n                    # Use the max. for normalizing running avg. of gradient\n                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])\n                else:\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n\n                bias_correction1 = 1 - beta1 ** state['step']\n                bias_correction2 = 1 - beta2 ** state['step']\n                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1\n                p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))\n\n        return loss\n\n\nclass Novograd(Optimizer):\n    \"\"\"\n    Implements Novograd algorithm.\n\n    Args:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.95, 0))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n        grad_averaging: gradient averaging\n        amsgrad (boolean, optional): whether to use the AMSGrad variant of this\n            algorithm from the paper `On the Convergence of Adam and Beyond`_\n            (default: False)\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,\n                 weight_decay=0, grad_averaging=False, amsgrad=False):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay,\n                        grad_averaging=grad_averaging,\n                        amsgrad=amsgrad)\n\n        super(Novograd, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(Novograd, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('amsgrad', False)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n            and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('Sparse gradients are not supported.')\n                amsgrad = group['amsgrad']\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                if amsgrad:\n                    max_exp_avg_sq = state['max_exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                norm = torch.sum(torch.pow(grad, 2))\n\n                if exp_avg_sq == 0:\n                    exp_avg_sq.copy_(norm)\n                else:\n                    exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)\n\n                if amsgrad:\n                    # Maintains the maximum of all 2nd moment running avg. till now\n                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)\n                    # Use the max. for normalizing running avg. of gradient\n                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])\n                else:\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n\n                grad.div_(denom)\n                if group['weight_decay'] != 0:\n                    grad.add_(group['weight_decay'], p.data)\n                if group['grad_averaging']:\n                    grad.mul_(1 - beta1)\n                exp_avg.mul_(beta1).add_(grad)\n\n                p.data.add_(-group['lr'], exp_avg)\n\n        return loss\n"
  },
  {
    "path": "torchtools/optim/over9000.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport torch, math\nfrom torch.optim.optimizer import Optimizer\nimport itertools as it\nfrom .lookahead import Lookahead\nfrom .ralamb import Ralamb\n\n\n# RAdam + LARS + LookAHead\n\n# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py\n# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20\n\ndef Over9000(params, alpha=0.5, k=6, *args, **kwargs):\n    ralamb = Ralamb(params, *args, **kwargs)\n    return Lookahead(ralamb, alpha, k)\n\n\nRangerLars = Over9000\n"
  },
  {
    "path": "torchtools/optim/radam.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam\n# Paper: https://arxiv.org/abs/1908.03265\n####\n\nimport math\nimport torch\nfrom torch.optim.optimizer import Optimizer, required\n\nclass RAdam(Optimizer):\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        \n        self.degenerated_to_sgd = degenerated_to_sgd\n        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):\n            for param in params:\n                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):\n                    param['buffer'] = [[None, None, None] for _ in range(10)]\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])\n        super(RAdam, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(RAdam, self).__setstate__(state)\n\n    def step(self, closure=None):\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError('RAdam does not support sparse gradients')\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state['step'] = 0\n                    state['exp_avg'] = torch.zeros_like(p_data_fp32)\n                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)\n                else:\n                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)\n                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                state['step'] += 1\n                buffered = group['buffer'][int(state['step'] % 10)]\n                if state['step'] == buffered[0]:\n                    N_sma, step_size = buffered[1], buffered[2]\n                else:\n                    buffered[0] = state['step']\n                    beta2_t = beta2 ** state['step']\n                    N_sma_max = 2 / (1 - beta2) - 1\n                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)\n                    buffered[1] = N_sma\n\n                    # more conservative since it's an approximated value\n                    if N_sma >= 5:\n                        step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])\n                    elif self.degenerated_to_sgd:\n                        step_size = 1.0 / (1 - beta1 ** state['step'])\n                    else:\n                        step_size = -1\n                    buffered[2] = step_size\n\n                # more conservative since it's an approximated value\n                if N_sma >= 5:\n                    if group['weight_decay'] != 0:\n                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)\n                    p.data.copy_(p_data_fp32)\n                elif step_size > 0:\n                    if group['weight_decay'] != 0:\n                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)\n                    p.data.copy_(p_data_fp32)\n\n        return loss\n\nclass PlainRAdam(Optimizer):\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n                    \n        self.degenerated_to_sgd = degenerated_to_sgd\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n\n        super(PlainRAdam, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(PlainRAdam, self).__setstate__(state)\n\n    def step(self, closure=None):\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError('RAdam does not support sparse gradients')\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state['step'] = 0\n                    state['exp_avg'] = torch.zeros_like(p_data_fp32)\n                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)\n                else:\n                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)\n                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                state['step'] += 1\n                beta2_t = beta2 ** state['step']\n                N_sma_max = 2 / (1 - beta2) - 1\n                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)\n\n\n                # more conservative since it's an approximated value\n                if N_sma >= 5:\n                    if group['weight_decay'] != 0:\n                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n                    step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)\n                    p.data.copy_(p_data_fp32)\n                elif self.degenerated_to_sgd:\n                    if group['weight_decay'] != 0:\n                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n                    step_size = group['lr'] / (1 - beta1 ** state['step'])\n                    p_data_fp32.add_(-step_size, exp_avg)\n                    p.data.copy_(p_data_fp32)\n\n        return loss\n\n\nclass AdamW(Optimizer):\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        \n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay, warmup = warmup)\n        super(AdamW, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(AdamW, self).__setstate__(state)\n\n    def step(self, closure=None):\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state['step'] = 0\n                    state['exp_avg'] = torch.zeros_like(p_data_fp32)\n                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)\n                else:\n                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)\n                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n\n                denom = exp_avg_sq.sqrt().add_(group['eps'])\n                bias_correction1 = 1 - beta1 ** state['step']\n                bias_correction2 = 1 - beta2 ** state['step']\n                \n                if group['warmup'] > state['step']:\n                    scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']\n                else:\n                    scheduled_lr = group['lr']\n\n                step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1\n                \n                if group['weight_decay'] != 0:\n                    p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)\n\n                p_data_fp32.addcdiv_(-step_size, exp_avg, denom)\n\n                p.data.copy_(p_data_fp32)\n\n        return loss"
  },
  {
    "path": "torchtools/optim/ralamb.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport torch, math\nfrom torch.optim.optimizer import Optimizer\n\n# RAdam + LARS\nclass Ralamb(Optimizer):\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):\n        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)\n        self.buffer = [[None, None, None] for ind in range(10)]\n        super(Ralamb, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(Ralamb, self).__setstate__(state)\n\n    def step(self, closure=None):\n\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data.float()\n                if grad.is_sparse:\n                    raise RuntimeError('Ralamb does not support sparse gradients')\n\n                p_data_fp32 = p.data.float()\n\n                state = self.state[p]\n\n                if len(state) == 0:\n                    state['step'] = 0\n                    state['exp_avg'] = torch.zeros_like(p_data_fp32)\n                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)\n                else:\n                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)\n                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                # Decay the first and second moment running average coefficient\n                # m_t\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n                # v_t\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n\n                state['step'] += 1\n                buffered = self.buffer[int(state['step'] % 10)]\n\n                if state['step'] == buffered[0]:\n                    N_sma, radam_step_size = buffered[1], buffered[2]\n                else:\n                    buffered[0] = state['step']\n                    beta2_t = beta2 ** state['step']\n                    N_sma_max = 2 / (1 - beta2) - 1\n                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)\n                    buffered[1] = N_sma\n\n                    # more conservative since it's an approximated value\n                    if N_sma >= 5:\n                        radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])\n                    else:\n                        radam_step_size = 1.0 / (1 - beta1 ** state['step'])\n                    buffered[2] = radam_step_size\n\n                if group['weight_decay'] != 0:\n                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)\n\n                # more conservative since it's an approximated value\n                radam_step = p_data_fp32.clone()\n                if N_sma >= 5:\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n                    radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)\n                else:\n                    radam_step.add_(-radam_step_size * group['lr'], exp_avg)\n\n                radam_norm = radam_step.pow(2).sum().sqrt()\n                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)\n                if weight_norm == 0 or radam_norm == 0:\n                    trust_ratio = 1\n                else:\n                    trust_ratio = weight_norm / radam_norm\n\n                state['weight_norm'] = weight_norm\n                state['adam_norm'] = radam_norm\n                state['trust_ratio'] = trust_ratio\n\n                if N_sma >= 5:\n                    p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)\n                else:\n                    p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)\n\n                p.data.copy_(p_data_fp32)\n\n        return loss"
  },
  {
    "path": "torchtools/optim/ranger.py",
    "content": "####\n# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer\n# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d\n####\n\nimport math\nimport torch\nfrom torch.optim.optimizer import Optimizer, required\nimport itertools as it\nfrom .lookahead import Lookahead\nfrom .radam import RAdam\n\n\ndef Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):\n    radam = RAdam(params, betas=betas, *args, **kwargs)\n    return Lookahead(radam, alpha, k)\n"
  },
  {
    "path": "torchtools/transforms/__init__.py",
    "content": "from .smart_crop import SmartCrop"
  },
  {
    "path": "torchtools/transforms/models/__init__.py",
    "content": ""
  },
  {
    "path": "torchtools/transforms/smart_crop.py",
    "content": "import torch\nimport torchvision\nfrom torch import nn\nimport numpy as np\nimport os\n\n# MICRO RESNET\nclass ResBlock(nn.Module):\n    def __init__(self, channels):\n        super(ResBlock, self).__init__()\n        \n        self.resblock = nn.Sequential(\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(channels, channels, kernel_size=3),\n            nn.InstanceNorm2d(channels, affine=True),\n            nn.ReLU(),\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(channels, channels, kernel_size=3),\n            nn.InstanceNorm2d(channels, affine=True),\n        )\n        \n    def forward(self, x):\n        out = self.resblock(x)\n        return out + x\n    \nclass Upsample2d(nn.Module):\n    def __init__(self, scale_factor):\n        super(Upsample2d, self).__init__()\n        \n        self.interp = nn.functional.interpolate\n        self.scale_factor = scale_factor\n        \n    def forward(self, x):\n        x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')\n        return x\n\nclass MicroResNet(nn.Module):\n    def __init__(self):\n        super(MicroResNet, self).__init__()\n        \n        self.downsampler = nn.Sequential(\n            nn.ReflectionPad2d(4),\n            nn.Conv2d(3, 8, kernel_size=9, stride=4),\n            nn.InstanceNorm2d(8, affine=True),\n            nn.ReLU(),\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(8, 16, kernel_size=3, stride=2),\n            nn.InstanceNorm2d(16, affine=True),\n            nn.ReLU(),\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(16, 32, kernel_size=3, stride=2),\n            nn.InstanceNorm2d(32, affine=True),\n            nn.ReLU(),\n        )\n        \n        self.residual = nn.Sequential(\n            ResBlock(32),\n            nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),\n            ResBlock(64),\n        )\n        \n        self.segmentator = nn.Sequential(\n            nn.ReflectionPad2d(1),\n            nn.Conv2d(64, 16, kernel_size=3),\n            nn.InstanceNorm2d(16, affine=True),\n            nn.ReLU(),\n            Upsample2d(scale_factor=2),\n            nn.ReflectionPad2d(4),\n            nn.Conv2d(16, 1, kernel_size=9),\n            nn.Sigmoid()\n        )\n        \n    def forward(self, x):\n        out = self.downsampler(x)\n        out = self.residual(out)\n        out = self.segmentator(out)\n        return out\n\n# SmartCrop module\nclass SmartCrop(nn.Module):\n    def __init__(self, output_size, randomize_p=0.0, randomize_q=0.1, temperature=0.03):\n        super().__init__()\n        self.output_size = output_size\n        self.randomize_p, self.randomize_q = randomize_p, randomize_q\n        self.temperature = temperature\n        if isinstance(self.output_size, int):\n            self.output_size = (self.output_size, self.output_size)\n        self.saliency_model = MicroResNet().eval().requires_grad_(False)\n        checkpoint = torch.load(os.path.dirname(__file__) + \"/models/saliency_model_v9.pt\", map_location=\"cpu\")\n        self.saliency_model.load_state_dict(checkpoint)\n\n    def forward(self, image):\n        is_batch = len(image.shape) == 4\n        if not is_batch:\n            image = image.unsqueeze(0)\n        with torch.no_grad():\n            resized_image = torchvision.transforms.functional.resize(image, 240, antialias=True)\n            saliency_map = self.saliency_model(resized_image)\n            tempered_heatmap = saliency_map.view(saliency_map.size(0), -1).div(self.temperature).softmax(-1)\n            tempered_heatmap = tempered_heatmap / tempered_heatmap.sum(dim=1)\n            tempered_heatmap = (tempered_heatmap > tempered_heatmap.max(dim=-1)[0]*0.75).float()\n            saliency_map = tempered_heatmap.view(*saliency_map.shape)\n\n        # GET CENTROID \n        coord_space = torch.cat([\n            torch.linspace(0, 1, saliency_map.size(-2))[None, None, :, None].expand(-1, -1, -1, saliency_map.size(-1)),\n            torch.linspace(0, 1, saliency_map.size(-1))[None, None, None, :].expand(-1, -1, saliency_map.size(-2), -1),\n        ], dim=1)\n        centroid = (coord_space * saliency_map).sum(dim=[-1, -2]) / saliency_map.sum(dim=[-1, -2])\n        # CROP\n        crops = []\n        for i in range(image.size(0)):\n            if np.random.rand() < self.randomize_p:\n                centroid[i, 0] += np.random.uniform(-self.randomize_q, self.randomize_q)\n                centroid[i, 1] += np.random.uniform(-self.randomize_q, self.randomize_q)\n            top = (centroid[i, 0]*image.size(-2)-self.output_size[-2]/2).clamp(min=0, max=max(0, image.size(-2)-self.output_size[-2])).int()\n            left = (centroid[i, 1]*image.size(-1)-self.output_size[-1]/2).clamp(min=0, max=max(0, image.size(-1)-self.output_size[-1])).int()\n            bottom, right = top + self.output_size[-2], left + self.output_size[-1]\n            crop = image[i, :, top:bottom, left:right]\n            if crop.size(-2) < self.output_size[-2] or crop.size(-1) < self.output_size[-1]:\n                crop = torchvision.transforms.functional.center_crop(crop, self.output_size)\n            crops.append(crop)\n        if is_batch:\n            crops = torch.stack(crops, dim=0)\n        else:\n            crops = crops[0]\n        return crops"
  },
  {
    "path": "torchtools/utils/__init__.py",
    "content": "from .diffusion import Diffuzz\nfrom .diffusion2 import Diffuzz2\nfrom .gamma_parametrization import apply_gamma_reparam, gamma_reparam_model, remove_gamma_reparam\nfrom .weight_normalization import apply_weight_norm, weight_norm_model, remove_weight_norm"
  },
  {
    "path": "torchtools/utils/diffusion.py",
    "content": "import torch\n\n# Samplers --------------------------------------------------------------------\nclass SimpleSampler():\n    def __init__(self, diffuzz):\n        self.current_step = -1\n        self.diffuzz = diffuzz\n\n    def __call__(self, *args, **kwargs):\n        self.current_step += 1\n        return self.step(*args, **kwargs)\n\n    def init_x(self, shape):\n        return torch.randn(*shape, device=self.diffuzz.device)\n\n    def step(self, x, t, t_prev, noise):\n        raise NotImplementedError(\"You should override the 'apply' function.\")\n\nclass DDPMSampler(SimpleSampler):\n    def step(self, x, t, t_prev, noise):\n        alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])\n        alpha = (alpha_cumprod / alpha_cumprod_prev)\n\n        mu = (1.0 / alpha).sqrt() * (x - (1-alpha) * noise / (1-alpha_cumprod).sqrt())\n        std = ((1-alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu)\n        return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]])\n\nclass DDIMSampler(SimpleSampler):\n    def step(self, x, t, t_prev, noise):\n        alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])\n\n        x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt()\n        dp_xt = (1 - alpha_cumprod_prev).sqrt()\n        return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise\n\nclass DPMSolverPlusPlusSampler(SimpleSampler):  # FIXME: CURRENTLY NOT WORKING\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.q_ts = {}\n\n    def _get_coef(self, alpha_cumprod):\n        log_alpha_t = alpha_cumprod.log()\n        alpha_t = log_alpha_t.exp()\n        sigma_t = (1-alpha_t ** 2).sqrt()\n        lambda_t = log_alpha_t - sigma_t.log()\n        return alpha_t, sigma_t, lambda_t\n\n    def init_x(self, shape):\n        alpha_cumprod = self.diffuzz._alpha_cumprod(torch.ones(shape[0], device=self.diffuzz.device)).view(-1, *[1 for _ in shape[1:]])\n        return torch.randn(*shape, device=self.diffuzz.device) * self._get_coef(alpha_cumprod)[1]\n\n    def step(self, x, t, t_prev, noise):\n        alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        stride = (t_prev - t)\n        if self.current_step == 0:\n            alpha_t, sigma_t, _ = self._get_coef(alpha_cumprod)\n        elif self.current_step == 1:\n            alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]])\n            alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod)\n            _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next)\n            h = lambda_t - lambda_t_next\n            x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * self.q_ts[self.current_step-1]\n        else:\n            alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]])\n            alpha_cumprod_next_next = self.diffuzz._alpha_cumprod(t+stride*2).view(t.size(0), *[1 for _ in x.shape[1:]])\n            \n            alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod)\n            _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next)\n            _, _, lambda_t_next_next = self._get_coef(alpha_cumprod_next_next)\n            \n            h = lambda_t - lambda_t_next\n            h_next = lambda_t_next - lambda_t_next_next\n            \n            r = h_next / h\n            D = (1 + 1 / (2 * r)) * self.q_ts[self.current_step-1] - 1 / (2 * r) * self.q_ts[self.current_step-2]\n            x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * D\n        self.q_ts[self.current_step] =  (x - sigma_t * noise) / alpha_t\n        return x\n\nsampler_dict = {\n    'ddpm': DDPMSampler,\n    'ddim': DDIMSampler,\n    'dpmsolver++': DPMSolverPlusPlusSampler,\n}\n\n# Custom simplified foward/backward diffusion (cosine schedule)\nclass Diffuzz():\n    def __init__(self, s=0.008, device=\"cpu\", cache_steps=None, scaler=1, clamp_range=(0.0001, 0.9999)):\n        self.device = device\n        self.s = torch.tensor([s]).to(device)\n        self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2\n        self.scaler = scaler\n        self.cached_steps = None\n        self.clamp_range = clamp_range\n        if cache_steps is not None:\n            self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device))\n\n    def _alpha_cumprod(self, t):\n        if self.cached_steps is None:\n            if self.scaler > 1:\n                t = 1 - (1-t) ** self.scaler\n            elif self.scaler < 1:\n                t = t ** self.scaler\n            alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod\n            return alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])\n        else:\n            return self.cached_steps[t.mul(len(self.cached_steps)-1).long()]\n\n    def diffuse(self, x, t, noise=None): # t -> [0, 1]\n        if noise is None:\n            noise = torch.randn_like(x)\n        alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise\n\n    def undiffuse(self, x, t, t_prev, noise, sampler=None):\n        if sampler is None:\n            sampler = DDPMSampler(self)\n        return sampler(x, t, t_prev, noise)\n\n    def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, unconditional_inputs=None, sampler='ddpm', half=False):\n        r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device)            \n        if isinstance(sampler, str):\n            if sampler in sampler_dict:\n                sampler = sampler_dict[sampler](self)\n            else:\n                raise ValueError(f\"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}\")\n        elif issubclass(sampler, SimpleSampler):\n            sampler =  sampler(self)\n        else:\n            raise ValueError(\"Sampler should be either a string or a SimpleSampler object.\")\n        preds = []\n        x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone()\n        if half:\n            r_range = r_range.half()\n            x = x.half()\n        for i in range(0, timesteps):\n            if mask is not None and x_init is not None:\n                x_renoised, _ = self.diffuse(x_init, r_range[i])\n                x = x * mask + x_renoised * (1-mask)\n            pred_noise = model(x, r_range[i], **model_inputs)\n            if cfg is not None:\n                if unconditional_inputs is None:\n                    unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}\n                pred_noise_unconditional = model(x, r_range[i], **unconditional_inputs)\n                pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg)\n            x = self.undiffuse(x, r_range[i], r_range[i+1], pred_noise, sampler=sampler)\n            preds.append(x)\n        return preds\n        \n    def p2_weight(self, t, k=1.0, gamma=1.0):\n        alpha_cumprod = self._alpha_cumprod(t)\n        return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma"
  },
  {
    "path": "torchtools/utils/diffusion2.py",
    "content": "import torch\nimport numpy as np\n\n# Samplers --------------------------------------------------------------------\nclass SimpleSampler():\n    def __init__(self, diffuzz, mode=\"v\"):\n        self.current_step = -1\n        self.diffuzz = diffuzz\n        if mode not in ['v', 'e', 'x']:\n            raise Exception(\"Mode should be either 'v', 'e' or 'x'\")\n        self.mode = mode\n\n    def __call__(self, *args, **kwargs):\n        self.current_step += 1\n        return self.step(*args, **kwargs)\n\n    def init_x(self, shape):\n        return torch.randn(*shape, device=self.diffuzz.device)\n\n    def step(self, x, t, t_prev, noise):\n        raise NotImplementedError(\"You should override the 'apply' function.\")\n\n# https://github.com/ozanciga/diffusion-for-beginners/blob/main/samplers/ddim.py\nclass DDIMSampler(SimpleSampler):\n    def step(self, x, t, t_prev, pred, eta=0):\n        alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])\n\n        sigma_tau = eta * ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)).sqrt() * (1 - alpha_cumprod / alpha_cumprod_prev).sqrt() if eta > 0 else 0\n        if self.mode == 'v':\n            x0 = alpha_cumprod.sqrt() * x - (1-alpha_cumprod).sqrt() * pred\n            noise = (1-alpha_cumprod).sqrt() * x + alpha_cumprod.sqrt() * pred\n        elif self.mode == 'x':\n            x0 = pred\n            noise = (x - x0 * alpha_cumprod.sqrt()) / (1 - alpha_cumprod).sqrt()\n        else:\n            noise = pred\n            x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / alpha_cumprod.sqrt()\n        renoised = alpha_cumprod_prev.sqrt() * x0 + (1 - alpha_cumprod_prev - sigma_tau ** 2).sqrt() * noise + sigma_tau * torch.randn_like(x)\n        return x0, renoised, pred\n\nclass DDPMSampler(DDIMSampler):\n    def step(self, x, t, t_prev, pred, eta=1):\n        return super().step(x, t, t_prev, pred, eta)\n\nsampler_dict = {\n    'ddpm': DDPMSampler,\n    'ddim': DDIMSampler,\n}\n\n# Custom simplified foward/backward diffusion (cosine schedule)\nclass Diffuzz2():\n    def __init__(self, s=0.008, device=\"cpu\", cache_steps=None, scaler=1, clamp_range=(1e-7, 1-1e-7)):\n        self.device = device\n        self.s = torch.tensor([s]).to(device)\n        self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2\n        self.scaler = 2 * np.log(1/scaler)\n        self.cached_steps = None\n        self.clamp_range = clamp_range\n        if cache_steps is not None:\n            self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device))\n\n    def _alpha_cumprod(self, t):\n        if self.cached_steps is None:\n            alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod\n            alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])\n            if self.scaler != 1:\n                alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(self.scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1])\n            return alpha_cumprod\n        else:\n            return self.cached_steps[t.mul(len(self.cached_steps)-1).long()]\n\n    def scale_t(self, t, scaler):\n        scaler = 2 * np.log(1/scaler)\n        alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod\n        alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])\n        if scaler != 1:\n            alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1])\n        return (((alpha_cumprod * self._init_alpha_cumprod) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + self.s) - self.s\n\n    def diffuse(self, x, t, noise=None): # t -> [0, 1]\n        if noise is None:\n            noise = torch.randn_like(x)\n        alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise\n    \n    def get_v(self, x, t, noise):\n        alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])\n        # x0 = alpha_cumprod * noised − (1-alpha_cumprod).sqrt() * pred_v\n        # noise = (1-alpha_cumprod).sqrt() * noised + alpha_cumprod * pred_v\n        return alpha_cumprod.sqrt() * noise - (1-alpha_cumprod).sqrt() * x\n    \n    def x0_from_v(self, noised, pred_v, t):\n        alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]])\n        return alpha_cumprod.sqrt() * noised - (1-alpha_cumprod).sqrt() * pred_v\n\n    def noise_from_v(self, noised, pred_v, t):\n        alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]])\n        return (1-alpha_cumprod).sqrt() * noised + alpha_cumprod.sqrt() * pred_v\n\n    def undiffuse(self, x, t, t_prev, pred, sampler=None, **kwargs):\n        if sampler is None:\n            sampler = DDPMSampler(self)\n        return sampler(x, t, t_prev, pred, **kwargs)\n\n    def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_rho=0.7, unconditional_inputs=None, sampler='ddpm', dtype=None, sample_mode='v', sampler_params={}, t_scaler=1):\n        r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device)            \n        if t_scaler != 1:\n            r_range = self.scale_t(r_range, t_scaler)\n        if isinstance(sampler, str):\n            if sampler in sampler_dict:\n                sampler = sampler_dict[sampler](self, sample_mode)\n            else:\n                raise ValueError(f\"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}\")\n        elif issubclass(sampler, SimpleSampler):\n            sampler =  sampler(self, sample_mode)\n        else:\n            raise ValueError(\"Sampler should be either a string or a SimpleSampler object.\")\n    \n        x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone()\n        if dtype is not None:\n            r_range = r_range.to(dtype)\n            x = x.to(dtype)\n        if cfg is not None:\n            if unconditional_inputs is None:\n                unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}\n            model_inputs = {k:torch.cat([v, v_u]) if isinstance(v, torch.Tensor) else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())}\n        for i in range(0, timesteps):\n            if mask is not None and x_init is not None:\n                x_renoised, _ = self.diffuse(x_init, r_range[i])\n                x = x * mask + x_renoised * (1-mask)\n            if cfg is not None:\n                pred, pred_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), **model_inputs).chunk(2)\n                pred_cfg = torch.lerp(pred_unconditional, pred, cfg)\n                if cfg_rho > 0:\n                    std_pos, std_cfg = pred.std(),  pred_cfg.std()\n                    pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)\n                else:\n                    pred = pred_cfg\n            else:\n                pred = model(x, r_range[i], **model_inputs)\n\n            diff_out = self.undiffuse(x, r_range[i], r_range[i+1], pred, sampler=sampler, **sampler_params)\n            x = diff_out[1]\n            altered_vars = yield diff_out\n            \n            # Update some running variables if the user wants\n            if altered_vars is not None:\n                cfg = altered_vars.get('cfg', cfg)\n                cfg_rho = altered_vars.get('cfg_rho', cfg_rho)\n                sampler = altered_vars.get('sampler', sampler)\n                unconditional_inputs = altered_vars.get('unconditional_inputs', unconditional_inputs)\n                model_inputs = altered_vars.get('model_inputs', model_inputs)\n                x = altered_vars.get('x', x)\n                mask = altered_vars.get('mask', mask)\n                x_init = altered_vars.get('x_init', x_init)\n        \n    def p2_weight(self, t, k=1.0, gamma=1.0):\n        alpha_cumprod = self._alpha_cumprod(t)\n        return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma\n    \n    def truncated_snr_weight(self, t, min=1.0, max=None):\n        alpha_cumprod = self._alpha_cumprod(t)\n        srn = (alpha_cumprod / (1 - alpha_cumprod))\n        if min != None or max != None:\n            srn = srn.clamp(min=min, max=max)\n        return srn\n"
  },
  {
    "path": "torchtools/utils/gamma_parametrization.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass _GammaScaling(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.gamma = nn.Parameter(torch.ones(1))\n\n    def forward(self, w):\n        return w * self.gamma\n\ndef apply_gamma_reparam(module, name=\"weight\"): # this reparametrizes the parameters of a single module\n    nn.utils.parametrizations.spectral_norm(module, name)\n    nn.utils.parametrize.register_parametrization(module, name, _GammaScaling())\n    return module\n\ndef gamma_reparam_model(model):\n    for module in model.modules(): # this reparametrizes all linear layers of the model\n        if isinstance(module, nn.Linear) and not torch.nn.utils.parametrize.is_parametrized(module, \"weight\"):\n            apply_gamma_reparam(module, \"weight\")\n        elif isinstance(module, nn.MultiheadAttention) and not torch.nn.utils.parametrize.is_parametrized(module, \"in_proj_weight\"):\n            apply_gamma_reparam(module, \"in_proj_weight\")\n    return model\n\ndef remove_gamma_reparam(model):\n    for module in model.modules():\n        if torch.nn.utils.parametrize.is_parametrized(module, \"weight\"):\n            nn.utils.parametrize.remove_parametrizations(module, \"weight\")\n        elif torch.nn.utils.parametrize.is_parametrized(module, \"in_proj_weight\"):\n            nn.utils.parametrize.remove_parametrizations(module, \"in_proj_weight\")\n"
  },
  {
    "path": "torchtools/utils/weight_normalization.py",
    "content": "import torch\nfrom torch import nn\n\nclass _WeigthNorm(nn.Module):\n    def __init__(self, eps=1e-4):\n        super().__init__()\n        self.eps = eps\n        \n    def _normalize(self, w):\n        norm_dims = list(range(1, len(w.shape)))\n        w_norm = torch.linalg.vector_norm(w, dim=norm_dims, keepdim=True)\n        # w_norm = torch.norm_except_dim(w, 2, 0).clone()\n        return w / (w_norm + self.eps)\n\n    def forward(self, w):\n        if self.training:\n            with torch.no_grad():\n                fan_in = w[0].numel()**0.5\n                w.data = self._normalize(w.data.clone()) * fan_in\n                # w.copy_(self._normalize(w) * fan_in)\n        return self._normalize(w)\n\ndef apply_weight_norm(module, name=\"weight\", init_weight=True): # this reparametrizes the parameters of a single module\n    if init_weight:\n        torch.nn.init.normal(getattr(module, name))\n    nn.utils.parametrize.register_parametrization(module, name, _WeigthNorm(), unsafe=True)\n    return module\n\ndef weight_norm_model(model, whitelist=None, init_weight=True):\n    whitelist = whitelist or []\n\n    def check_parameter(module, name):\n        return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(getattr(module, name), nn.Parameter)\n\n    for name, module in model.named_modules(): # this reparametrizes all layers of the model that have a \"weight\" parameter\n        if not any([w in name for w in whitelist]):\n            if check_parameter(module, \"weight\"):\n                apply_weight_norm(module, init_weight=init_weight)\n            elif check_parameter(module, \"in_proj_weight\"):\n                apply_weight_norm(module, 'in_proj_weight', init_weight=init_weight)\n    return model\n\ndef remove_weight_norm(model):\n    for module in model.modules():\n        if torch.nn.utils.parametrize.is_parametrized(module, \"weight\"):\n            nn.utils.parametrize.remove_parametrizations(module, \"weight\")\n        elif torch.nn.utils.parametrize.is_parametrized(module, \"in_proj_weight\"):\n            nn.utils.parametrize.remove_parametrizations(module, \"in_proj_weight\")"
  }
]