Repository: pabloppp/pytorch-tools Branch: master Commit: 6472bd5e2231 Files: 53 Total size: 135.7 KB Directory structure: gitextract_gjrwr8nv/ ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── readme.md ├── setup.cfg ├── setup.py └── torchtools/ ├── __init__.py ├── lr_scheduler/ │ ├── __init__.py │ ├── delayed.py │ └── inverse_sqrt.py ├── nn/ │ ├── __init__.py │ ├── adain.py │ ├── alias_free_activation.py │ ├── equal_layers.py │ ├── evonorm2d.py │ ├── fourier_features.py │ ├── functional/ │ │ ├── __init__.py │ │ ├── gradient_penalty.py │ │ ├── magnitude_preserving.py │ │ ├── perceptual.py │ │ └── vq.py │ ├── gp_loss.py │ ├── haar_dwt.py │ ├── magnitude_preserving.py │ ├── mish.py │ ├── modulation.py │ ├── perceptual.py │ ├── pixel_normalzation.py │ ├── pos_embeddings.py │ ├── simple_self_attention.py │ ├── stylegan2/ │ │ ├── __init__.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ ├── transformers.py │ └── vq.py ├── optim/ │ ├── __init__.py │ ├── lamb.py │ ├── lookahead.py │ ├── novograd.py │ ├── over9000.py │ ├── radam.py │ ├── ralamb.py │ └── ranger.py ├── transforms/ │ ├── __init__.py │ ├── models/ │ │ ├── __init__.py │ │ └── saliency_model_v9.pt │ └── smart_crop.py └── utils/ ├── __init__.py ├── diffusion.py ├── diffusion2.py ├── gamma_parametrization.py └── weight_normalization.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ */**/__pycache__ *.egg-info /dist *.pyc .idea ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Pablo Pernías Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. Footer ================================================ FILE: MANIFEST.in ================================================ include LICENSE include readme.md recursive-include torchtools/transforms/models * ================================================ FILE: readme.md ================================================ # Pytorch Tools ## Install Requirements: ``` PyTorch >= 1.0.0 Torchivision Numpy >= 1.0.0 ``` ``` # In order to install the latest (beta) use pip install git+https://github.com/pabloppp/pytorch-tools -U # if you want to install a specific version to avoid breaking changes (for example, v0.3.5), use pip install git+https://github.com/pabloppp/pytorch-tools@0.3.5 -U ``` # Current available tools ## Optimizers Comparison table taken from https://github.com/mgrankin/over9000 And the article explaining this recent improvements https://medium.com/@lessw/how-we-beat-the-fastai-leaderboard-score-by-19-77-a-cbb2338fab5c Dataset | LR Schedule| Imagenette size 128, 5 epoch | Imagewoof size 128, 5 epoch --- | -- | --- | --- Adam - baseline |OneCycle| 0.8493 | 0.6125 RangerLars (RAdam + LARS + Lookahead) |Flat and anneal| 0.8732 | 0.6523 Ralamb (RAdam + LARS) |Flat and anneal| 0.8675 | 0.6367 Ranger (RAdam + Lookahead) |Flat and anneal| 0.8594 | 0.5946 Novograd |Flat and anneal| 0.8711 | 0.6126 Radam |Flat and anneal| 0.8444 | 0.537 Lookahead |OneCycle| 0.8578 | 0.6106 Lamb |OneCycle| 0.8400 | 0.5597 ### Ranger Taken as is from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d Example of use: ```python from torchtools.optim import Ranger optimizer = Ranger(model.parameters()) ``` ### RAdam Taken as is from https://github.com/LiyuanLucasLiu/RAdam Blog post: https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b Original Paper: https://arxiv.org/abs/1908.03265 Example of use: ```python from torchtools.optim import RAdam, PlainRAdam, AdamW optimizer = RAdam(model.parameters()) # optimizer = PlainRAdam(model.parameters()) # optimizer = AdamW(model.parameters()) ``` ### RangerLars (Over9000) Taken as is from https://github.com/mgrankin/over9000 Example of use: ```python from torchtools.optim import RangerLars # Over9000 optimizer = RangerLars(model.parameters()) ``` ### Novograd Taken as is from https://github.com/mgrankin/over9000 Example of use: ```python from torchtools.optim import Novograd optimizer = Novograd(model.parameters()) ``` ### Ralamb Taken as is from https://github.com/mgrankin/over9000 Example of use: ```python from torchtools.optim import Ralamb optimizer = Ralamb(model.parameters()) ``` ### Lookahead Taken as is from https://github.com/lonePatient/lookahead_pytorch Original Paper: https://arxiv.org/abs/1907.08610 This lookahead can be used with any optimizer Example of use: ```python from torch import optim from torchtools.optim import Lookahead optimizer = optim.Adam(model.parameters(), lr=0.001) optimizer = Lookahead(base_optimizer=optimizer, k=5, alpha=0.5) # for a base Lookahead + Adam you can just do: # # from torchtools.optim import LookaheadAdam ``` ### Lamb Taken as is from https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py Original Paper: https://arxiv.org/abs/1904.00962 Example of use: ```python from torchtools.optim import Lamb optimizer = Lamb(model.parameters()) ``` ## LR Schedulers ### Delayed LR Allows for a customizable number of initial steps where the learning rate remains fixed. After those steps the learning rate will be updated according to the supplied scheduler's policy Example of use: ```python from torch import optim, nn from torchtools.lr_scheduler import DelayerScheduler optimizer = optim.Adam(model.parameters(), lr=0.001) # define here your optimizer, the lr that you set will be the one used for the initial delay steps delay_epochs = 10 total_epochs = 20 base_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, delay_epochs) # delay the scheduler for 10 steps delayed_scheduler = DelayerScheduler(optimizer, total_epochs - delay_epochs, base_scheduler) for epoch in range(total_epochs): # train(...) delayed_scheduler.step() # The lr will be 0.001 for the first 10 epochs, then will use the policy fro the base_scheduler for the rest of the epochs # for a base DelayerScheduler + CosineAnnealingLR you can just do: # # from torchtools.lr_scheduler import DelayedCosineAnnealingLR # scheduler = DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs) # the sum of both must be the total number of epochs ``` ## Activations ### Mish Original implementation: https://github.com/digantamisra98/Mish Original Paper: https://arxiv.org/abs/1908.08681v1 Implementation taken as is from https://github.com/lessw2020/mish Example of use: ```python from torchtools.nn import Mish # Then you can just use Mish as a replacement for any activation function, like ReLU ``` ### AliasFreeActivation Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 by Rosinality I modularized this activation so it can be easily used inside of any model without having to deal with complex initialization. This activation actually takes a lot of responsibility, since it internally defines the channels and size of the input based on a set of parameters, instead of receiving them as a parameter, this means that the rest of the layers (convolutions, positional embedding, etc...) must adapt to it. Example of use: ```python from torchtools.nn.alias_free_activation import AliasFreeActivation from torchtools.nn import EqualLeakyReLU # We can use the static function to get the filter parameters for a specific level. # It can be specially usefull to obtain the initial size and channels. max_size, max_channels = 256, 512 first_channels, first_size = AliasFreeActivation.alias_level_params( 0, max_levels=14, max_size=max_size, max_channels=max_channels )[-2:] class MyModel(nn.Module): def __init__(self, level, max_levels=14, max_size=256, max_channels=512, margin=10): ... # AdaIN will require the style vector to be 2*size leaky_relu = EqualLeakyReLU(negative_slope=0.2) self.activation = AliasFreeActivation( leaky_relu, level, max_levels=max_levels, max_size=max_size, max_channels=max_channels, margin=margin ) self.conv = nn.Conv2d(self.activation.channels_prev, self.activation.channels, kernel_size=3, padding=1) ... def forward(self, x): # x the channels and size of X are dependent on the level of this module. ... x = self.conv(x) x = self.activation(x) ... ``` ## Layers ### SimpleSelfAttention Implementation taken as is from https://github.com/sdoria/SimpleSelfAttention Example of use: ```python from torchtools.nn import SimpleSelfAttention # The input of the layer has to at least have 3 dimensions (B, C, N), # the attention will be performed in the 2nd dimension. # # For images, the input will be internally reshaped to 3 dimensions, # and reshaped back to the original shape before returning it ``` ### PixelNorm Inspired from https://github.com/github-pengge/PyTorch-progressive_growing_of_gans Example of use: ```python from torchtools.nn import PixelNorm model = nn.Linear( nn.Conv2d(...), PixelNorm(), nn.ReLU() ) # It doesn't require any parameter, it just performs a simple element-wise normalization # x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) # # Just use it as a regular layer, generally after convolutions and before ReLU # (warning) since it performs a srtq root it's pretty slow if the layer sizes are big ``` ### Adaptive Instance Normalization - AdaIN Implementation based on https://github.com/SiskonEmilia/StyleGAN-PyTorch Original Paper https://arxiv.org/abs/1703.06868 Example of use: ```python from torchtools.nn import AdaIN class MyModel(nn.Module): def __init__(self, n_channels): ... # AdaIN will require the style vector to be 2*size self.style = nn.Linear(input_size, output_size*2) self.adain = AdaIN(output_size) ... def forward(self, x, w): ... x = self.adain(x, self.style(w)) ... # AdaIN will "transfer" a style encoded in a latent vector w into any tensor x. # In order to do this it first needs to be passed through a linear layer that will return 2 tensors (actually, one tensor of twice the size required, that we'll then split in 2) # It will then perform an instance normalization to "whiten" the tensor, followed with a de-normalization but using the values generated by the linear layer, thus encoding the original vector w in the tensor. ``` ### EvoNorm Implementation taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98 Original Paper https://arxiv.org/abs/2004.02967 Example of use: ```python from torchtools.nn import EvoNorm2D model = nn.Linear( nn.Conv2d(...), EvoNorm2D(c_hidden), # For S0 version # evoB0 = EvoNorm2D(input, affine = True, version = 'B0', training = True) # For B0 version nn.ReLU() ) ``` ### GPT Transformer Encoder Layer Implementation based on MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy It can be used as a drop-in replacement for the `torch.nn.TransformerEncoderLayer` Example of use: ```python from torchtools.nn import GPTTransformerEncoderLayer class MyTransformer(nn.Module): def __init__(self, n_channels): ... encoder_layer = GPTTransformerEncoderLayer(d_model=512, nhead=8) self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) ... ``` ### Stylegan2 ModulatedConv2d Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 by Rosinality It extends from `torch.nn.Conv2d` so you can use it as a drop-in replacement, the only Conv2d parámeter that you cannot use is 'groups' since it will be overriden for this to work. It also includes a parameter `ema_decay` that will add the EMA normalization used in Alias-free GAN (defaults to 1, meaning that it's disabled) Example of use: ```python from torchtools.nn import ModulatedConv2d class MyModel(nn.Module): def __init__(self): ... self.conv = ModulatedConv2d(16, 32, kernel_size=3, padding=1) # SUGESTIONS: # set bias=False if you want to handle bias on your own # set demodulate=False for RGB output # set ema_decay=0.9989 to imitate the alias-free gan setup ... def forward(self, x, w): ... x = self.conv(x, w) # 'x' is a 4D tensor (B x C x W x H) and 'w' is a 2D tensor (B x C) ... ``` ### Equal Layers (EqualNorm, EqualLinear) Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94 It extends the base classes (nn.Linear, nn.Conv2d, nn.LeakyReLU) so you can use this as a drop-in replacement, although it includes some optiona parameters. Example of use: ```python from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d class MyModel(nn.Module): def __init__(self): ... self.linear = EqualLinear(16, 32, bias_init=1, lr_mul=0.01) # bias_init and lr_mul are extra optional params self.leaky_relu = EqualLeakyReLU(negative_slope=0.2) self.conv = EqualConv2d(16, 32, kernel_size=3, padding=1) # Since this classes extend from the base classes, you can use all parameters from the original classes. ... ``` ### FourierFeatures2d Implementation inspired on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L88 but improved using my own understanding of how this should work... It creates a 2d tensor of embeddings following a fourier series based on the parameters you provide, this features are dynamic, meaning that affine transformations can be applied to them in order to shift, rotate, and even scale (experimental). ```python from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d class MyModel(nn.Module): def __init__(self, dim=256, margin=10, cutoff=2): ... self.feats = FourierFeatures2d(4+margin*2, dim, cutoff) # optionally enable scaling with allow_scaling=True # Also, you can randomize the frequencies if you plan on keeping them fixed, setting w_scale to any value > 0 ... def forward(self, affine): ... embds = self.feats(affine) # 'affine' should be a Bx4 tensor, or Bx6 if scaling is enabled... # the default or initial affine values should be [1, 0, 0, 0, 1, 1] => ([1, 0]: rotation, [0, 0]: shift, [1, 1]: scale) ... ``` ## Criterions ### Gradient Penalty (for WGAN-GP) Implementation taken with minor changes from https://github.com/caogang/wgan-gp Original paper https://arxiv.org/pdf/1704.00028.pdf Example of use: ```python from torchtools.nn import GPLoss # This criterion defines the gradient penalty for WGAN GP # For an example of a training cycle refer to this repo https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L185 discriminator = ... gpcriterion = GPLoss(discriminator) # l = 10 by default gradient_penalty = gpcriterion(real_data, fake_data) discriminator_loss = ... + gradient_penalty # add the gp component to the Wasserstein loss ``` ### Total Variation Loss Total Variation denoising https://www.wikiwand.com/en/Total_variation_denoising Example of use: ```python # This loss (or regularization) is usefull for removing artifacts and noise in generated images. # It's widely used in style transfer. from torchtools.nn import TVLoss tvcriterion = TVLoss() # reduction = 'sum' and alpha = 1e-4 by default G = ... # output image tv_loss = tvcriterion(G) loss = ... + tv_loss # add the tv loss component to your reconstruction loss ``` ## Vector Quantization ### VectorQuantize: Encodding based quantization [(source)](torchtools/vq.py#L5) This transforms any tensor to its quantized version using a codebook of embeddings. It uses a traight-forward approach for applying the gradients. Passing a tensor trough the **VectorQuantize** module will return a new tensor with the same dimension but changing each one of the tensors of the last dimension by the nearest neighbor from the codebook, which has a limited number of values, thus quantizing the tensor. For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L4) The output of the model is a quantized tensor, as well as a Touple of the loss components of the codebook (needed for training), and the indices of the quantized vectors in the form: `qx, (vq_loss, commit_loss), indices` When **creating a new instance of the module**, it accepts the following parameters: - **embedding_size**: the size of the embeddings used in the codebook, should match the last dimension of the tensor you want to quantize - **k**: the size of the codebook, or number of embeddings. - **ema_decay** (default=0.99): the Exponentially Moving Average decay used (this only will be used if ema_loss is True) - **ema_loss** (default=False): Enables Exponentially Moving Average update of the codebook (instead of relying on gradient descent as EMA converges faster) When **calling the forward method** of the module, it accepts the following parameters: - **x**: this is the tensor you want to quantize, make sure the dimension that you want to quantize (by default is the last one) matches embedding_size defined when instantiating the module - **get_losses** (default=True): when False, the vq_loss and commit_loss components of the output will both be None, this should speed up a little bit the model when used for inference. - **dim** (default=-1): The dimension across which the input should be quantized. Example of use: ```python from torchtools.nn import VectorQuantize e = torch.randn(1, 16, 16, 8) # create a random tensor with 8 as its last dimension size vquantizer = VectorQuantize(8, k=32, ema_loss=True) # we create the module with embedding size of 8, a codebook of size 32 and make the codebook update using EMA qe, (vq_loss, commit_loss), indices = vquantizer.forward(e) # we quantize our tensor while also getting the loss components and the indices # NOTE While the model is in training mode, the codebook will always be updated when calling the forward method, in order to freeze the codebook for inference put it in evaluation mode with 'vquantizer.eval()' # NOTE 2 In order to update the module properly, add the loss components to the final model loss before calling backward(), if you set ema_loss to true you only need to add the commit_loss to the total loss, an it's usually multiplied by a value between 0.1 and 2, being 0.25 a good default value loss = ... # whatever loss you have for your final output loss += commit_loss * 0.25 # loss += vq_loss # only if you didn't set the ema_loss to True ... loss.backward() optimizer.step() ``` --- ### Binarize: binarize the input tensor [(source)](torchtools/vq.py#L55) This transfors the values of a tensor into 0 and 1 depending if they're above or below a specified threshold. It uses a traight-forward approach for applying the gradients, so it's effectively differentiable. For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L36) Example of use: ```python from torchtools.nn import Binarize e = torch.randn(8, 16) # create a random tensor with any dimension binarizer = Binarize(threshold=0.5) # you can set the threshold you want, for example if your output was passed through a tanh activation, 0 might be a better theshold since tanh outputs values between -1 and 1 bq = binarizer(e) # will return a tensor with the same shape as e, but full of 0s and 1s ``` ## Embeddings ### RotaryEmbedding Implementation taken as is from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L161 Example of use: ```python from torchtools.nn import RotaryEmbedding class MyModel(nn.Module): def __init__(self, dim): ... self.rotary_pos_embd = RotaryEmbedding(dim) ... def forward(self, x): x = self.rotary_pos_embd(x) ... ``` ## Diffusion ### Diffuzz Custom (non-cached) continuous forward/backward diffusion. It's not SUPER performant since it calculates all the required values on the fly instead of caching them (although I add a very simple cache where you can specify the number of steps that you want to cache), but in general I think this will not make an extremely big difference in terms of performance, simplifies a lot the code, and removes the concept of having a fixed number of timesteps for the forward diffusion (since I always found it weird to train a model assuming 1000 forward diffusion steps, and then using way less steps during inference) by using a continuous value between 0 and 1 to decide how much noise we'll be adding to the output (1 being pure gaussian noise). During sampling, the same applies, instead of having a fixed number of steps, the diffuzz module will accept a noised input, a couple of values t & t_prev (between 0 and 1) and a predicted noise, and it will try to remove such noise in a scale such as to go from step t to step t_prev, so if we want to denoise in 10 steps we'll tell it to go from 1.0 to 0.9, then to 0.8, etc... while if we want to denoise in 100 steps, we'll start at 1.0 and go to 0.99, then to 0.98, etc... Example of use during training: ```python from torchtools.utils import Diffuzz device = "cuda" diffuzz = Diffuzz(device=device) # diffuzz = Diffuzz(device=device, cache_steps=10000) # optionally you can pass a 'cache_steps' parameter to speed up the noising process custom_unet = CustomUnet().to(device) # Custom model whith output size = input size input_tensor = torch.randn(8, 3, 16, 16, device=device) # an image, audio signal, or whatever... t = torch.rand(input_tensor.size(0), device=device) # get a tensor with batch_size of values between 0 and 1 noised_tensor, noise = diffuzz.diffuse(input_tensor, t) predicted_noise = custom_unet(noised_tensor, t) loss = nn.functional.mse_loss(predicted_noise, noise) # Optionally the diffuzz module provides loss gamma weighting (untested) but for this to work the loss # should not be averaged on the batch dimension before applying it. # loss = nn.functional.mse_loss(predicted_noise, noise, reduction='none').mean(dim=[1, 2, 3]) # loss = (loss * diffuzz.p2_weight(t)).mean() ``` Example of use for sampling: ```python from torchtools.utils import Diffuzz device = "cuda" sampled = diffuzz.sample( custom_unet, {'c': conditioning}, (conditioning.size(0), 3, 16, 16), timesteps=20, sampler='ddim' )[-1] ``` the `sample` method accepts a `sampler` parameter, currently only `ddpm` (default) and `ddim` are implemented, but I'm planning on adding more, very likely by borrowing (and appropriately citing) code from this repo https://github.com/ozanciga/diffusion-for-beginners ================================================ FILE: setup.cfg ================================================ [metadata] description-file = readme.md ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages setup( name='torchtools', packages=find_packages(), description='PyTorch useful tools', version='0.3.5', url='https://github.com/pabloppp/pytorch-tools', author='Pablo Pernías', author_email='pablo@pernias.com', keywords=['pip', 'pytorch', 'tools', 'RAdam', 'Lookahead', 'RALamb', 'quantization'], zip_safe=False, install_requires=[ 'torch>=1.6', 'torchvision', 'numpy>=1.0', 'ninja>=1.0' ], package_data={ 'stylegan2.tools': ['torchtools/nn/stylegan2/*'], 'transforms.models': ['torchtools/transforms/models/*'] }, include_package_data=True, ) ================================================ FILE: torchtools/__init__.py ================================================ ================================================ FILE: torchtools/lr_scheduler/__init__.py ================================================ from .delayed import DelayerScheduler, DelayedCosineAnnealingLR from .inverse_sqrt import InverseSqrtLR ================================================ FILE: torchtools/lr_scheduler/delayed.py ================================================ from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR class DelayerScheduler(_LRScheduler): """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler Args: optimizer (Optimizer): Wrapped optimizer. delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) """ def __init__(self, optimizer, delay_epochs, after_scheduler): self.delay_epochs = delay_epochs self.after_scheduler = after_scheduler self.finished = False super().__init__(optimizer) def get_lr(self): if self.last_epoch >= self.delay_epochs: if not self.finished: self.after_scheduler.base_lrs = self.base_lrs self.finished = True return self.after_scheduler.get_lr() return self.base_lrs def step(self, epoch=None): if self.finished: if epoch is None: self.after_scheduler.step(None) else: self.after_scheduler.step(epoch - self.delay_epochs) else: return super(DelayerScheduler, self).step(epoch) def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs): base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs) return DelayerScheduler(optimizer, delay_epochs, base_scheduler) ================================================ FILE: torchtools/lr_scheduler/inverse_sqrt.py ================================================ import warnings from torch.optim.lr_scheduler import LRScheduler class InverseSqrtLR(LRScheduler): def __init__(self, optimizer, lr, warmup_steps, pre_warmup_lr=None, last_epoch=-1, verbose=False): warmup_steps = max(warmup_steps, 1) self.lr = lr * warmup_steps**0.5 self.warmup_steps = warmup_steps self.pre_warmup_lr = pre_warmup_lr if pre_warmup_lr is not None else lr super().__init__(optimizer, last_epoch, verbose) def _process_lr(self, _): warmup_factor = min(self.last_epoch/self.warmup_steps, 1) # this grows linearly from 0 to 1 during the warmup base_lr = self.lr / max(self.last_epoch, self.warmup_steps)**0.5 return warmup_factor * base_lr + (1-warmup_factor)*self.pre_warmup_lr def get_lr(self): if not self._get_lr_called_within_step: warnings.warn("To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning) lr = self._process_lr(self.lr) return [lr for _ in self.optimizer.param_groups] def _get_closed_form_lr(self): return [self._process_lr(base_lr) for base_lr in self.base_lrs] ================================================ FILE: torchtools/nn/__init__.py ================================================ from .mish import Mish from .simple_self_attention import SimpleSelfAttention from .vq import VectorQuantize, Binarize, FSQ from .gp_loss import GPLoss from .pixel_normalzation import PixelNorm from .perceptual import TVLoss from .adain import AdaIN from .transformers import GPTTransformerEncoderLayer from .evonorm2d import EvoNorm2D from .pos_embeddings import RotaryEmbedding from .modulation import ModulatedConv2d from .equal_layers import EqualConv2d, EqualLeakyReLU, EqualLinear from .fourier_features import FourierFeatures2d # from .alias_free_activation import AliasFreeActivation from .magnitude_preserving import MP_GELU, MP_SiLU, Gain from .haar_dwt import HaarForward, HaarInverse ================================================ FILE: torchtools/nn/adain.py ================================================ import torch from torch import nn class AdaIN(nn.Module): def __init__(self, n_channels): super(AdaIN, self).__init__() self.norm = nn.InstanceNorm2d(n_channels) def forward(self, image, style): factor, bias = style.view(style.size(0), style.size(1), 1, 1).chunk(2, dim=1) result = self.norm(image) * factor + bias return result ================================================ FILE: torchtools/nn/alias_free_activation.py ================================================ import torch from torch import nn import math from .stylegan2 import upfirdn2d #### # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 # But I simplified it into a single (almots) self-contained module. # Probably I give this module too much reponsibility but meh... #### class AliasFreeActivation(nn.Module): def __init__(self, activation, level, max_levels, max_size, max_channels, margin, start_cutoff=2, critical_layers=2, window_size=6): super().__init__() self.activation = activation # Filter features self.cutoff, self.stopband, self.band_half, self.channels, self.size = self.alias_level_params( level, max_levels, max_size, max_channels, start_cutoff, critical_layers ) self.cutoff_prev, self.stopband_prev, self.band_half_prev, self.channels_prev, self.size_prev = self.alias_level_params( max(level-1, 0), max_levels, max_size, max_channels, start_cutoff, critical_layers ) # Filters self.scale_factor = 2 if self.size_prev < self.size else 1 up_filter = self._lowpass_filter( window_size * self.scale_factor * 2, self.cutoff_prev, self.band_half_prev, self.size * self.scale_factor * 2 ) self.register_buffer("up_filter", (up_filter / up_filter.sum()) * 2 * self.scale_factor) down_filter = self._lowpass_filter( window_size * self.scale_factor, self.cutoff, self.band_half, self.size * self.scale_factor * 2 ) self.register_buffer("down_filter", down_filter / down_filter.sum()) p = self.up_filter.shape[0] - (2*self.scale_factor) self.up_pad = ((p + 1) // 2 + (2*self.scale_factor) - 1, p // 2) p = self.down_filter.shape[0] - 2 self.down_pad = ((p + 1) // 2, p // 2) self.margin = margin @staticmethod def alias_level_params(level, max_levels, max_size, max_channels, start_cutoff=2, critical_layers=2, base_channels=2**14): end_cutoff = max_size//2 cutoff = start_cutoff * (end_cutoff / start_cutoff) ** min(level / (max_levels - critical_layers), 1) start_stopband = start_cutoff ** 2.1 end_stopband = end_cutoff * (2 ** 0.3) stopband = start_stopband * (end_stopband/start_stopband) ** min(level / (max_levels - critical_layers), 1) size = 2 ** math.ceil(math.log(min(2 * stopband, max_size), 2)) band_half = max(stopband, size / 2) - cutoff channels = min(round(base_channels / size), max_channels) return cutoff, stopband, band_half, channels, size def _lowpass_filter(self, n_taps, cutoff, band_half, sr): window = self._kaiser_window(n_taps, band_half, sr) ind = torch.arange(n_taps) - (n_taps - 1) / 2 lowpass = 2 * cutoff / sr * torch.sinc(2 * cutoff / sr * ind) * window return lowpass def _kaiser_window(self, n_taps, f_h, sr): beta = self._kaiser_beta(n_taps, f_h, sr) ind = torch.arange(n_taps) - (n_taps - 1) / 2 return torch.i0(beta * torch.sqrt(1 - ((2 * ind) / (n_taps - 1)) ** 2)) / torch.i0(torch.tensor(beta)) def _kaiser_attenuation(self, n_taps, f_h, sr): df = (2 * f_h) / (sr / 2) return 2.285 * (n_taps - 1) * math.pi * df + 7.95 def _kaiser_beta(self, n_taps, f_h, sr): atten = self._kaiser_attenuation(n_taps, f_h, sr) if atten > 50: return 0.1102 * (atten - 8.7) elif 50 >= atten >= 21: return 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21) else: return 0.0 def forward(self, x): x = self._upsample(x, self.up_filter, 2*self.scale_factor, pad=self.up_pad) x = self.activation(x) x = self._downsample(x, self.down_filter, 2, pad=self.down_pad) if self.scale_factor > 1 and self.margin > 0: m = self.scale_factor * self.margin // 2 x = x[:, :, m:-m, m:-m] return x def _upsample(self, x, kernel, factor, pad=(0, 0)): x = upfirdn2d(x, kernel.unsqueeze(0), up=(factor, 1), pad=(*pad, 0, 0)) x = upfirdn2d(x, kernel.unsqueeze(1), up=(1, factor), pad=(0, 0, *pad)) return x def _downsample(self, x, kernel, factor, pad=(0, 0)): x = upfirdn2d(x, kernel.unsqueeze(0), down=(factor, 1), pad=(*pad, 0, 0)) x = upfirdn2d(x, kernel.unsqueeze(1), down=(1, factor), pad=(0, 0, *pad)) return x def extra_repr(self): info_string = f'cutoff={self.cutoff}, stopband={self.stopband}, band_half={self.band_half}, channels={self.channels}, size={self.size}' return info_string ================================================ FILE: torchtools/nn/equal_layers.py ================================================ import torch from torch import nn import math #### # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94 # But made it extend from the base modules to avoid some boilerplate #### class EqualLinear(nn.Linear): def __init__(self, *args, bias_init=0, lr_mul=1, **kwargs): super().__init__(*args, **kwargs) self.scale = (1 / math.sqrt(self.in_features)) * lr_mul self.lr_mul = lr_mul nn.init.normal_(self.weight, std=1/lr_mul) if self.bias is not None: nn.init.constant_(self.bias, bias_init) def forward(self, x): return nn.functional.linear(x, self.weight * self.scale, self.bias * self.lr_mul) class EqualConv2d(nn.Conv2d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) fan_in = self.in_channels * self.kernel_size[0] ** 2 self.scale = 1 / math.sqrt(fan_in) nn.init.normal_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x): return self._conv_forward(x, self.weight * self.scale, self.bias) class EqualLeakyReLU(nn.LeakyReLU): def __init__(self, *args, scale=2**0.5, **kwargs): super().__init__(*args, **kwargs) self.scale = scale def forward(self, x): return super().forward(x) * self.scale ================================================ FILE: torchtools/nn/evonorm2d.py ================================================ import torch import torch.nn as nn ## Taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98 class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): ctx.save_for_backward(i) return i * torch.sigmoid(i) @staticmethod def backward(ctx, grad_output): sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): return SwishImplementation.apply(x) def instance_std(x, eps=1e-5): var = torch.var(x, dim = (2, 3), keepdim=True).expand_as(x) if torch.isnan(var).any(): var = torch.zeros(var.shape) return torch.sqrt(var + eps) def group_std(x, groups = 32, eps = 1e-5): N, C, H, W = x.size() x = torch.reshape(x, (N, groups, C // groups, H, W)) var = torch.var(x, dim = (2, 3, 4), keepdim = True).expand_as(x) return torch.reshape(torch.sqrt(var + eps), (N, C, H, W)) class EvoNorm2D(nn.Module): def __init__(self, input, non_linear = True, version = 'S0', efficient = False, affine = True, momentum = 0.9, eps = 1e-5, groups = 32, training = True): super(EvoNorm2D, self).__init__() self.non_linear = non_linear self.version = version self.training = training self.momentum = momentum self.efficient = efficient if self.version == 'S0': self.swish = MemoryEfficientSwish() self.groups = groups self.eps = eps if self.version not in ['B0', 'S0']: raise ValueError("Invalid EvoNorm version") self.insize = input self.affine = affine if self.affine: self.gamma = nn.Parameter(torch.ones(1, self.insize, 1, 1)) self.beta = nn.Parameter(torch.zeros(1, self.insize, 1, 1)) if self.non_linear: self.v = nn.Parameter(torch.ones(1,self.insize,1,1)) else: self.register_parameter('gamma', None) self.register_parameter('beta', None) self.register_buffer('v', None) self.register_buffer('running_var', torch.ones(1, self.insize, 1, 1)) self.reset_parameters() def reset_parameters(self): self.running_var.fill_(1) def _check_input_dim(self, x): if x.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(x.dim())) def forward(self, x): self._check_input_dim(x) if self.version == 'S0': if self.non_linear: if not self.efficient: num = x * torch.sigmoid(self.v * x) # Original Swish Implementation, however memory intensive. else: num = self.swish(x) # Experimental Memory Efficient Variant of Swish return num / group_std(x, groups = self.groups, eps = self.eps) * self.gamma + self.beta else: return x * self.gamma + self.beta if self.version == 'B0': if self.training: var = torch.var(x, dim = (0, 2, 3), unbiased = False, keepdim = True) self.running_var.mul_(self.momentum) self.running_var.add_((1 - self.momentum) * var) else: var = self.running_var if self.non_linear: den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x, eps = self.eps)) return x / den * self.gamma + self.beta else: return x * self.gamma + self.beta ================================================ FILE: torchtools/nn/fourier_features.py ================================================ import torch from torch import nn import math class FourierFeatures2d(nn.Module): def __init__(self, size, dim, cutoff, affine_eps=1e-8, freq_range=[-0.5, 0.5], w_scale=0, allow_scaling=False, op_order=['r', 't', 's']): super().__init__() self.size = size self.dim = dim self.cutoff = cutoff self.freq_range = freq_range self.affine_eps = affine_eps self.w_scale = w_scale coords = torch.linspace(freq_range[0], freq_range[1], size+1)[:-1] freqs = torch.linspace(0, cutoff, dim // 4) if w_scale > 0: freqs = freqs @ (torch.randn(dim // 4, dim // 4) * w_scale) coord_map = torch.outer(freqs, coords) coord_map = 2 * math.pi * coord_map self.register_buffer("coord_h", coord_map.view(freqs.shape[0], 1, size)) self.register_buffer("coord_w", self.coord_h.transpose(1, 2).detach()) self.register_buffer("lf", freqs.view(1, dim // 4, 1, 1) * 2*math.pi * 2/size) self.allow_scaling = allow_scaling for op in op_order: assert op in ['r', 't', 's'], f"Operation not valid: {op}" self.op_order = op_order def forward(self, affine): norm = ((affine[:, 0:1].pow(2) + affine[:, 1:2].pow(2)).sqrt() + self.affine_eps).expand(affine.size(0), 4) if self.allow_scaling: assert affine.size(-1) == 6, f"If scaling is enabled, 2 extra values must be passed for a total of 6, and not {affine.size(-1)}" norm = torch.cat([norm, norm.new_ones(affine.size(0), 2)], dim=1) else: assert affine.size(-1) == 4, f"If scaling is disabled, 4 affine values should be passed, and not {affine.size(-1)}" affine = affine / norm affine = affine[:, :, None, None, None] coord_h, coord_w = self.coord_h.unsqueeze(0), self.coord_w.unsqueeze(0) for op in reversed(self.op_order): if op == 's' and self.allow_scaling: coord_h = coord_h / nn.functional.threshold(affine[:, 5], 1.0, 1.0) # scale coord_w = coord_w / nn.functional.threshold(affine[:, 4], 1.0, 1.0) elif op == 't': coord_h = coord_h - (affine[:, 3] * self.lf) # shift coord_w = coord_w - (affine[:, 2] * self.lf) elif op == 'r': _coord_h = -coord_w * affine[:, 1] + coord_h * affine[:, 0] # rotation coord_w = coord_w * affine[:, 0] + coord_h * affine[:, 1] coord_h = _coord_h coord_h = torch.cat((torch.sin(coord_h), torch.cos(coord_h)), 1) coord_w = torch.cat((torch.sin(coord_w), torch.cos(coord_w)), 1) coords = torch.cat((coord_h, coord_w), 1) return coords def extra_repr(self): info_string = f'size={self.size}, dim={self.dim}, cutoff={self.cutoff}, freq_range={self.freq_range}' if self.w_scale > 0: info_string += f', w_scale={self.w_scale}' if self.allow_scaling: info_string += f', allow_scaling={self.allow_scaling}' return info_string ================================================ FILE: torchtools/nn/functional/__init__.py ================================================ from .vq import vector_quantize, binarize from .gradient_penalty import gradient_penalty from .perceptual import total_variation from .magnitude_preserving import mp_cat, mp_sum ================================================ FILE: torchtools/nn/functional/gradient_penalty.py ================================================ #### # CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp # ORIGINAL PAPER https://arxiv.org/pdf/1704.00028.pdf #### import torch from torch import autograd def gradient_penalty(netD, real_data, fake_data, l=10): batch_size = real_data.size(0) alpha = real_data.new_empty((batch_size, 1, 1, 1)).uniform_(0, 1) alpha = alpha.expand_as(real_data) interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=real_data.new_ones(disc_interpolates.size()), create_graph=True, retain_graph=True, only_inputs=True)[0] gradients = gradients.view(gradients.size(0), -1) gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) gradient_penalty = ((gradients_norm - 1) ** 2).mean() * l return gradient_penalty ================================================ FILE: torchtools/nn/functional/magnitude_preserving.py ================================================ import torch def mp_cat(*args, dim=1, t=0.5): if isinstance(t, float): t = [1-t, t] assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)" w = [m/a.size(dim)**0.5 for a, m in zip(args, t)] C = (sum([a.size(dim) for a in args]) / sum([m**2 for m in t]))**0.5 return torch.cat([a*v for a, v in zip(args, w)], dim=dim) * C def mp_sum(*args, t=0.5): if isinstance(t, float): t = [1-t, t] assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)" assert abs(sum(t)-1) < 1e-3 , "the values of t should all add up to one" return sum([a*m for a, m in zip(args, t)]) / sum([m**2 for m in t])**0.5 ================================================ FILE: torchtools/nn/functional/perceptual.py ================================================ import torch def total_variation(X, reduction='sum'): tv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1]) tv_v = torch.abs(X[:, :, 1:] - X[:, :, :-1]) tv = torch.mean(tv_h) + torch.mean(tv_v) if reduction == 'mean' else torch.sum(tv_h) + torch.sum(tv_v) return tv ================================================ FILE: torchtools/nn/functional/vq.py ================================================ import torch from torch.autograd import Function class vector_quantize(Function): @staticmethod def forward(ctx, x, codebook): with torch.no_grad(): codebook_sqr = torch.sum(codebook ** 2, dim=1) x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) _, indices = dist.min(dim=1) ctx.save_for_backward(indices, codebook) ctx.mark_non_differentiable(indices) nn = torch.index_select(codebook, 0, indices) return nn, indices @staticmethod def backward(ctx, grad_output, grad_indices): grad_inputs, grad_codebook = None, None if ctx.needs_input_grad[0]: grad_inputs = grad_output.clone() if ctx.needs_input_grad[1]: # Gradient wrt. the codebook indices, codebook = ctx.saved_tensors grad_codebook = torch.zeros_like(codebook) grad_codebook.index_add_(0, indices, grad_output) return (grad_inputs, grad_codebook) class binarize(Function): @staticmethod def forward(ctx, x, threshold=0.5): with torch.no_grad(): binarized = (x > threshold).float() ctx.mark_non_differentiable(binarized) return binarized @staticmethod def backward(ctx, grad_output): grad_inputs = None if ctx.needs_input_grad[0]: grad_inputs = grad_output.clone() return grad_inputs ================================================ FILE: torchtools/nn/gp_loss.py ================================================ import torch from torch import nn from .functional import gradient_penalty class GPLoss(nn.Module): def __init__(self, discriminator, l=10): super(GPLoss, self).__init__() self.discriminator = discriminator self.l = l def forward(self, real_data, fake_data): return gradient_penalty(self.discriminator, real_data, fake_data, self.l) ================================================ FILE: torchtools/nn/haar_dwt.py ================================================ import torch from torch import nn # Taken almost as is from https://github.com/bes-dev/haar_pytorch class HaarForward(nn.Module): """ Performs a 2d DWT Forward decomposition of an image using Haar Wavelets set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt """ def __init__(self, beta=2): super().__init__() self.alpha = 0.5 self.beta = beta def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a 2d DWT Forward decomposition of an image using Haar Wavelets Arguments: x (torch.Tensor): input tensor of shape [b, c, h, w] Returns: out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2] """ ll = self.alpha/self.beta * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2]) lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2]) hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2]) hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2]) return torch.cat([ll,lh,hl,hh], axis=1) class HaarInverse(nn.Module): """ Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt """ def __init__(self, beta=2): super().__init__() self.alpha = 0.5 self.beta = beta def forward(self, x: torch.Tensor) -> torch.Tensor: """ Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets Arguments: x (torch.Tensor): input tensor of shape [b, c, h, w] Returns: out (torch.Tensor): output tensor of shape [b, c / 4, h * 2, w * 2] """ assert x.size(1) % 4 == 0, "The number of channels must be divisible by 4." size = [x.shape[0], x.shape[1] // 4, x.shape[2] * 2, x.shape[3] * 2] f = lambda i: x[:, size[1] * i : size[1] * (i + 1)] out = torch.zeros(size, dtype=x.dtype, device=x.device) out[:,:,0::2,0::2] = self.alpha * (f(0)*self.beta + f(1) + f(2) + f(3)) out[:,:,0::2,1::2] = self.alpha * (f(0)*self.beta + f(1) - f(2) - f(3)) out[:,:,1::2,0::2] = self.alpha * (f(0)*self.beta - f(1) + f(2) - f(3)) out[:,:,1::2,1::2] = self.alpha * (f(0)*self.beta - f(1) - f(2) + f(3)) return out ================================================ FILE: torchtools/nn/magnitude_preserving.py ================================================ import torch from torch import nn class MP_GELU(nn.GELU): def forward(self, x): return super().forward(x) / 0.652 # ¯\_(ツ)_/¯ class MP_SiLU(nn.SiLU): def forward(self, x): return super().forward(x) / 0.596 # ¯\_(ツ)_/¯ class Gain(nn.Module): def __init__(self, init_w=0.0): super().__init__() self.g = nn.Parameter(torch.tensor([init_w])) def forward(self, x): return x * self.g ================================================ FILE: torchtools/nn/mish.py ================================================ #### # CODE TAKEN FROM https://github.com/lessw2020/mish # ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1 #### import torch import torch.nn as nn import torch.nn.functional as F #(uncomment if needed,but you likely already have it) #Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" #https://arxiv.org/abs/1908.08681v1 #implemented for PyTorch / FastAI by lessw2020 #github: https://github.com/lessw2020/mish class Mish(nn.Module): def __init__(self): super().__init__() def forward(self, x): #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) return x *( torch.tanh(F.softplus(x))) ================================================ FILE: torchtools/nn/modulation.py ================================================ import torch from torch import nn import math #### # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 # But made it extend from the base Conv2d to avoid some boilerplate #### class ModulatedConv2d(nn.Conv2d): def __init__(self, *args, demodulate=True, ema_decay=1.0, **kwargs): super().__init__(*args, **kwargs) fan_in = self.in_channels * self.kernel_size[0] ** 2 self.scale = 1 / math.sqrt(fan_in) self.demodulate = demodulate self.ema_decay = ema_decay self.register_buffer("ema_var", torch.tensor(1.0)) nn.init.normal_(self.weight) if self.bias is not None: nn.init.zeros_(self.bias) def forward(self, x, w): batch, in_channels, height, width = x.shape style = w.view(batch, 1, in_channels, 1, 1) weight = self.scale * self.weight.unsqueeze(0) * style if self.demodulate: demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) weight = weight * demod.view(batch, self.out_channels, 1, 1, 1) weight = weight.view( batch * self.out_channels, in_channels, self.kernel_size[0], self.kernel_size[1] ) if self.ema_decay < 1: if self.training: var = x.pow(2).mean((0, 1, 2, 3)) self.ema_var.mul_(self.ema_decay).add_(var.detach(), alpha=1 - self.ema_decay) weight = weight / (torch.sqrt(self.ema_var) + 1e-8) input = x.view(1, batch * in_channels, height, width) self.groups = batch out = self._conv_forward(input, weight, None) _, _, height, width = out.shape out = out.view(batch, self.out_channels, height, width) if self.bias is not None: out = out + self.bias.view(1, -1, 1, 1) return out ================================================ FILE: torchtools/nn/perceptual.py ================================================ import torch import torch.nn as nn from .functional import total_variation class TVLoss(nn.Module): def __init__(self, reduction='sum', alpha=1e-4): super(TVLoss, self).__init__() self.reduction = reduction self.alpha = alpha def forward(self, x): return total_variation(x, reduction=self.reduction) * self.alpha ================================================ FILE: torchtools/nn/pixel_normalzation.py ================================================ import torch from torch import nn class PixelNorm(nn.Module): def __init__(self, dim=1, eps=1e-4): super().__init__() self.dim = dim self.eps = eps def forward(self, x): return x / (torch.sqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True)) + self.eps) ================================================ FILE: torchtools/nn/pos_embeddings.py ================================================ import torch from torch import nn #### # CODE TAKEN FROM https://github.com/lucidrains/x-transformers #### class RotaryEmbedding(nn.Module): def __init__(self, dim, base=10000): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) self.seq_len_cached = None self.cos_cached = None self.sin_cached = None def forward(self, x, seq_dim=1): seq_len = x.shape[seq_dim] if seq_len != self.seq_len_cached: self.seq_len_cached = seq_len t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) freqs = torch.einsum('i,j->ij', t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(x.device) self.cos_cached = emb.cos()[:, None, None, :] self.sin_cached = emb.sin()[:, None, None, :] return self.cos_cached, self.sin_cached # rotary pos emb helpers: def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 @torch.jit.script def apply_rotary_pos_emb(q, k, cos, sin): return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) ================================================ FILE: torchtools/nn/simple_self_attention.py ================================================ import torch.nn as nn import torch, math, sys #### # CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention #### #Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): "Create and initialize a `nn.Conv1d` layer with spectral normalization." conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) nn.init.kaiming_normal_(conv.weight) if bias: conv.bias.data.zero_() return nn.utils.spectral_norm(conv) # Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py # Inspired by https://arxiv.org/pdf/1805.08318.pdf class SimpleSelfAttention(nn.Module): def __init__(self, n_in, ks=1, sym=False): super().__init__() self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False) self.gamma = nn.Parameter(torch.Tensor([0.])) self.sym = sym self.n_in = n_in def forward(self, x): if self.sym: # symmetry hack by https://github.com/mgrankin c = self.conv.weight.view(self.n_in,self.n_in) c = (c + c.t())/2 self.conv.weight = c.view(self.n_in,self.n_in,1) size = x.size() x = x.view(*size[:2],-1) # (C,N) # changed the order of mutiplication to avoid O(N^2) complexity # (x*xT)*(W*x) instead of (x*(xT*(W*x))) convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2) xxT = torch.bmm(x, x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2) o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2) o = self.gamma * o + x return o.view(*size).contiguous() ================================================ FILE: torchtools/nn/stylegan2/__init__.py ================================================ from .upfirdn2d import upfirdn2d ================================================ FILE: torchtools/nn/stylegan2/upfirdn2d.cpp ================================================ #include #include torch::Tensor upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1); #define CHECK_CUDA(x) \ TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) \ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { CHECK_INPUT(input); CHECK_INPUT(kernel); at::DeviceGuard guard(input.device()); return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); } ================================================ FILE: torchtools/nn/stylegan2/upfirdn2d.py ================================================ from collections import abc import os import torch from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load module_path = os.path.dirname(__file__) upfirdn2d_op = load( "upfirdn2d", sources=[ os.path.join(module_path, "upfirdn2d.cpp"), os.path.join(module_path, "upfirdn2d_kernel.cu"), ], ) class UpFirDn2dBackward(Function): @staticmethod def forward( ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size ): up_x, up_y = up down_x, down_y = down g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) grad_input = upfirdn2d_op.upfirdn2d( grad_output, grad_kernel, down_x, down_y, up_x, up_y, g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1, ) grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) ctx.save_for_backward(kernel) pad_x0, pad_x1, pad_y0, pad_y1 = pad ctx.up_x = up_x ctx.up_y = up_y ctx.down_x = down_x ctx.down_y = down_y ctx.pad_x0 = pad_x0 ctx.pad_x1 = pad_x1 ctx.pad_y0 = pad_y0 ctx.pad_y1 = pad_y1 ctx.in_size = in_size ctx.out_size = out_size return grad_input @staticmethod def backward(ctx, gradgrad_input): kernel, = ctx.saved_tensors gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) gradgrad_out = upfirdn2d_op.upfirdn2d( gradgrad_input, kernel, ctx.up_x, ctx.up_y, ctx.down_x, ctx.down_y, ctx.pad_x0, ctx.pad_x1, ctx.pad_y0, ctx.pad_y1, ) # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) gradgrad_out = gradgrad_out.view( ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] ) return gradgrad_out, None, None, None, None, None, None, None, None class UpFirDn2d(Function): @staticmethod def forward(ctx, input, kernel, up, down, pad): up_x, up_y = up down_x, down_y = down pad_x0, pad_x1, pad_y0, pad_y1 = pad kernel_h, kernel_w = kernel.shape batch, channel, in_h, in_w = input.shape ctx.in_size = input.shape input = input.reshape(-1, in_h, in_w, 1) ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x ctx.out_size = (out_h, out_w) ctx.up = (up_x, up_y) ctx.down = (down_x, down_y) ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) g_pad_x0 = kernel_w - pad_x0 - 1 g_pad_y0 = kernel_h - pad_y0 - 1 g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) out = upfirdn2d_op.upfirdn2d( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ) # out = out.view(major, out_h, out_w, minor) out = out.view(-1, channel, out_h, out_w) return out @staticmethod def backward(ctx, grad_output): kernel, grad_kernel = ctx.saved_tensors grad_input = None if ctx.needs_input_grad[0]: grad_input = UpFirDn2dBackward.apply( grad_output, kernel, grad_kernel, ctx.up, ctx.down, ctx.pad, ctx.g_pad, ctx.in_size, ctx.out_size, ) return grad_input, None, None, None, None def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): if not isinstance(up, abc.Iterable): up = (up, up) if not isinstance(down, abc.Iterable): down = (down, down) if len(pad) == 2: pad = (pad[0], pad[1], pad[0], pad[1]) if input.device.type == "cpu": out = upfirdn2d_native(input, kernel, *up, *down, *pad) else: out = UpFirDn2d.apply(input, kernel, up, down, pad) return out def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): _, channel, in_h, in_w = input.shape input = input.reshape(-1, in_h, in_w, 1) _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape out = input.view(-1, in_h, 1, in_w, 1, minor) out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) out = F.pad( out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] ) out = out[ :, max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), :, ] out = out.permute(0, 3, 1, 2) out = out.reshape( [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] ) w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) out = F.conv2d(out, w) out = out.reshape( -1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) out = out[:, ::down_y, ::down_x, :] out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x return out.view(-1, channel, out_h, out_w) ================================================ FILE: torchtools/nn/stylegan2/upfirdn2d_kernel.cu ================================================ // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. // // This work is made available under the Nvidia Source Code License-NC. // To view a copy of this license, visit // https://nvlabs.github.io/stylegan2/license.html #include #include #include #include #include #include #include static __device__ __forceinline__ int floor_div(int a, int b) { int t = 1 - a / b; return (a + t * b) / b - t; } struct UpFirDn2DKernelParams { int up_x; int up_y; int down_x; int down_y; int pad_x0; int pad_x1; int pad_y0; int pad_y1; int major_dim; int in_h; int in_w; int minor_dim; int kernel_h; int kernel_w; int out_h; int out_w; int loop_major; int loop_x; }; template __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; int out_y = minor_idx / p.minor_dim; minor_idx -= out_y * p.minor_dim; int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; int major_idx_base = blockIdx.z * p.loop_major; if (out_x_base >= p.out_w || out_y >= p.out_h || major_idx_base >= p.major_dim) { return; } int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major && major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, out_x = out_x_base; loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; const scalar_t *x_p = &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; int x_px = p.minor_dim; int k_px = -p.up_x; int x_py = p.in_w * p.minor_dim; int k_py = -p.up_y * p.kernel_w; scalar_t v = 0.0f; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += static_cast(*x_p) * static_cast(*k_p); x_p += x_px; k_p += k_px; } x_p += x_py - w * x_px; k_p += k_py - w * k_px; } out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } template __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, const scalar_t *kernel, const UpFirDn2DKernelParams p) { const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; __shared__ volatile float sk[kernel_h][kernel_w]; __shared__ volatile float sx[tile_in_h][tile_in_w]; int minor_idx = blockIdx.x; int tile_out_y = minor_idx / p.minor_dim; minor_idx -= tile_out_y * p.minor_dim; tile_out_y *= tile_out_h; int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; int major_idx_base = blockIdx.z * p.loop_major; if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { return; } for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { int ky = tap_idx / kernel_w; int kx = tap_idx - ky * kernel_w; scalar_t v = 0.0; if (kx < p.kernel_w & ky < p.kernel_h) { v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; } sk[ky][kx] = v; } for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; int tile_in_x = floor_div(tile_mid_x, up_x); int tile_in_y = floor_div(tile_mid_y, up_y); __syncthreads(); for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { int rel_in_y = in_idx / tile_in_w; int rel_in_x = in_idx - rel_in_y * tile_in_w; int in_x = rel_in_x + tile_in_x; int in_y = rel_in_y + tile_in_y; scalar_t v = 0.0; if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; } sx[rel_in_y][rel_in_x] = v; } __syncthreads(); for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { int rel_out_y = out_idx / tile_out_w; int rel_out_x = out_idx - rel_out_y * tile_out_w; int out_x = rel_out_x + tile_out_x; int out_y = rel_out_y + tile_out_y; int mid_x = tile_mid_x + rel_out_x * down_x; int mid_y = tile_mid_y + rel_out_y * down_y; int in_x = floor_div(mid_x, up_x); int in_y = floor_div(mid_y, up_y); int rel_in_x = in_x - tile_in_x; int rel_in_y = in_y - tile_in_y; int kernel_x = (in_x + 1) * up_x - mid_x - 1; int kernel_y = (in_y + 1) * up_y - mid_y - 1; if (out_x < p.out_w & out_y < p.out_h) { scalar_t v = 0.0; #pragma unroll for (int y = 0; y < kernel_h / up_y; y++) #pragma unroll for (int x = 0; x < kernel_w / up_x; x++) v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; } } } } } torch::Tensor upfirdn2d_op(const torch::Tensor &input, const torch::Tensor &kernel, int up_x, int up_y, int down_x, int down_y, int pad_x0, int pad_x1, int pad_y0, int pad_y1) { int curDevice = -1; cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); UpFirDn2DKernelParams p; auto x = input.contiguous(); auto k = kernel.contiguous(); p.major_dim = x.size(0); p.in_h = x.size(1); p.in_w = x.size(2); p.minor_dim = x.size(3); p.kernel_h = k.size(0); p.kernel_w = k.size(1); p.up_x = up_x; p.up_y = up_y; p.down_x = down_x; p.down_y = down_y; p.pad_x0 = pad_x0; p.pad_x1 = pad_x1; p.pad_y0 = pad_y0; p.pad_y1 = pad_y1; p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); int mode = -1; int tile_out_h = -1; int tile_out_w = -1; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { void *cuda_kernel = (void *)upfirdn2d_kernel_large; if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 1 && p.kernel_w <= 24) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 128; } if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 1 && p.kernel_w <= 12) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 128; } if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 24 && p.kernel_w <= 1) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 32; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 12 && p.kernel_w <= 1) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 32; tile_out_w = 32; } // if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 && p.kernel_h <= 1 && p.kernel_w <= 24) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 && p.kernel_h <= 1 && p.kernel_w <= 12) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 && p.kernel_h <= 24 && p.kernel_w <= 1) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 && p.kernel_h <= 12 && p.kernel_w <= 1) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 32; } // if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 16; tile_out_w = 64; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 32; } if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { cuda_kernel = (void *)upfirdn2d_kernel; tile_out_h = 8; tile_out_w = 32; } dim3 block_size; dim3 grid_size; if (tile_out_h > 0 && tile_out_w > 0) { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 1; block_size = dim3(32 * 8, 1, 1); grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, (p.major_dim - 1) / p.loop_major + 1); } else { p.loop_major = (p.major_dim - 1) / 16384 + 1; p.loop_x = 4; block_size = dim3(4, 32, 1); grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, (p.out_w - 1) / (p.loop_x * block_size.y) + 1, (p.major_dim - 1) / p.loop_major + 1); } scalar_t *out_p = out.data_ptr(); scalar_t *x_p = x.data_ptr(); scalar_t *k_p = k.data_ptr(); void *args[] = {&out_p, &x_p, &k_p, &p}; AT_CUDA_CHECK( cudaLaunchKernel(cuda_kernel, grid_size, block_size, args, 0, stream)); }); return out; } ================================================ FILE: torchtools/nn/transformers.py ================================================ import torch.nn as nn # Based on the GPT2 implementatyion from MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy class GPTTransformerEncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) self.mlp = nn.Sequential( nn.Linear(d_model, dim_feedforward), nn.GELU(), nn.Linear(dim_feedforward, d_model), nn.Dropout(dropout), ) def forward(self, x, src_mask=None, src_key_padding_mask=None): x = self.ln1(x) x = x + self.attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] x = x + self.mlp(self.ln2(x)) return x ================================================ FILE: torchtools/nn/vq.py ================================================ import torch from torch import nn from .functional.vq import vector_quantize, binarize import numpy as np class VectorQuantize(nn.Module): def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): """ Takes an input of variable size (as long as the last dimension matches the embedding size). Returns one tensor containing the nearest neigbour embeddings to each of the inputs, with the same size as the input, vq and commitment components for the loss as a touple in the second output and the indices of the quantized vectors in the third: quantized, (vq_loss, commit_loss), indices """ super(VectorQuantize, self).__init__() self.codebook = nn.Embedding(k, embedding_size) self.codebook.weight.data.uniform_(-1./k, 1./k) self.vq = vector_quantize.apply self.ema_decay = ema_decay self.ema_loss = ema_loss if ema_loss: self.register_buffer('ema_element_count', torch.ones(k)) self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) def _laplace_smoothing(self, x, epsilon): n = torch.sum(x) return ((x + epsilon) / (n + x.size(0) * epsilon) * n) def _updateEMA(self, z_e_x, indices): mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() elem_count = mask.sum(dim=0) weight_sum = torch.mm(mask.t(), z_e_x) self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) def idx2vq(self, idx, dim=-1): q_idx = self.codebook(idx) if dim != -1: q_idx = q_idx.movedim(-1, dim) return q_idx def forward(self, x, get_losses=True, dim=-1): if dim != -1: x = x.movedim(dim, -1) z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) vq_loss, commit_loss = None, None if self.ema_loss and self.training: self._updateEMA(z_e_x.detach(), indices.detach()) # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) if get_losses: vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() z_q_x = z_q_x.view(x.shape) if dim != -1: z_q_x = z_q_x.movedim(-1, dim) return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) class Binarize(nn.Module): def __init__(self, threshold=0.5): """ Takes an input of any size. Returns an output of the same size but with its values binarized (0 if input is below a threshold, 1 if its above) """ super(Binarize, self).__init__() self.bin = binarize.apply self.threshold = threshold def forward(self, x): return self.bin(x, self.threshold) # Finite Scalar Quantization: https://arxiv.org/abs/2309.15505 class FSQ(nn.Module): def __init__(self, bins, dim=-1, eps=1e-1): super().__init__() self.dim = dim self.eps = eps self.register_buffer('bins', torch.tensor(bins)) self.register_buffer('bases', torch.tensor([1] + np.cumprod(bins[:-1]).tolist())) self.codebook_size = np.prod(bins) self.in_shift, self.out_shift = None, None def _round(self, x, quantize): x = x.sigmoid() * (1-1e-7) if quantize is True: x_rounded = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins) x = x + (x_rounded - x).detach() x_sigmoid = x x = (x / (1-1e-7)).logit(eps=self.eps) return x, x_sigmoid def vq_to_idx(self, x, is_sigmoid=False): if not is_sigmoid: x = x.sigmoid() * (1-1e-7) x = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins) x = x.mul(self.bins-1).long() x = (x * self.bases).sum(dim=-1).long() return x def idx_to_vq(self, x): x = x.unsqueeze(-1) // self.bases % self.bins x = x.div(self.bins-1) x = (x / (1-1e-7)).logit(eps=self.eps) if self.dim != -1: x = x.movedim(-1, self.dim) return x def forward(self, x, quantize=True): if self.dim != -1: x = x.movedim(self.dim, -1) x, x_sigmoid = self._round(x, quantize=quantize) idx = self.vq_to_idx(x_sigmoid, is_sigmoid=True) if self.dim != -1: x = x.movedim(-1, self.dim) return x, idx ================================================ FILE: torchtools/optim/__init__.py ================================================ from .radam import RAdam, PlainRAdam, AdamW from .ranger import Ranger from .lookahead import Lookahead, LookaheadAdam from .over9000 import Over9000, RangerLars from .ralamb import Ralamb from .novograd import Novograd from .lamb import Lamb ================================================ FILE: torchtools/optim/lamb.py ================================================ #### # CODE TAKEN FROM https://github.com/mgrankin/over9000 #### import collections import math import torch from torch.optim import Optimizer try: from tensorboardX import SummaryWriter def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): """Log a histogram of trust ratio scalars in across layers.""" results = collections.defaultdict(list) for group in optimizer.param_groups: for p in group['params']: state = optimizer.state[p] for i in ('weight_norm', 'adam_norm', 'trust_ratio'): if i in state: results[i].append(state[i]) for k, v in results.items(): event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) except ModuleNotFoundError as e: print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results") class Lamb(Optimizer): r"""Implements Lamb algorithm. It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) adam (bool, optional): always use trust ratio = 1, which turns this into Adam. Useful for comparison purposes. .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: https://arxiv.org/abs/1904.00962 """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) self.adam = adam super(Lamb, self).__init__(params, defaults) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 # Decay the first and second moment running average coefficient # m_t exp_avg.mul_(beta1).add_(1 - beta1, grad) # v_t exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) # Paper v3 does not use debiasing. # bias_correction1 = 1 - beta1 ** state['step'] # bias_correction2 = 1 - beta2 ** state['step'] # Apply bias to lr to avoid broadcast. step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) if group['weight_decay'] != 0: adam_step.add_(group['weight_decay'], p.data) adam_norm = adam_step.pow(2).sum().sqrt() if weight_norm == 0 or adam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / adam_norm state['weight_norm'] = weight_norm state['adam_norm'] = adam_norm state['trust_ratio'] = trust_ratio if self.adam: trust_ratio = 1 p.data.add_(-step_size * trust_ratio, adam_step) return loss ================================================ FILE: torchtools/optim/lookahead.py ================================================ #### # CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch # Original paper: https://arxiv.org/abs/1907.08610 #### # Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py """ Lookahead Optimizer Wrapper. Implementation modified from: https://github.com/alphadl/lookahead.pytorch Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 """ import torch from torch.optim.optimizer import Optimizer from collections import defaultdict class Lookahead(Optimizer): def __init__(self, base_optimizer, alpha=0.5, k=6): if not 0.0 <= alpha <= 1.0: raise ValueError(f'Invalid slow update rate: {alpha}') if not 1 <= k: raise ValueError(f'Invalid lookahead steps: {k}') defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) self.base_optimizer = base_optimizer self.param_groups = self.base_optimizer.param_groups self.defaults = base_optimizer.defaults self.defaults.update(defaults) self.state = defaultdict(dict) # manually add our defaults to the param groups for name, default in defaults.items(): for group in self.param_groups: group.setdefault(name, default) def update_slow(self, group): for fast_p in group["params"]: if fast_p.grad is None: continue param_state = self.state[fast_p] if 'slow_buffer' not in param_state: param_state['slow_buffer'] = torch.empty_like(fast_p.data) param_state['slow_buffer'].copy_(fast_p.data) slow = param_state['slow_buffer'] slow.add_(group['lookahead_alpha'], fast_p.data - slow) fast_p.data.copy_(slow) def sync_lookahead(self): for group in self.param_groups: self.update_slow(group) def step(self, closure=None): # print(self.k) # assert id(self.param_groups) == id(self.base_optimizer.param_groups) loss = self.base_optimizer.step(closure) for group in self.param_groups: group['lookahead_step'] += 1 if group['lookahead_step'] % group['lookahead_k'] == 0: self.update_slow(group) return loss def state_dict(self): fast_state_dict = self.base_optimizer.state_dict() slow_state = { (id(k) if isinstance(k, torch.Tensor) else k): v for k, v in self.state.items() } fast_state = fast_state_dict['state'] param_groups = fast_state_dict['param_groups'] return { 'state': fast_state, 'slow_state': slow_state, 'param_groups': param_groups, } def load_state_dict(self, state_dict): fast_state_dict = { 'state': state_dict['state'], 'param_groups': state_dict['param_groups'], } self.base_optimizer.load_state_dict(fast_state_dict) # We want to restore the slow state, but share param_groups reference # with base_optimizer. This is a bit redundant but least code slow_state_new = False if 'slow_state' not in state_dict: print('Loading state_dict from optimizer without Lookahead applied.') state_dict['slow_state'] = defaultdict(dict) slow_state_new = True slow_state_dict = { 'state': state_dict['slow_state'], 'param_groups': state_dict['param_groups'], # this is pointless but saves code } super(Lookahead, self).load_state_dict(slow_state_dict) self.param_groups = self.base_optimizer.param_groups # make both ref same container if slow_state_new: # reapply defaults to catch missing lookahead specific ones for name, default in self.defaults.items(): for group in self.param_groups: group.setdefault(name, default) def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): adam = Adam(params, *args, **kwargs) return Lookahead(adam, alpha, k) ================================================ FILE: torchtools/optim/novograd.py ================================================ #### # CODE TAKEN FROM https://github.com/mgrankin/over9000 #### # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.optim import Optimizer import math class AdamW(Optimizer): """Implements AdamW algorithm. It has been proposed in `Adam: A Method for Stochastic Optimization`_. Arguments: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ Adam: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): super(AdamW, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group['eps']) else: denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)) return loss class Novograd(Optimizer): """ Implements Novograd algorithm. Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.95, 0)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) grad_averaging: gradient averaging amsgrad (boolean, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) """ def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8, weight_decay=0, grad_averaging=False, amsgrad=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, grad_averaging=grad_averaging, amsgrad=amsgrad) super(Novograd, self).__init__(params, defaults) def __setstate__(self, state): super(Novograd, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Sparse gradients are not supported.') amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 norm = torch.sum(torch.pow(grad, 2)) if exp_avg_sq == 0: exp_avg_sq.copy_(norm) else: exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group['eps']) else: denom = exp_avg_sq.sqrt().add_(group['eps']) grad.div_(denom) if group['weight_decay'] != 0: grad.add_(group['weight_decay'], p.data) if group['grad_averaging']: grad.mul_(1 - beta1) exp_avg.mul_(beta1).add_(grad) p.data.add_(-group['lr'], exp_avg) return loss ================================================ FILE: torchtools/optim/over9000.py ================================================ #### # CODE TAKEN FROM https://github.com/mgrankin/over9000 #### import torch, math from torch.optim.optimizer import Optimizer import itertools as it from .lookahead import Lookahead from .ralamb import Ralamb # RAdam + LARS + LookAHead # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 def Over9000(params, alpha=0.5, k=6, *args, **kwargs): ralamb = Ralamb(params, *args, **kwargs) return Lookahead(ralamb, alpha, k) RangerLars = Over9000 ================================================ FILE: torchtools/optim/radam.py ================================================ #### # CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam # Paper: https://arxiv.org/abs/1908.03265 #### import math import torch from torch.optim.optimizer import Optimizer, required class RAdam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) self.degenerated_to_sgd = degenerated_to_sgd if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): for param in params: if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): param['buffer'] = [[None, None, None] for _ in range(10)] defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) super(RAdam, self).__init__(params, defaults) def __setstate__(self, state): super(RAdam, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('RAdam does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 buffered = group['buffer'][int(state['step'] % 10)] if state['step'] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: buffered[0] = state['step'] beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) elif self.degenerated_to_sgd: step_size = 1.0 / (1 - beta1 ** state['step']) else: step_size = -1 buffered[2] = step_size # more conservative since it's an approximated value if N_sma >= 5: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) p.data.copy_(p_data_fp32) elif step_size > 0: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) p_data_fp32.add_(-step_size * group['lr'], exp_avg) p.data.copy_(p_data_fp32) return loss class PlainRAdam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) self.degenerated_to_sgd = degenerated_to_sgd defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super(PlainRAdam, self).__init__(params, defaults) def __setstate__(self, state): super(PlainRAdam, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('RAdam does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) state['step'] += 1 beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) # more conservative since it's an approximated value if N_sma >= 5: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) denom = exp_avg_sq.sqrt().add_(group['eps']) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) elif self.degenerated_to_sgd: if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) step_size = group['lr'] / (1 - beta1 ** state['step']) p_data_fp32.add_(-step_size, exp_avg) p.data.copy_(p_data_fp32) return loss class AdamW(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, warmup = warmup) super(AdamW, self).__init__(params, defaults) def __setstate__(self, state): super(AdamW, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) exp_avg.mul_(beta1).add_(1 - beta1, grad) denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] if group['warmup'] > state['step']: scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] else: scheduled_lr = group['lr'] step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) p_data_fp32.addcdiv_(-step_size, exp_avg, denom) p.data.copy_(p_data_fp32) return loss ================================================ FILE: torchtools/optim/ralamb.py ================================================ #### # CODE TAKEN FROM https://github.com/mgrankin/over9000 #### import torch, math from torch.optim.optimizer import Optimizer # RAdam + LARS class Ralamb(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) self.buffer = [[None, None, None] for ind in range(10)] super(Ralamb, self).__init__(params, defaults) def __setstate__(self, state): super(Ralamb, self).__setstate__(state) def step(self, closure=None): loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: raise RuntimeError('Ralamb does not support sparse gradients') p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: state['step'] = 0 state['exp_avg'] = torch.zeros_like(p_data_fp32) state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) else: state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] # Decay the first and second moment running average coefficient # m_t exp_avg.mul_(beta1).add_(1 - beta1, grad) # v_t exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) state['step'] += 1 buffered = self.buffer[int(state['step'] % 10)] if state['step'] == buffered[0]: N_sma, radam_step_size = buffered[1], buffered[2] else: buffered[0] = state['step'] beta2_t = beta2 ** state['step'] N_sma_max = 2 / (1 - beta2) - 1 N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) else: radam_step_size = 1.0 / (1 - beta1 ** state['step']) buffered[2] = radam_step_size if group['weight_decay'] != 0: p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) # more conservative since it's an approximated value radam_step = p_data_fp32.clone() if N_sma >= 5: denom = exp_avg_sq.sqrt().add_(group['eps']) radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) else: radam_step.add_(-radam_step_size * group['lr'], exp_avg) radam_norm = radam_step.pow(2).sum().sqrt() weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) if weight_norm == 0 or radam_norm == 0: trust_ratio = 1 else: trust_ratio = weight_norm / radam_norm state['weight_norm'] = weight_norm state['adam_norm'] = radam_norm state['trust_ratio'] = trust_ratio if N_sma >= 5: p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) else: p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) p.data.copy_(p_data_fp32) return loss ================================================ FILE: torchtools/optim/ranger.py ================================================ #### # CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer # Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d #### import math import torch from torch.optim.optimizer import Optimizer, required import itertools as it from .lookahead import Lookahead from .radam import RAdam def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs): radam = RAdam(params, betas=betas, *args, **kwargs) return Lookahead(radam, alpha, k) ================================================ FILE: torchtools/transforms/__init__.py ================================================ from .smart_crop import SmartCrop ================================================ FILE: torchtools/transforms/models/__init__.py ================================================ ================================================ FILE: torchtools/transforms/smart_crop.py ================================================ import torch import torchvision from torch import nn import numpy as np import os # MICRO RESNET class ResBlock(nn.Module): def __init__(self, channels): super(ResBlock, self).__init__() self.resblock = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(channels, channels, kernel_size=3), nn.InstanceNorm2d(channels, affine=True), ) def forward(self, x): out = self.resblock(x) return out + x class Upsample2d(nn.Module): def __init__(self, scale_factor): super(Upsample2d, self).__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor def forward(self, x): x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') return x class MicroResNet(nn.Module): def __init__(self): super(MicroResNet, self).__init__() self.downsampler = nn.Sequential( nn.ReflectionPad2d(4), nn.Conv2d(3, 8, kernel_size=9, stride=4), nn.InstanceNorm2d(8, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(8, 16, kernel_size=3, stride=2), nn.InstanceNorm2d(16, affine=True), nn.ReLU(), nn.ReflectionPad2d(1), nn.Conv2d(16, 32, kernel_size=3, stride=2), nn.InstanceNorm2d(32, affine=True), nn.ReLU(), ) self.residual = nn.Sequential( ResBlock(32), nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), ResBlock(64), ) self.segmentator = nn.Sequential( nn.ReflectionPad2d(1), nn.Conv2d(64, 16, kernel_size=3), nn.InstanceNorm2d(16, affine=True), nn.ReLU(), Upsample2d(scale_factor=2), nn.ReflectionPad2d(4), nn.Conv2d(16, 1, kernel_size=9), nn.Sigmoid() ) def forward(self, x): out = self.downsampler(x) out = self.residual(out) out = self.segmentator(out) return out # SmartCrop module class SmartCrop(nn.Module): def __init__(self, output_size, randomize_p=0.0, randomize_q=0.1, temperature=0.03): super().__init__() self.output_size = output_size self.randomize_p, self.randomize_q = randomize_p, randomize_q self.temperature = temperature if isinstance(self.output_size, int): self.output_size = (self.output_size, self.output_size) self.saliency_model = MicroResNet().eval().requires_grad_(False) checkpoint = torch.load(os.path.dirname(__file__) + "/models/saliency_model_v9.pt", map_location="cpu") self.saliency_model.load_state_dict(checkpoint) def forward(self, image): is_batch = len(image.shape) == 4 if not is_batch: image = image.unsqueeze(0) with torch.no_grad(): resized_image = torchvision.transforms.functional.resize(image, 240, antialias=True) saliency_map = self.saliency_model(resized_image) tempered_heatmap = saliency_map.view(saliency_map.size(0), -1).div(self.temperature).softmax(-1) tempered_heatmap = tempered_heatmap / tempered_heatmap.sum(dim=1) tempered_heatmap = (tempered_heatmap > tempered_heatmap.max(dim=-1)[0]*0.75).float() saliency_map = tempered_heatmap.view(*saliency_map.shape) # GET CENTROID coord_space = torch.cat([ torch.linspace(0, 1, saliency_map.size(-2))[None, None, :, None].expand(-1, -1, -1, saliency_map.size(-1)), torch.linspace(0, 1, saliency_map.size(-1))[None, None, None, :].expand(-1, -1, saliency_map.size(-2), -1), ], dim=1) centroid = (coord_space * saliency_map).sum(dim=[-1, -2]) / saliency_map.sum(dim=[-1, -2]) # CROP crops = [] for i in range(image.size(0)): if np.random.rand() < self.randomize_p: centroid[i, 0] += np.random.uniform(-self.randomize_q, self.randomize_q) centroid[i, 1] += np.random.uniform(-self.randomize_q, self.randomize_q) top = (centroid[i, 0]*image.size(-2)-self.output_size[-2]/2).clamp(min=0, max=max(0, image.size(-2)-self.output_size[-2])).int() left = (centroid[i, 1]*image.size(-1)-self.output_size[-1]/2).clamp(min=0, max=max(0, image.size(-1)-self.output_size[-1])).int() bottom, right = top + self.output_size[-2], left + self.output_size[-1] crop = image[i, :, top:bottom, left:right] if crop.size(-2) < self.output_size[-2] or crop.size(-1) < self.output_size[-1]: crop = torchvision.transforms.functional.center_crop(crop, self.output_size) crops.append(crop) if is_batch: crops = torch.stack(crops, dim=0) else: crops = crops[0] return crops ================================================ FILE: torchtools/utils/__init__.py ================================================ from .diffusion import Diffuzz from .diffusion2 import Diffuzz2 from .gamma_parametrization import apply_gamma_reparam, gamma_reparam_model, remove_gamma_reparam from .weight_normalization import apply_weight_norm, weight_norm_model, remove_weight_norm ================================================ FILE: torchtools/utils/diffusion.py ================================================ import torch # Samplers -------------------------------------------------------------------- class SimpleSampler(): def __init__(self, diffuzz): self.current_step = -1 self.diffuzz = diffuzz def __call__(self, *args, **kwargs): self.current_step += 1 return self.step(*args, **kwargs) def init_x(self, shape): return torch.randn(*shape, device=self.diffuzz.device) def step(self, x, t, t_prev, noise): raise NotImplementedError("You should override the 'apply' function.") class DDPMSampler(SimpleSampler): def step(self, x, t, t_prev, noise): alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) alpha = (alpha_cumprod / alpha_cumprod_prev) mu = (1.0 / alpha).sqrt() * (x - (1-alpha) * noise / (1-alpha_cumprod).sqrt()) std = ((1-alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) class DDIMSampler(SimpleSampler): def step(self, x, t, t_prev, noise): alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() dp_xt = (1 - alpha_cumprod_prev).sqrt() return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise class DPMSolverPlusPlusSampler(SimpleSampler): # FIXME: CURRENTLY NOT WORKING def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.q_ts = {} def _get_coef(self, alpha_cumprod): log_alpha_t = alpha_cumprod.log() alpha_t = log_alpha_t.exp() sigma_t = (1-alpha_t ** 2).sqrt() lambda_t = log_alpha_t - sigma_t.log() return alpha_t, sigma_t, lambda_t def init_x(self, shape): alpha_cumprod = self.diffuzz._alpha_cumprod(torch.ones(shape[0], device=self.diffuzz.device)).view(-1, *[1 for _ in shape[1:]]) return torch.randn(*shape, device=self.diffuzz.device) * self._get_coef(alpha_cumprod)[1] def step(self, x, t, t_prev, noise): alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) stride = (t_prev - t) if self.current_step == 0: alpha_t, sigma_t, _ = self._get_coef(alpha_cumprod) elif self.current_step == 1: alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod) _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next) h = lambda_t - lambda_t_next x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * self.q_ts[self.current_step-1] else: alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_cumprod_next_next = self.diffuzz._alpha_cumprod(t+stride*2).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod) _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next) _, _, lambda_t_next_next = self._get_coef(alpha_cumprod_next_next) h = lambda_t - lambda_t_next h_next = lambda_t_next - lambda_t_next_next r = h_next / h D = (1 + 1 / (2 * r)) * self.q_ts[self.current_step-1] - 1 / (2 * r) * self.q_ts[self.current_step-2] x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * D self.q_ts[self.current_step] = (x - sigma_t * noise) / alpha_t return x sampler_dict = { 'ddpm': DDPMSampler, 'ddim': DDIMSampler, 'dpmsolver++': DPMSolverPlusPlusSampler, } # Custom simplified foward/backward diffusion (cosine schedule) class Diffuzz(): def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(0.0001, 0.9999)): self.device = device self.s = torch.tensor([s]).to(device) self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 self.scaler = scaler self.cached_steps = None self.clamp_range = clamp_range if cache_steps is not None: self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) def _alpha_cumprod(self, t): if self.cached_steps is None: if self.scaler > 1: t = 1 - (1-t) ** self.scaler elif self.scaler < 1: t = t ** self.scaler alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod return alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) else: return self.cached_steps[t.mul(len(self.cached_steps)-1).long()] def diffuse(self, x, t, noise=None): # t -> [0, 1] if noise is None: noise = torch.randn_like(x) alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise def undiffuse(self, x, t, t_prev, noise, sampler=None): if sampler is None: sampler = DDPMSampler(self) return sampler(x, t, t_prev, noise) def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, unconditional_inputs=None, sampler='ddpm', half=False): r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) if isinstance(sampler, str): if sampler in sampler_dict: sampler = sampler_dict[sampler](self) else: raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") elif issubclass(sampler, SimpleSampler): sampler = sampler(self) else: raise ValueError("Sampler should be either a string or a SimpleSampler object.") preds = [] x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() if half: r_range = r_range.half() x = x.half() for i in range(0, timesteps): if mask is not None and x_init is not None: x_renoised, _ = self.diffuse(x_init, r_range[i]) x = x * mask + x_renoised * (1-mask) pred_noise = model(x, r_range[i], **model_inputs) if cfg is not None: if unconditional_inputs is None: unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} pred_noise_unconditional = model(x, r_range[i], **unconditional_inputs) pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) x = self.undiffuse(x, r_range[i], r_range[i+1], pred_noise, sampler=sampler) preds.append(x) return preds def p2_weight(self, t, k=1.0, gamma=1.0): alpha_cumprod = self._alpha_cumprod(t) return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma ================================================ FILE: torchtools/utils/diffusion2.py ================================================ import torch import numpy as np # Samplers -------------------------------------------------------------------- class SimpleSampler(): def __init__(self, diffuzz, mode="v"): self.current_step = -1 self.diffuzz = diffuzz if mode not in ['v', 'e', 'x']: raise Exception("Mode should be either 'v', 'e' or 'x'") self.mode = mode def __call__(self, *args, **kwargs): self.current_step += 1 return self.step(*args, **kwargs) def init_x(self, shape): return torch.randn(*shape, device=self.diffuzz.device) def step(self, x, t, t_prev, noise): raise NotImplementedError("You should override the 'apply' function.") # https://github.com/ozanciga/diffusion-for-beginners/blob/main/samplers/ddim.py class DDIMSampler(SimpleSampler): def step(self, x, t, t_prev, pred, eta=0): alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) sigma_tau = eta * ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)).sqrt() * (1 - alpha_cumprod / alpha_cumprod_prev).sqrt() if eta > 0 else 0 if self.mode == 'v': x0 = alpha_cumprod.sqrt() * x - (1-alpha_cumprod).sqrt() * pred noise = (1-alpha_cumprod).sqrt() * x + alpha_cumprod.sqrt() * pred elif self.mode == 'x': x0 = pred noise = (x - x0 * alpha_cumprod.sqrt()) / (1 - alpha_cumprod).sqrt() else: noise = pred x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / alpha_cumprod.sqrt() renoised = alpha_cumprod_prev.sqrt() * x0 + (1 - alpha_cumprod_prev - sigma_tau ** 2).sqrt() * noise + sigma_tau * torch.randn_like(x) return x0, renoised, pred class DDPMSampler(DDIMSampler): def step(self, x, t, t_prev, pred, eta=1): return super().step(x, t, t_prev, pred, eta) sampler_dict = { 'ddpm': DDPMSampler, 'ddim': DDIMSampler, } # Custom simplified foward/backward diffusion (cosine schedule) class Diffuzz2(): def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(1e-7, 1-1e-7)): self.device = device self.s = torch.tensor([s]).to(device) self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 self.scaler = 2 * np.log(1/scaler) self.cached_steps = None self.clamp_range = clamp_range if cache_steps is not None: self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) def _alpha_cumprod(self, t): if self.cached_steps is None: alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) if self.scaler != 1: alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(self.scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1]) return alpha_cumprod else: return self.cached_steps[t.mul(len(self.cached_steps)-1).long()] def scale_t(self, t, scaler): scaler = 2 * np.log(1/scaler) alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) if scaler != 1: alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1]) return (((alpha_cumprod * self._init_alpha_cumprod) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + self.s) - self.s def diffuse(self, x, t, noise=None): # t -> [0, 1] if noise is None: noise = torch.randn_like(x) alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise def get_v(self, x, t, noise): alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) # x0 = alpha_cumprod * noised − (1-alpha_cumprod).sqrt() * pred_v # noise = (1-alpha_cumprod).sqrt() * noised + alpha_cumprod * pred_v return alpha_cumprod.sqrt() * noise - (1-alpha_cumprod).sqrt() * x def x0_from_v(self, noised, pred_v, t): alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]]) return alpha_cumprod.sqrt() * noised - (1-alpha_cumprod).sqrt() * pred_v def noise_from_v(self, noised, pred_v, t): alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]]) return (1-alpha_cumprod).sqrt() * noised + alpha_cumprod.sqrt() * pred_v def undiffuse(self, x, t, t_prev, pred, sampler=None, **kwargs): if sampler is None: sampler = DDPMSampler(self) return sampler(x, t, t_prev, pred, **kwargs) def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_rho=0.7, unconditional_inputs=None, sampler='ddpm', dtype=None, sample_mode='v', sampler_params={}, t_scaler=1): r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) if t_scaler != 1: r_range = self.scale_t(r_range, t_scaler) if isinstance(sampler, str): if sampler in sampler_dict: sampler = sampler_dict[sampler](self, sample_mode) else: raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") elif issubclass(sampler, SimpleSampler): sampler = sampler(self, sample_mode) else: raise ValueError("Sampler should be either a string or a SimpleSampler object.") x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() if dtype is not None: r_range = r_range.to(dtype) x = x.to(dtype) if cfg is not None: if unconditional_inputs is None: unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} model_inputs = {k:torch.cat([v, v_u]) if isinstance(v, torch.Tensor) else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())} for i in range(0, timesteps): if mask is not None and x_init is not None: x_renoised, _ = self.diffuse(x_init, r_range[i]) x = x * mask + x_renoised * (1-mask) if cfg is not None: pred, pred_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), **model_inputs).chunk(2) pred_cfg = torch.lerp(pred_unconditional, pred, cfg) if cfg_rho > 0: std_pos, std_cfg = pred.std(), pred_cfg.std() pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) else: pred = pred_cfg else: pred = model(x, r_range[i], **model_inputs) diff_out = self.undiffuse(x, r_range[i], r_range[i+1], pred, sampler=sampler, **sampler_params) x = diff_out[1] altered_vars = yield diff_out # Update some running variables if the user wants if altered_vars is not None: cfg = altered_vars.get('cfg', cfg) cfg_rho = altered_vars.get('cfg_rho', cfg_rho) sampler = altered_vars.get('sampler', sampler) unconditional_inputs = altered_vars.get('unconditional_inputs', unconditional_inputs) model_inputs = altered_vars.get('model_inputs', model_inputs) x = altered_vars.get('x', x) mask = altered_vars.get('mask', mask) x_init = altered_vars.get('x_init', x_init) def p2_weight(self, t, k=1.0, gamma=1.0): alpha_cumprod = self._alpha_cumprod(t) return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma def truncated_snr_weight(self, t, min=1.0, max=None): alpha_cumprod = self._alpha_cumprod(t) srn = (alpha_cumprod / (1 - alpha_cumprod)) if min != None or max != None: srn = srn.clamp(min=min, max=max) return srn ================================================ FILE: torchtools/utils/gamma_parametrization.py ================================================ import torch from torch import nn class _GammaScaling(nn.Module): def __init__(self): super().__init__() self.gamma = nn.Parameter(torch.ones(1)) def forward(self, w): return w * self.gamma def apply_gamma_reparam(module, name="weight"): # this reparametrizes the parameters of a single module nn.utils.parametrizations.spectral_norm(module, name) nn.utils.parametrize.register_parametrization(module, name, _GammaScaling()) return module def gamma_reparam_model(model): for module in model.modules(): # this reparametrizes all linear layers of the model if isinstance(module, nn.Linear) and not torch.nn.utils.parametrize.is_parametrized(module, "weight"): apply_gamma_reparam(module, "weight") elif isinstance(module, nn.MultiheadAttention) and not torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): apply_gamma_reparam(module, "in_proj_weight") return model def remove_gamma_reparam(model): for module in model.modules(): if torch.nn.utils.parametrize.is_parametrized(module, "weight"): nn.utils.parametrize.remove_parametrizations(module, "weight") elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight") ================================================ FILE: torchtools/utils/weight_normalization.py ================================================ import torch from torch import nn class _WeigthNorm(nn.Module): def __init__(self, eps=1e-4): super().__init__() self.eps = eps def _normalize(self, w): norm_dims = list(range(1, len(w.shape))) w_norm = torch.linalg.vector_norm(w, dim=norm_dims, keepdim=True) # w_norm = torch.norm_except_dim(w, 2, 0).clone() return w / (w_norm + self.eps) def forward(self, w): if self.training: with torch.no_grad(): fan_in = w[0].numel()**0.5 w.data = self._normalize(w.data.clone()) * fan_in # w.copy_(self._normalize(w) * fan_in) return self._normalize(w) def apply_weight_norm(module, name="weight", init_weight=True): # this reparametrizes the parameters of a single module if init_weight: torch.nn.init.normal(getattr(module, name)) nn.utils.parametrize.register_parametrization(module, name, _WeigthNorm(), unsafe=True) return module def weight_norm_model(model, whitelist=None, init_weight=True): whitelist = whitelist or [] def check_parameter(module, name): return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(getattr(module, name), nn.Parameter) for name, module in model.named_modules(): # this reparametrizes all layers of the model that have a "weight" parameter if not any([w in name for w in whitelist]): if check_parameter(module, "weight"): apply_weight_norm(module, init_weight=init_weight) elif check_parameter(module, "in_proj_weight"): apply_weight_norm(module, 'in_proj_weight', init_weight=init_weight) return model def remove_weight_norm(model): for module in model.modules(): if torch.nn.utils.parametrize.is_parametrized(module, "weight"): nn.utils.parametrize.remove_parametrizations(module, "weight") elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight")