Repository: pabloppp/pytorch-tools
Branch: master
Commit: 6472bd5e2231
Files: 53
Total size: 135.7 KB
Directory structure:
gitextract_gjrwr8nv/
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── readme.md
├── setup.cfg
├── setup.py
└── torchtools/
├── __init__.py
├── lr_scheduler/
│ ├── __init__.py
│ ├── delayed.py
│ └── inverse_sqrt.py
├── nn/
│ ├── __init__.py
│ ├── adain.py
│ ├── alias_free_activation.py
│ ├── equal_layers.py
│ ├── evonorm2d.py
│ ├── fourier_features.py
│ ├── functional/
│ │ ├── __init__.py
│ │ ├── gradient_penalty.py
│ │ ├── magnitude_preserving.py
│ │ ├── perceptual.py
│ │ └── vq.py
│ ├── gp_loss.py
│ ├── haar_dwt.py
│ ├── magnitude_preserving.py
│ ├── mish.py
│ ├── modulation.py
│ ├── perceptual.py
│ ├── pixel_normalzation.py
│ ├── pos_embeddings.py
│ ├── simple_self_attention.py
│ ├── stylegan2/
│ │ ├── __init__.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── transformers.py
│ └── vq.py
├── optim/
│ ├── __init__.py
│ ├── lamb.py
│ ├── lookahead.py
│ ├── novograd.py
│ ├── over9000.py
│ ├── radam.py
│ ├── ralamb.py
│ └── ranger.py
├── transforms/
│ ├── __init__.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── saliency_model_v9.pt
│ └── smart_crop.py
└── utils/
├── __init__.py
├── diffusion.py
├── diffusion2.py
├── gamma_parametrization.py
└── weight_normalization.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*/**/__pycache__
*.egg-info
/dist
*.pyc
.idea
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Pablo Pernías
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Footer
================================================
FILE: MANIFEST.in
================================================
include LICENSE
include readme.md
recursive-include torchtools/transforms/models *
================================================
FILE: readme.md
================================================
# Pytorch Tools
## Install
Requirements:
```
PyTorch >= 1.0.0
Torchivision
Numpy >= 1.0.0
```
```
# In order to install the latest (beta) use
pip install git+https://github.com/pabloppp/pytorch-tools -U
# if you want to install a specific version to avoid breaking changes (for example, v0.3.5), use
pip install git+https://github.com/pabloppp/pytorch-tools@0.3.5 -U
```
# Current available tools
## Optimizers
Comparison table taken from https://github.com/mgrankin/over9000
And the article explaining this recent improvements https://medium.com/@lessw/how-we-beat-the-fastai-leaderboard-score-by-19-77-a-cbb2338fab5c
Dataset | LR Schedule| Imagenette size 128, 5 epoch | Imagewoof size 128, 5 epoch
--- | -- | --- | ---
Adam - baseline |OneCycle| 0.8493 | 0.6125
RangerLars (RAdam + LARS + Lookahead) |Flat and anneal| 0.8732 | 0.6523
Ralamb (RAdam + LARS) |Flat and anneal| 0.8675 | 0.6367
Ranger (RAdam + Lookahead) |Flat and anneal| 0.8594 | 0.5946
Novograd |Flat and anneal| 0.8711 | 0.6126
Radam |Flat and anneal| 0.8444 | 0.537
Lookahead |OneCycle| 0.8578 | 0.6106
Lamb |OneCycle| 0.8400 | 0.5597
### Ranger
Taken as is from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
Example of use:
```python
from torchtools.optim import Ranger
optimizer = Ranger(model.parameters())
```
### RAdam
Taken as is from https://github.com/LiyuanLucasLiu/RAdam
Blog post: https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b
Original Paper: https://arxiv.org/abs/1908.03265
Example of use:
```python
from torchtools.optim import RAdam, PlainRAdam, AdamW
optimizer = RAdam(model.parameters())
# optimizer = PlainRAdam(model.parameters())
# optimizer = AdamW(model.parameters())
```
### RangerLars (Over9000)
Taken as is from https://github.com/mgrankin/over9000
Example of use:
```python
from torchtools.optim import RangerLars # Over9000
optimizer = RangerLars(model.parameters())
```
### Novograd
Taken as is from https://github.com/mgrankin/over9000
Example of use:
```python
from torchtools.optim import Novograd
optimizer = Novograd(model.parameters())
```
### Ralamb
Taken as is from https://github.com/mgrankin/over9000
Example of use:
```python
from torchtools.optim import Ralamb
optimizer = Ralamb(model.parameters())
```
### Lookahead
Taken as is from https://github.com/lonePatient/lookahead_pytorch
Original Paper: https://arxiv.org/abs/1907.08610
This lookahead can be used with any optimizer
Example of use:
```python
from torch import optim
from torchtools.optim import Lookahead
optimizer = optim.Adam(model.parameters(), lr=0.001)
optimizer = Lookahead(base_optimizer=optimizer, k=5, alpha=0.5)
# for a base Lookahead + Adam you can just do:
#
# from torchtools.optim import LookaheadAdam
```
### Lamb
Taken as is from https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
Original Paper: https://arxiv.org/abs/1904.00962
Example of use:
```python
from torchtools.optim import Lamb
optimizer = Lamb(model.parameters())
```
## LR Schedulers
### Delayed LR
Allows for a customizable number of initial steps where the learning rate remains fixed.
After those steps the learning rate will be updated according to the supplied scheduler's policy
Example of use:
```python
from torch import optim, nn
from torchtools.lr_scheduler import DelayerScheduler
optimizer = optim.Adam(model.parameters(), lr=0.001) # define here your optimizer, the lr that you set will be the one used for the initial delay steps
delay_epochs = 10
total_epochs = 20
base_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, delay_epochs) # delay the scheduler for 10 steps
delayed_scheduler = DelayerScheduler(optimizer, total_epochs - delay_epochs, base_scheduler)
for epoch in range(total_epochs):
# train(...)
delayed_scheduler.step()
# The lr will be 0.001 for the first 10 epochs, then will use the policy fro the base_scheduler for the rest of the epochs
# for a base DelayerScheduler + CosineAnnealingLR you can just do:
#
# from torchtools.lr_scheduler import DelayedCosineAnnealingLR
# scheduler = DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs) # the sum of both must be the total number of epochs
```
## Activations
### Mish
Original implementation: https://github.com/digantamisra98/Mish
Original Paper: https://arxiv.org/abs/1908.08681v1
Implementation taken as is from https://github.com/lessw2020/mish
Example of use:
```python
from torchtools.nn import Mish
# Then you can just use Mish as a replacement for any activation function, like ReLU
```
### AliasFreeActivation
Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 by Rosinality
I modularized this activation so it can be easily used inside of any model without having to deal with complex initialization.
This activation actually takes a lot of responsibility, since it internally defines the channels and size of the input based on a set of parameters, instead of receiving them as a parameter, this means that the rest of the layers (convolutions, positional embedding, etc...) must adapt to it.
Example of use:
```python
from torchtools.nn.alias_free_activation import AliasFreeActivation
from torchtools.nn import EqualLeakyReLU
# We can use the static function to get the filter parameters for a specific level.
# It can be specially usefull to obtain the initial size and channels.
max_size, max_channels = 256, 512
first_channels, first_size = AliasFreeActivation.alias_level_params(
0, max_levels=14, max_size=max_size, max_channels=max_channels
)[-2:]
class MyModel(nn.Module):
def __init__(self, level, max_levels=14, max_size=256, max_channels=512, margin=10):
...
# AdaIN will require the style vector to be 2*size
leaky_relu = EqualLeakyReLU(negative_slope=0.2)
self.activation = AliasFreeActivation(
leaky_relu, level, max_levels=max_levels, max_size=max_size, max_channels=max_channels, margin=margin
)
self.conv = nn.Conv2d(self.activation.channels_prev, self.activation.channels, kernel_size=3, padding=1)
...
def forward(self, x): # x the channels and size of X are dependent on the level of this module.
...
x = self.conv(x)
x = self.activation(x)
...
```
## Layers
### SimpleSelfAttention
Implementation taken as is from https://github.com/sdoria/SimpleSelfAttention
Example of use:
```python
from torchtools.nn import SimpleSelfAttention
# The input of the layer has to at least have 3 dimensions (B, C, N),
# the attention will be performed in the 2nd dimension.
#
# For images, the input will be internally reshaped to 3 dimensions,
# and reshaped back to the original shape before returning it
```
### PixelNorm
Inspired from https://github.com/github-pengge/PyTorch-progressive_growing_of_gans
Example of use:
```python
from torchtools.nn import PixelNorm
model = nn.Linear(
nn.Conv2d(...),
PixelNorm(),
nn.ReLU()
)
# It doesn't require any parameter, it just performs a simple element-wise normalization
# x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)
#
# Just use it as a regular layer, generally after convolutions and before ReLU
# (warning) since it performs a srtq root it's pretty slow if the layer sizes are big
```
### Adaptive Instance Normalization - AdaIN
Implementation based on https://github.com/SiskonEmilia/StyleGAN-PyTorch
Original Paper https://arxiv.org/abs/1703.06868
Example of use:
```python
from torchtools.nn import AdaIN
class MyModel(nn.Module):
def __init__(self, n_channels):
...
# AdaIN will require the style vector to be 2*size
self.style = nn.Linear(input_size, output_size*2)
self.adain = AdaIN(output_size)
...
def forward(self, x, w):
...
x = self.adain(x, self.style(w))
...
# AdaIN will "transfer" a style encoded in a latent vector w into any tensor x.
# In order to do this it first needs to be passed through a linear layer that will return 2 tensors (actually, one tensor of twice the size required, that we'll then split in 2)
# It will then perform an instance normalization to "whiten" the tensor, followed with a de-normalization but using the values generated by the linear layer, thus encoding the original vector w in the tensor.
```
### EvoNorm
Implementation taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98
Original Paper https://arxiv.org/abs/2004.02967
Example of use:
```python
from torchtools.nn import EvoNorm2D
model = nn.Linear(
nn.Conv2d(...),
EvoNorm2D(c_hidden), # For S0 version
# evoB0 = EvoNorm2D(input, affine = True, version = 'B0', training = True) # For B0 version
nn.ReLU()
)
```
### GPT Transformer Encoder Layer
Implementation based on MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy
It can be used as a drop-in replacement for the `torch.nn.TransformerEncoderLayer`
Example of use:
```python
from torchtools.nn import GPTTransformerEncoderLayer
class MyTransformer(nn.Module):
def __init__(self, n_channels):
...
encoder_layer = GPTTransformerEncoderLayer(d_model=512, nhead=8)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
...
```
### Stylegan2 ModulatedConv2d
Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 by Rosinality
It extends from `torch.nn.Conv2d` so you can use it as a drop-in replacement, the only Conv2d parámeter that you cannot use is 'groups' since it will be overriden for this to work.
It also includes a parameter `ema_decay` that will add the EMA normalization used in Alias-free GAN (defaults to 1, meaning that it's disabled)
Example of use:
```python
from torchtools.nn import ModulatedConv2d
class MyModel(nn.Module):
def __init__(self):
...
self.conv = ModulatedConv2d(16, 32, kernel_size=3, padding=1)
# SUGESTIONS:
# set bias=False if you want to handle bias on your own
# set demodulate=False for RGB output
# set ema_decay=0.9989 to imitate the alias-free gan setup
...
def forward(self, x, w):
...
x = self.conv(x, w) # 'x' is a 4D tensor (B x C x W x H) and 'w' is a 2D tensor (B x C)
...
```
### Equal Layers (EqualNorm, EqualLinear)
Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94
It extends the base classes (nn.Linear, nn.Conv2d, nn.LeakyReLU) so you can use this as a drop-in replacement, although it includes some optiona parameters.
Example of use:
```python
from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d
class MyModel(nn.Module):
def __init__(self):
...
self.linear = EqualLinear(16, 32, bias_init=1, lr_mul=0.01) # bias_init and lr_mul are extra optional params
self.leaky_relu = EqualLeakyReLU(negative_slope=0.2)
self.conv = EqualConv2d(16, 32, kernel_size=3, padding=1)
# Since this classes extend from the base classes, you can use all parameters from the original classes.
...
```
### FourierFeatures2d
Implementation inspired on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L88
but improved using my own understanding of how this should work...
It creates a 2d tensor of embeddings following a fourier series based on the parameters you provide, this features are dynamic, meaning that affine transformations can be applied to them in order to shift, rotate, and even scale (experimental).
```python
from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d
class MyModel(nn.Module):
def __init__(self, dim=256, margin=10, cutoff=2):
...
self.feats = FourierFeatures2d(4+margin*2, dim, cutoff) # optionally enable scaling with allow_scaling=True
# Also, you can randomize the frequencies if you plan on keeping them fixed, setting w_scale to any value > 0
...
def forward(self, affine):
...
embds = self.feats(affine) # 'affine' should be a Bx4 tensor, or Bx6 if scaling is enabled...
# the default or initial affine values should be [1, 0, 0, 0, 1, 1] => ([1, 0]: rotation, [0, 0]: shift, [1, 1]: scale)
...
```
## Criterions
### Gradient Penalty (for WGAN-GP)
Implementation taken with minor changes from https://github.com/caogang/wgan-gp
Original paper https://arxiv.org/pdf/1704.00028.pdf
Example of use:
```python
from torchtools.nn import GPLoss
# This criterion defines the gradient penalty for WGAN GP
# For an example of a training cycle refer to this repo https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L185
discriminator = ...
gpcriterion = GPLoss(discriminator) # l = 10 by default
gradient_penalty = gpcriterion(real_data, fake_data)
discriminator_loss = ... + gradient_penalty # add the gp component to the Wasserstein loss
```
### Total Variation Loss
Total Variation denoising https://www.wikiwand.com/en/Total_variation_denoising
Example of use:
```python
# This loss (or regularization) is usefull for removing artifacts and noise in generated images.
# It's widely used in style transfer.
from torchtools.nn import TVLoss
tvcriterion = TVLoss() # reduction = 'sum' and alpha = 1e-4 by default
G = ... # output image
tv_loss = tvcriterion(G)
loss = ... + tv_loss # add the tv loss component to your reconstruction loss
```
## Vector Quantization
### VectorQuantize: Encodding based quantization [(source)](torchtools/vq.py#L5)
This transforms any tensor to its quantized version using a codebook of embeddings.
It uses a traight-forward approach for applying the gradients.
Passing a tensor trough the **VectorQuantize** module will return a new tensor with the same dimension but changing each one of the tensors of the last dimension by the nearest neighbor from the codebook, which has a limited number of values, thus quantizing the tensor.
For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L4)
The output of the model is a quantized tensor, as well as a Touple of the loss components of the codebook (needed for training), and the indices of the quantized vectors in the form: `qx, (vq_loss, commit_loss), indices`
When **creating a new instance of the module**, it accepts the following parameters:
- **embedding_size**: the size of the embeddings used in the codebook, should match the last dimension of the tensor you want to quantize
- **k**: the size of the codebook, or number of embeddings.
- **ema_decay** (default=0.99): the Exponentially Moving Average decay used (this only will be used if ema_loss is True)
- **ema_loss** (default=False): Enables Exponentially Moving Average update of the codebook (instead of relying on gradient descent as EMA converges faster)
When **calling the forward method** of the module, it accepts the following parameters:
- **x**: this is the tensor you want to quantize, make sure the dimension that you want to quantize (by default is the last one) matches embedding_size defined when instantiating the module
- **get_losses** (default=True): when False, the vq_loss and commit_loss components of the output will both be None, this should speed up a little bit the model when used for inference.
- **dim** (default=-1): The dimension across which the input should be quantized.
Example of use:
```python
from torchtools.nn import VectorQuantize
e = torch.randn(1, 16, 16, 8) # create a random tensor with 8 as its last dimension size
vquantizer = VectorQuantize(8, k=32, ema_loss=True) # we create the module with embedding size of 8, a codebook of size 32 and make the codebook update using EMA
qe, (vq_loss, commit_loss), indices = vquantizer.forward(e) # we quantize our tensor while also getting the loss components and the indices
# NOTE While the model is in training mode, the codebook will always be updated when calling the forward method, in order to freeze the codebook for inference put it in evaluation mode with 'vquantizer.eval()'
# NOTE 2 In order to update the module properly, add the loss components to the final model loss before calling backward(), if you set ema_loss to true you only need to add the commit_loss to the total loss, an it's usually multiplied by a value between 0.1 and 2, being 0.25 a good default value
loss = ... # whatever loss you have for your final output
loss += commit_loss * 0.25
# loss += vq_loss # only if you didn't set the ema_loss to True
...
loss.backward()
optimizer.step()
```
---
### Binarize: binarize the input tensor [(source)](torchtools/vq.py#L55)
This transfors the values of a tensor into 0 and 1 depending if they're above or below a specified threshold.
It uses a traight-forward approach for applying the gradients, so it's effectively differentiable.
For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L36)
Example of use:
```python
from torchtools.nn import Binarize
e = torch.randn(8, 16) # create a random tensor with any dimension
binarizer = Binarize(threshold=0.5) # you can set the threshold you want, for example if your output was passed through a tanh activation, 0 might be a better theshold since tanh outputs values between -1 and 1
bq = binarizer(e) # will return a tensor with the same shape as e, but full of 0s and 1s
```
## Embeddings
### RotaryEmbedding
Implementation taken as is from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L161
Example of use:
```python
from torchtools.nn import RotaryEmbedding
class MyModel(nn.Module):
def __init__(self, dim):
...
self.rotary_pos_embd = RotaryEmbedding(dim)
...
def forward(self, x):
x = self.rotary_pos_embd(x)
...
```
## Diffusion
### Diffuzz
Custom (non-cached) continuous forward/backward diffusion.
It's not SUPER performant since it calculates all the required values on the fly instead of caching them (although I add a very simple cache where you can specify the number of steps that you want to cache), but in general I think this will not make an extremely big difference in terms of performance, simplifies a lot the code, and removes the concept of having a fixed number of timesteps for the forward diffusion (since I always found it weird to train a model assuming 1000 forward diffusion steps, and then using way less steps during inference) by using a continuous value between 0 and 1 to decide how much noise we'll be adding to the output (1 being pure gaussian noise).
During sampling, the same applies, instead of having a fixed number of steps, the diffuzz module will accept a noised input, a couple of values t & t_prev (between 0 and 1) and a predicted noise, and it will try to remove such noise in a scale such as to go from step t to step t_prev, so if we want to denoise in 10 steps we'll tell it to go from 1.0 to 0.9, then to 0.8, etc... while if we want to denoise in 100 steps, we'll start at 1.0 and go to 0.99, then to 0.98, etc...
Example of use during training:
```python
from torchtools.utils import Diffuzz
device = "cuda"
diffuzz = Diffuzz(device=device)
# diffuzz = Diffuzz(device=device, cache_steps=10000) # optionally you can pass a 'cache_steps' parameter to speed up the noising process
custom_unet = CustomUnet().to(device) # Custom model whith output size = input size
input_tensor = torch.randn(8, 3, 16, 16, device=device) # an image, audio signal, or whatever...
t = torch.rand(input_tensor.size(0), device=device) # get a tensor with batch_size of values between 0 and 1
noised_tensor, noise = diffuzz.diffuse(input_tensor, t)
predicted_noise = custom_unet(noised_tensor, t)
loss = nn.functional.mse_loss(predicted_noise, noise)
# Optionally the diffuzz module provides loss gamma weighting (untested) but for this to work the loss
# should not be averaged on the batch dimension before applying it.
# loss = nn.functional.mse_loss(predicted_noise, noise, reduction='none').mean(dim=[1, 2, 3])
# loss = (loss * diffuzz.p2_weight(t)).mean()
```
Example of use for sampling:
```python
from torchtools.utils import Diffuzz
device = "cuda"
sampled = diffuzz.sample(
custom_unet, {'c': conditioning},
(conditioning.size(0), 3, 16, 16),
timesteps=20, sampler='ddim'
)[-1]
```
the `sample` method accepts a `sampler` parameter, currently only `ddpm` (default) and `ddim` are implemented, but I'm planning on adding more, very likely by borrowing (and appropriately citing) code from this repo https://github.com/ozanciga/diffusion-for-beginners
================================================
FILE: setup.cfg
================================================
[metadata]
description-file = readme.md
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(
name='torchtools',
packages=find_packages(),
description='PyTorch useful tools',
version='0.3.5',
url='https://github.com/pabloppp/pytorch-tools',
author='Pablo Pernías',
author_email='pablo@pernias.com',
keywords=['pip', 'pytorch', 'tools', 'RAdam', 'Lookahead', 'RALamb', 'quantization'],
zip_safe=False,
install_requires=[
'torch>=1.6',
'torchvision',
'numpy>=1.0',
'ninja>=1.0'
],
package_data={
'stylegan2.tools': ['torchtools/nn/stylegan2/*'],
'transforms.models': ['torchtools/transforms/models/*']
},
include_package_data=True,
)
================================================
FILE: torchtools/__init__.py
================================================
================================================
FILE: torchtools/lr_scheduler/__init__.py
================================================
from .delayed import DelayerScheduler, DelayedCosineAnnealingLR
from .inverse_sqrt import InverseSqrtLR
================================================
FILE: torchtools/lr_scheduler/delayed.py
================================================
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
class DelayerScheduler(_LRScheduler):
""" Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
Args:
optimizer (Optimizer): Wrapped optimizer.
delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
"""
def __init__(self, optimizer, delay_epochs, after_scheduler):
self.delay_epochs = delay_epochs
self.after_scheduler = after_scheduler
self.finished = False
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch >= self.delay_epochs:
if not self.finished:
self.after_scheduler.base_lrs = self.base_lrs
self.finished = True
return self.after_scheduler.get_lr()
return self.base_lrs
def step(self, epoch=None):
if self.finished:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.delay_epochs)
else:
return super(DelayerScheduler, self).step(epoch)
def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs):
base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs)
return DelayerScheduler(optimizer, delay_epochs, base_scheduler)
================================================
FILE: torchtools/lr_scheduler/inverse_sqrt.py
================================================
import warnings
from torch.optim.lr_scheduler import LRScheduler
class InverseSqrtLR(LRScheduler):
def __init__(self, optimizer, lr, warmup_steps, pre_warmup_lr=None, last_epoch=-1, verbose=False):
warmup_steps = max(warmup_steps, 1)
self.lr = lr * warmup_steps**0.5
self.warmup_steps = warmup_steps
self.pre_warmup_lr = pre_warmup_lr if pre_warmup_lr is not None else lr
super().__init__(optimizer, last_epoch, verbose)
def _process_lr(self, _):
warmup_factor = min(self.last_epoch/self.warmup_steps, 1) # this grows linearly from 0 to 1 during the warmup
base_lr = self.lr / max(self.last_epoch, self.warmup_steps)**0.5
return warmup_factor * base_lr + (1-warmup_factor)*self.pre_warmup_lr
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning)
lr = self._process_lr(self.lr)
return [lr for _ in self.optimizer.param_groups]
def _get_closed_form_lr(self):
return [self._process_lr(base_lr) for base_lr in self.base_lrs]
================================================
FILE: torchtools/nn/__init__.py
================================================
from .mish import Mish
from .simple_self_attention import SimpleSelfAttention
from .vq import VectorQuantize, Binarize, FSQ
from .gp_loss import GPLoss
from .pixel_normalzation import PixelNorm
from .perceptual import TVLoss
from .adain import AdaIN
from .transformers import GPTTransformerEncoderLayer
from .evonorm2d import EvoNorm2D
from .pos_embeddings import RotaryEmbedding
from .modulation import ModulatedConv2d
from .equal_layers import EqualConv2d, EqualLeakyReLU, EqualLinear
from .fourier_features import FourierFeatures2d
# from .alias_free_activation import AliasFreeActivation
from .magnitude_preserving import MP_GELU, MP_SiLU, Gain
from .haar_dwt import HaarForward, HaarInverse
================================================
FILE: torchtools/nn/adain.py
================================================
import torch
from torch import nn
class AdaIN(nn.Module):
def __init__(self, n_channels):
super(AdaIN, self).__init__()
self.norm = nn.InstanceNorm2d(n_channels)
def forward(self, image, style):
factor, bias = style.view(style.size(0), style.size(1), 1, 1).chunk(2, dim=1)
result = self.norm(image) * factor + bias
return result
================================================
FILE: torchtools/nn/alias_free_activation.py
================================================
import torch
from torch import nn
import math
from .stylegan2 import upfirdn2d
####
# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM
# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225
# But I simplified it into a single (almots) self-contained module.
# Probably I give this module too much reponsibility but meh...
####
class AliasFreeActivation(nn.Module):
def __init__(self, activation, level, max_levels, max_size, max_channels, margin, start_cutoff=2, critical_layers=2, window_size=6):
super().__init__()
self.activation = activation
# Filter features
self.cutoff, self.stopband, self.band_half, self.channels, self.size = self.alias_level_params(
level, max_levels, max_size, max_channels, start_cutoff, critical_layers
)
self.cutoff_prev, self.stopband_prev, self.band_half_prev, self.channels_prev, self.size_prev = self.alias_level_params(
max(level-1, 0), max_levels, max_size, max_channels, start_cutoff, critical_layers
)
# Filters
self.scale_factor = 2 if self.size_prev < self.size else 1
up_filter = self._lowpass_filter(
window_size * self.scale_factor * 2, self.cutoff_prev, self.band_half_prev, self.size * self.scale_factor * 2
)
self.register_buffer("up_filter", (up_filter / up_filter.sum()) * 2 * self.scale_factor)
down_filter = self._lowpass_filter(
window_size * self.scale_factor, self.cutoff, self.band_half, self.size * self.scale_factor * 2
)
self.register_buffer("down_filter", down_filter / down_filter.sum())
p = self.up_filter.shape[0] - (2*self.scale_factor)
self.up_pad = ((p + 1) // 2 + (2*self.scale_factor) - 1, p // 2)
p = self.down_filter.shape[0] - 2
self.down_pad = ((p + 1) // 2, p // 2)
self.margin = margin
@staticmethod
def alias_level_params(level, max_levels, max_size, max_channels, start_cutoff=2, critical_layers=2, base_channels=2**14):
end_cutoff = max_size//2
cutoff = start_cutoff * (end_cutoff / start_cutoff) ** min(level / (max_levels - critical_layers), 1)
start_stopband = start_cutoff ** 2.1
end_stopband = end_cutoff * (2 ** 0.3)
stopband = start_stopband * (end_stopband/start_stopband) ** min(level / (max_levels - critical_layers), 1)
size = 2 ** math.ceil(math.log(min(2 * stopband, max_size), 2))
band_half = max(stopband, size / 2) - cutoff
channels = min(round(base_channels / size), max_channels)
return cutoff, stopband, band_half, channels, size
def _lowpass_filter(self, n_taps, cutoff, band_half, sr):
window = self._kaiser_window(n_taps, band_half, sr)
ind = torch.arange(n_taps) - (n_taps - 1) / 2
lowpass = 2 * cutoff / sr * torch.sinc(2 * cutoff / sr * ind) * window
return lowpass
def _kaiser_window(self, n_taps, f_h, sr):
beta = self._kaiser_beta(n_taps, f_h, sr)
ind = torch.arange(n_taps) - (n_taps - 1) / 2
return torch.i0(beta * torch.sqrt(1 - ((2 * ind) / (n_taps - 1)) ** 2)) / torch.i0(torch.tensor(beta))
def _kaiser_attenuation(self, n_taps, f_h, sr):
df = (2 * f_h) / (sr / 2)
return 2.285 * (n_taps - 1) * math.pi * df + 7.95
def _kaiser_beta(self, n_taps, f_h, sr):
atten = self._kaiser_attenuation(n_taps, f_h, sr)
if atten > 50:
return 0.1102 * (atten - 8.7)
elif 50 >= atten >= 21:
return 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21)
else:
return 0.0
def forward(self, x):
x = self._upsample(x, self.up_filter, 2*self.scale_factor, pad=self.up_pad)
x = self.activation(x)
x = self._downsample(x, self.down_filter, 2, pad=self.down_pad)
if self.scale_factor > 1 and self.margin > 0:
m = self.scale_factor * self.margin // 2
x = x[:, :, m:-m, m:-m]
return x
def _upsample(self, x, kernel, factor, pad=(0, 0)):
x = upfirdn2d(x, kernel.unsqueeze(0), up=(factor, 1), pad=(*pad, 0, 0))
x = upfirdn2d(x, kernel.unsqueeze(1), up=(1, factor), pad=(0, 0, *pad))
return x
def _downsample(self, x, kernel, factor, pad=(0, 0)):
x = upfirdn2d(x, kernel.unsqueeze(0), down=(factor, 1), pad=(*pad, 0, 0))
x = upfirdn2d(x, kernel.unsqueeze(1), down=(1, factor), pad=(0, 0, *pad))
return x
def extra_repr(self):
info_string = f'cutoff={self.cutoff}, stopband={self.stopband}, band_half={self.band_half}, channels={self.channels}, size={self.size}'
return info_string
================================================
FILE: torchtools/nn/equal_layers.py
================================================
import torch
from torch import nn
import math
####
# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM
# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94
# But made it extend from the base modules to avoid some boilerplate
####
class EqualLinear(nn.Linear):
def __init__(self, *args, bias_init=0, lr_mul=1, **kwargs):
super().__init__(*args, **kwargs)
self.scale = (1 / math.sqrt(self.in_features)) * lr_mul
self.lr_mul = lr_mul
nn.init.normal_(self.weight, std=1/lr_mul)
if self.bias is not None:
nn.init.constant_(self.bias, bias_init)
def forward(self, x):
return nn.functional.linear(x, self.weight * self.scale, self.bias * self.lr_mul)
class EqualConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
fan_in = self.in_channels * self.kernel_size[0] ** 2
self.scale = 1 / math.sqrt(fan_in)
nn.init.normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x):
return self._conv_forward(x, self.weight * self.scale, self.bias)
class EqualLeakyReLU(nn.LeakyReLU):
def __init__(self, *args, scale=2**0.5, **kwargs):
super().__init__(*args, **kwargs)
self.scale = scale
def forward(self, x):
return super().forward(x) * self.scale
================================================
FILE: torchtools/nn/evonorm2d.py
================================================
import torch
import torch.nn as nn
## Taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
ctx.save_for_backward(i)
return i * torch.sigmoid(i)
@staticmethod
def backward(ctx, grad_output):
sigmoid_i = torch.sigmoid(ctx.saved_variables[0])
return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
def instance_std(x, eps=1e-5):
var = torch.var(x, dim = (2, 3), keepdim=True).expand_as(x)
if torch.isnan(var).any():
var = torch.zeros(var.shape)
return torch.sqrt(var + eps)
def group_std(x, groups = 32, eps = 1e-5):
N, C, H, W = x.size()
x = torch.reshape(x, (N, groups, C // groups, H, W))
var = torch.var(x, dim = (2, 3, 4), keepdim = True).expand_as(x)
return torch.reshape(torch.sqrt(var + eps), (N, C, H, W))
class EvoNorm2D(nn.Module):
def __init__(self, input, non_linear = True, version = 'S0', efficient = False, affine = True, momentum = 0.9, eps = 1e-5, groups = 32, training = True):
super(EvoNorm2D, self).__init__()
self.non_linear = non_linear
self.version = version
self.training = training
self.momentum = momentum
self.efficient = efficient
if self.version == 'S0':
self.swish = MemoryEfficientSwish()
self.groups = groups
self.eps = eps
if self.version not in ['B0', 'S0']:
raise ValueError("Invalid EvoNorm version")
self.insize = input
self.affine = affine
if self.affine:
self.gamma = nn.Parameter(torch.ones(1, self.insize, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, self.insize, 1, 1))
if self.non_linear:
self.v = nn.Parameter(torch.ones(1,self.insize,1,1))
else:
self.register_parameter('gamma', None)
self.register_parameter('beta', None)
self.register_buffer('v', None)
self.register_buffer('running_var', torch.ones(1, self.insize, 1, 1))
self.reset_parameters()
def reset_parameters(self):
self.running_var.fill_(1)
def _check_input_dim(self, x):
if x.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(x.dim()))
def forward(self, x):
self._check_input_dim(x)
if self.version == 'S0':
if self.non_linear:
if not self.efficient:
num = x * torch.sigmoid(self.v * x) # Original Swish Implementation, however memory intensive.
else:
num = self.swish(x) # Experimental Memory Efficient Variant of Swish
return num / group_std(x, groups = self.groups, eps = self.eps) * self.gamma + self.beta
else:
return x * self.gamma + self.beta
if self.version == 'B0':
if self.training:
var = torch.var(x, dim = (0, 2, 3), unbiased = False, keepdim = True)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var)
else:
var = self.running_var
if self.non_linear:
den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x, eps = self.eps))
return x / den * self.gamma + self.beta
else:
return x * self.gamma + self.beta
================================================
FILE: torchtools/nn/fourier_features.py
================================================
import torch
from torch import nn
import math
class FourierFeatures2d(nn.Module):
def __init__(self, size, dim, cutoff, affine_eps=1e-8, freq_range=[-0.5, 0.5], w_scale=0, allow_scaling=False, op_order=['r', 't', 's']):
super().__init__()
self.size = size
self.dim = dim
self.cutoff = cutoff
self.freq_range = freq_range
self.affine_eps = affine_eps
self.w_scale = w_scale
coords = torch.linspace(freq_range[0], freq_range[1], size+1)[:-1]
freqs = torch.linspace(0, cutoff, dim // 4)
if w_scale > 0:
freqs = freqs @ (torch.randn(dim // 4, dim // 4) * w_scale)
coord_map = torch.outer(freqs, coords)
coord_map = 2 * math.pi * coord_map
self.register_buffer("coord_h", coord_map.view(freqs.shape[0], 1, size))
self.register_buffer("coord_w", self.coord_h.transpose(1, 2).detach())
self.register_buffer("lf", freqs.view(1, dim // 4, 1, 1) * 2*math.pi * 2/size)
self.allow_scaling = allow_scaling
for op in op_order:
assert op in ['r', 't', 's'], f"Operation not valid: {op}"
self.op_order = op_order
def forward(self, affine):
norm = ((affine[:, 0:1].pow(2) + affine[:, 1:2].pow(2)).sqrt() + self.affine_eps).expand(affine.size(0), 4)
if self.allow_scaling:
assert affine.size(-1) == 6, f"If scaling is enabled, 2 extra values must be passed for a total of 6, and not {affine.size(-1)}"
norm = torch.cat([norm, norm.new_ones(affine.size(0), 2)], dim=1)
else:
assert affine.size(-1) == 4, f"If scaling is disabled, 4 affine values should be passed, and not {affine.size(-1)}"
affine = affine / norm
affine = affine[:, :, None, None, None]
coord_h, coord_w = self.coord_h.unsqueeze(0), self.coord_w.unsqueeze(0)
for op in reversed(self.op_order):
if op == 's' and self.allow_scaling:
coord_h = coord_h / nn.functional.threshold(affine[:, 5], 1.0, 1.0) # scale
coord_w = coord_w / nn.functional.threshold(affine[:, 4], 1.0, 1.0)
elif op == 't':
coord_h = coord_h - (affine[:, 3] * self.lf) # shift
coord_w = coord_w - (affine[:, 2] * self.lf)
elif op == 'r':
_coord_h = -coord_w * affine[:, 1] + coord_h * affine[:, 0] # rotation
coord_w = coord_w * affine[:, 0] + coord_h * affine[:, 1]
coord_h = _coord_h
coord_h = torch.cat((torch.sin(coord_h), torch.cos(coord_h)), 1)
coord_w = torch.cat((torch.sin(coord_w), torch.cos(coord_w)), 1)
coords = torch.cat((coord_h, coord_w), 1)
return coords
def extra_repr(self):
info_string = f'size={self.size}, dim={self.dim}, cutoff={self.cutoff}, freq_range={self.freq_range}'
if self.w_scale > 0:
info_string += f', w_scale={self.w_scale}'
if self.allow_scaling:
info_string += f', allow_scaling={self.allow_scaling}'
return info_string
================================================
FILE: torchtools/nn/functional/__init__.py
================================================
from .vq import vector_quantize, binarize
from .gradient_penalty import gradient_penalty
from .perceptual import total_variation
from .magnitude_preserving import mp_cat, mp_sum
================================================
FILE: torchtools/nn/functional/gradient_penalty.py
================================================
####
# CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp
# ORIGINAL PAPER https://arxiv.org/pdf/1704.00028.pdf
####
import torch
from torch import autograd
def gradient_penalty(netD, real_data, fake_data, l=10):
batch_size = real_data.size(0)
alpha = real_data.new_empty((batch_size, 1, 1, 1)).uniform_(0, 1)
alpha = alpha.expand_as(real_data)
interpolates = alpha * real_data + ((1 - alpha) * fake_data)
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=real_data.new_ones(disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
gradient_penalty = ((gradients_norm - 1) ** 2).mean() * l
return gradient_penalty
================================================
FILE: torchtools/nn/functional/magnitude_preserving.py
================================================
import torch
def mp_cat(*args, dim=1, t=0.5):
if isinstance(t, float):
t = [1-t, t]
assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)"
w = [m/a.size(dim)**0.5 for a, m in zip(args, t)]
C = (sum([a.size(dim) for a in args]) / sum([m**2 for m in t]))**0.5
return torch.cat([a*v for a, v in zip(args, w)], dim=dim) * C
def mp_sum(*args, t=0.5):
if isinstance(t, float):
t = [1-t, t]
assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)"
assert abs(sum(t)-1) < 1e-3 , "the values of t should all add up to one"
return sum([a*m for a, m in zip(args, t)]) / sum([m**2 for m in t])**0.5
================================================
FILE: torchtools/nn/functional/perceptual.py
================================================
import torch
def total_variation(X, reduction='sum'):
tv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1])
tv_v = torch.abs(X[:, :, 1:] - X[:, :, :-1])
tv = torch.mean(tv_h) + torch.mean(tv_v) if reduction == 'mean' else torch.sum(tv_h) + torch.sum(tv_v)
return tv
================================================
FILE: torchtools/nn/functional/vq.py
================================================
import torch
from torch.autograd import Function
class vector_quantize(Function):
@staticmethod
def forward(ctx, x, codebook):
with torch.no_grad():
codebook_sqr = torch.sum(codebook ** 2, dim=1)
x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
_, indices = dist.min(dim=1)
ctx.save_for_backward(indices, codebook)
ctx.mark_non_differentiable(indices)
nn = torch.index_select(codebook, 0, indices)
return nn, indices
@staticmethod
def backward(ctx, grad_output, grad_indices):
grad_inputs, grad_codebook = None, None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
if ctx.needs_input_grad[1]:
# Gradient wrt. the codebook
indices, codebook = ctx.saved_tensors
grad_codebook = torch.zeros_like(codebook)
grad_codebook.index_add_(0, indices, grad_output)
return (grad_inputs, grad_codebook)
class binarize(Function):
@staticmethod
def forward(ctx, x, threshold=0.5):
with torch.no_grad():
binarized = (x > threshold).float()
ctx.mark_non_differentiable(binarized)
return binarized
@staticmethod
def backward(ctx, grad_output):
grad_inputs = None
if ctx.needs_input_grad[0]:
grad_inputs = grad_output.clone()
return grad_inputs
================================================
FILE: torchtools/nn/gp_loss.py
================================================
import torch
from torch import nn
from .functional import gradient_penalty
class GPLoss(nn.Module):
def __init__(self, discriminator, l=10):
super(GPLoss, self).__init__()
self.discriminator = discriminator
self.l = l
def forward(self, real_data, fake_data):
return gradient_penalty(self.discriminator, real_data, fake_data, self.l)
================================================
FILE: torchtools/nn/haar_dwt.py
================================================
import torch
from torch import nn
# Taken almost as is from https://github.com/bes-dev/haar_pytorch
class HaarForward(nn.Module):
"""
Performs a 2d DWT Forward decomposition of an image using Haar Wavelets
set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt
"""
def __init__(self, beta=2):
super().__init__()
self.alpha = 0.5
self.beta = beta
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a 2d DWT Forward decomposition of an image using Haar Wavelets
Arguments:
x (torch.Tensor): input tensor of shape [b, c, h, w]
Returns:
out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2]
"""
ll = self.alpha/self.beta * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2])
hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2])
return torch.cat([ll,lh,hl,hh], axis=1)
class HaarInverse(nn.Module):
"""
Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets
set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt
"""
def __init__(self, beta=2):
super().__init__()
self.alpha = 0.5
self.beta = beta
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets
Arguments:
x (torch.Tensor): input tensor of shape [b, c, h, w]
Returns:
out (torch.Tensor): output tensor of shape [b, c / 4, h * 2, w * 2]
"""
assert x.size(1) % 4 == 0, "The number of channels must be divisible by 4."
size = [x.shape[0], x.shape[1] // 4, x.shape[2] * 2, x.shape[3] * 2]
f = lambda i: x[:, size[1] * i : size[1] * (i + 1)]
out = torch.zeros(size, dtype=x.dtype, device=x.device)
out[:,:,0::2,0::2] = self.alpha * (f(0)*self.beta + f(1) + f(2) + f(3))
out[:,:,0::2,1::2] = self.alpha * (f(0)*self.beta + f(1) - f(2) - f(3))
out[:,:,1::2,0::2] = self.alpha * (f(0)*self.beta - f(1) + f(2) - f(3))
out[:,:,1::2,1::2] = self.alpha * (f(0)*self.beta - f(1) - f(2) + f(3))
return out
================================================
FILE: torchtools/nn/magnitude_preserving.py
================================================
import torch
from torch import nn
class MP_GELU(nn.GELU):
def forward(self, x):
return super().forward(x) / 0.652 # ¯\_(ツ)_/¯
class MP_SiLU(nn.SiLU):
def forward(self, x):
return super().forward(x) / 0.596 # ¯\_(ツ)_/¯
class Gain(nn.Module):
def __init__(self, init_w=0.0):
super().__init__()
self.g = nn.Parameter(torch.tensor([init_w]))
def forward(self, x):
return x * self.g
================================================
FILE: torchtools/nn/mish.py
================================================
####
# CODE TAKEN FROM https://github.com/lessw2020/mish
# ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1
####
import torch
import torch.nn as nn
import torch.nn.functional as F #(uncomment if needed,but you likely already have it)
#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
#https://arxiv.org/abs/1908.08681v1
#implemented for PyTorch / FastAI by lessw2020
#github: https://github.com/lessw2020/mish
class Mish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
#inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
return x *( torch.tanh(F.softplus(x)))
================================================
FILE: torchtools/nn/modulation.py
================================================
import torch
from torch import nn
import math
####
# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM
# https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143
# But made it extend from the base Conv2d to avoid some boilerplate
####
class ModulatedConv2d(nn.Conv2d):
def __init__(self, *args, demodulate=True, ema_decay=1.0, **kwargs):
super().__init__(*args, **kwargs)
fan_in = self.in_channels * self.kernel_size[0] ** 2
self.scale = 1 / math.sqrt(fan_in)
self.demodulate = demodulate
self.ema_decay = ema_decay
self.register_buffer("ema_var", torch.tensor(1.0))
nn.init.normal_(self.weight)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, x, w):
batch, in_channels, height, width = x.shape
style = w.view(batch, 1, in_channels, 1, 1)
weight = self.scale * self.weight.unsqueeze(0) * style
if self.demodulate:
demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
weight = weight * demod.view(batch, self.out_channels, 1, 1, 1)
weight = weight.view(
batch * self.out_channels, in_channels, self.kernel_size[0], self.kernel_size[1]
)
if self.ema_decay < 1:
if self.training:
var = x.pow(2).mean((0, 1, 2, 3))
self.ema_var.mul_(self.ema_decay).add_(var.detach(), alpha=1 - self.ema_decay)
weight = weight / (torch.sqrt(self.ema_var) + 1e-8)
input = x.view(1, batch * in_channels, height, width)
self.groups = batch
out = self._conv_forward(input, weight, None)
_, _, height, width = out.shape
out = out.view(batch, self.out_channels, height, width)
if self.bias is not None:
out = out + self.bias.view(1, -1, 1, 1)
return out
================================================
FILE: torchtools/nn/perceptual.py
================================================
import torch
import torch.nn as nn
from .functional import total_variation
class TVLoss(nn.Module):
def __init__(self, reduction='sum', alpha=1e-4):
super(TVLoss, self).__init__()
self.reduction = reduction
self.alpha = alpha
def forward(self, x):
return total_variation(x, reduction=self.reduction) * self.alpha
================================================
FILE: torchtools/nn/pixel_normalzation.py
================================================
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-4):
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / (torch.sqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True)) + self.eps)
================================================
FILE: torchtools/nn/pos_embeddings.py
================================================
import torch
from torch import nn
####
# CODE TAKEN FROM https://github.com/lucidrains/x-transformers
####
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
================================================
FILE: torchtools/nn/simple_self_attention.py
================================================
import torch.nn as nn
import torch, math, sys
####
# CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention
####
#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False):
"Create and initialize a `nn.Conv1d` layer with spectral normalization."
conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
nn.init.kaiming_normal_(conv.weight)
if bias: conv.bias.data.zero_()
return nn.utils.spectral_norm(conv)
# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
# Inspired by https://arxiv.org/pdf/1805.08318.pdf
class SimpleSelfAttention(nn.Module):
def __init__(self, n_in, ks=1, sym=False):
super().__init__()
self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)
self.gamma = nn.Parameter(torch.Tensor([0.]))
self.sym = sym
self.n_in = n_in
def forward(self, x):
if self.sym:
# symmetry hack by https://github.com/mgrankin
c = self.conv.weight.view(self.n_in,self.n_in)
c = (c + c.t())/2
self.conv.weight = c.view(self.n_in,self.n_in,1)
size = x.size()
x = x.view(*size[:2],-1) # (C,N)
# changed the order of mutiplication to avoid O(N^2) complexity
# (x*xT)*(W*x) instead of (x*(xT*(W*x)))
convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2)
xxT = torch.bmm(x, x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2)
o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2)
o = self.gamma * o + x
return o.view(*size).contiguous()
================================================
FILE: torchtools/nn/stylegan2/__init__.py
================================================
from .upfirdn2d import upfirdn2d
================================================
FILE: torchtools/nn/stylegan2/upfirdn2d.cpp
================================================
#include <ATen/ATen.h>
#include <torch/extension.h>
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
int up_x, int up_y, int down_x, int down_y, int pad_x0,
int pad_x1, int pad_y0, int pad_y1) {
CHECK_INPUT(input);
CHECK_INPUT(kernel);
at::DeviceGuard guard(input.device());
return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
pad_y0, pad_y1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
}
================================================
FILE: torchtools/nn/stylegan2/upfirdn2d.py
================================================
from collections import abc
import os
import torch
from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
module_path = os.path.dirname(__file__)
upfirdn2d_op = load(
"upfirdn2d",
sources=[
os.path.join(module_path, "upfirdn2d.cpp"),
os.path.join(module_path, "upfirdn2d_kernel.cu"),
],
)
class UpFirDn2dBackward(Function):
@staticmethod
def forward(
ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
):
up_x, up_y = up
down_x, down_y = down
g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
grad_input = upfirdn2d_op.upfirdn2d(
grad_output,
grad_kernel,
down_x,
down_y,
up_x,
up_y,
g_pad_x0,
g_pad_x1,
g_pad_y0,
g_pad_y1,
)
grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
ctx.save_for_backward(kernel)
pad_x0, pad_x1, pad_y0, pad_y1 = pad
ctx.up_x = up_x
ctx.up_y = up_y
ctx.down_x = down_x
ctx.down_y = down_y
ctx.pad_x0 = pad_x0
ctx.pad_x1 = pad_x1
ctx.pad_y0 = pad_y0
ctx.pad_y1 = pad_y1
ctx.in_size = in_size
ctx.out_size = out_size
return grad_input
@staticmethod
def backward(ctx, gradgrad_input):
kernel, = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
gradgrad_out = upfirdn2d_op.upfirdn2d(
gradgrad_input,
kernel,
ctx.up_x,
ctx.up_y,
ctx.down_x,
ctx.down_y,
ctx.pad_x0,
ctx.pad_x1,
ctx.pad_y0,
ctx.pad_y1,
)
# gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
gradgrad_out = gradgrad_out.view(
ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
)
return gradgrad_out, None, None, None, None, None, None, None, None
class UpFirDn2d(Function):
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
down_x, down_y = down
pad_x0, pad_x1, pad_y0, pad_y1 = pad
kernel_h, kernel_w = kernel.shape
batch, channel, in_h, in_w = input.shape
ctx.in_size = input.shape
input = input.reshape(-1, in_h, in_w, 1)
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
ctx.down = (down_x, down_y)
ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
g_pad_x0 = kernel_w - pad_x0 - 1
g_pad_y0 = kernel_h - pad_y0 - 1
g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
out = upfirdn2d_op.upfirdn2d(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
)
# out = out.view(major, out_h, out_w, minor)
out = out.view(-1, channel, out_h, out_w)
return out
@staticmethod
def backward(ctx, grad_output):
kernel, grad_kernel = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = UpFirDn2dBackward.apply(
grad_output,
kernel,
grad_kernel,
ctx.up,
ctx.down,
ctx.pad,
ctx.g_pad,
ctx.in_size,
ctx.out_size,
)
return grad_input, None, None, None, None
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
if not isinstance(up, abc.Iterable):
up = (up, up)
if not isinstance(down, abc.Iterable):
down = (down, down)
if len(pad) == 2:
pad = (pad[0], pad[1], pad[0], pad[1])
if input.device.type == "cpu":
out = upfirdn2d_native(input, kernel, *up, *down, *pad)
else:
out = UpFirDn2d.apply(input, kernel, up, down, pad)
return out
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
_, channel, in_h, in_w = input.shape
input = input.reshape(-1, in_h, in_w, 1)
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
out = input.view(-1, in_h, 1, in_w, 1, minor)
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(
out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
)
out = out[
:,
max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
:,
]
out = out.permute(0, 3, 1, 2)
out = out.reshape(
[-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
)
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
out = F.conv2d(out, w)
out = out.reshape(
-1,
minor,
in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
out = out[:, ::down_y, ::down_x, :]
out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
return out.view(-1, channel, out_h, out_w)
================================================
FILE: torchtools/nn/stylegan2/upfirdn2d_kernel.cu
================================================
// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
//
// This work is made available under the Nvidia Source Code License-NC.
// To view a copy of this license, visit
// https://nvlabs.github.io/stylegan2/license.html
#include <torch/types.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
static __device__ __forceinline__ int floor_div(int a, int b) {
int t = 1 - a / b;
return (a + t * b) / b - t;
}
struct UpFirDn2DKernelParams {
int up_x;
int up_y;
int down_x;
int down_y;
int pad_x0;
int pad_x1;
int pad_y0;
int pad_y1;
int major_dim;
int in_h;
int in_w;
int minor_dim;
int kernel_h;
int kernel_w;
int out_h;
int out_w;
int loop_major;
int loop_x;
};
template <typename scalar_t>
__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
int out_y = minor_idx / p.minor_dim;
minor_idx -= out_y * p.minor_dim;
int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
int major_idx_base = blockIdx.z * p.loop_major;
if (out_x_base >= p.out_w || out_y >= p.out_h ||
major_idx_base >= p.major_dim) {
return;
}
int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major && major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, out_x = out_x_base;
loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
const scalar_t *x_p =
&input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
minor_idx];
const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
int x_px = p.minor_dim;
int k_px = -p.up_x;
int x_py = p.in_w * p.minor_dim;
int k_py = -p.up_y * p.kernel_w;
scalar_t v = 0.0f;
for (int y = 0; y < h; y++) {
for (int x = 0; x < w; x++) {
v += static_cast<scalar_t>(*x_p) * static_cast<scalar_t>(*k_p);
x_p += x_px;
k_p += k_px;
}
x_p += x_py - w * x_px;
k_p += k_py - w * k_px;
}
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
template <typename scalar_t, int up_x, int up_y, int down_x, int down_y,
int kernel_h, int kernel_w, int tile_out_h, int tile_out_w>
__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
const scalar_t *kernel,
const UpFirDn2DKernelParams p) {
const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
__shared__ volatile float sk[kernel_h][kernel_w];
__shared__ volatile float sx[tile_in_h][tile_in_w];
int minor_idx = blockIdx.x;
int tile_out_y = minor_idx / p.minor_dim;
minor_idx -= tile_out_y * p.minor_dim;
tile_out_y *= tile_out_h;
int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
int major_idx_base = blockIdx.z * p.loop_major;
if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
major_idx_base >= p.major_dim) {
return;
}
for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
tap_idx += blockDim.x) {
int ky = tap_idx / kernel_w;
int kx = tap_idx - ky * kernel_w;
scalar_t v = 0.0;
if (kx < p.kernel_w & ky < p.kernel_h) {
v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
}
sk[ky][kx] = v;
}
for (int loop_major = 0, major_idx = major_idx_base;
loop_major < p.loop_major & major_idx < p.major_dim;
loop_major++, major_idx++) {
for (int loop_x = 0, tile_out_x = tile_out_x_base;
loop_x < p.loop_x & tile_out_x < p.out_w;
loop_x++, tile_out_x += tile_out_w) {
int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
int tile_in_x = floor_div(tile_mid_x, up_x);
int tile_in_y = floor_div(tile_mid_y, up_y);
__syncthreads();
for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
in_idx += blockDim.x) {
int rel_in_y = in_idx / tile_in_w;
int rel_in_x = in_idx - rel_in_y * tile_in_w;
int in_x = rel_in_x + tile_in_x;
int in_y = rel_in_y + tile_in_y;
scalar_t v = 0.0;
if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
p.minor_dim +
minor_idx];
}
sx[rel_in_y][rel_in_x] = v;
}
__syncthreads();
for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
out_idx += blockDim.x) {
int rel_out_y = out_idx / tile_out_w;
int rel_out_x = out_idx - rel_out_y * tile_out_w;
int out_x = rel_out_x + tile_out_x;
int out_y = rel_out_y + tile_out_y;
int mid_x = tile_mid_x + rel_out_x * down_x;
int mid_y = tile_mid_y + rel_out_y * down_y;
int in_x = floor_div(mid_x, up_x);
int in_y = floor_div(mid_y, up_y);
int rel_in_x = in_x - tile_in_x;
int rel_in_y = in_y - tile_in_y;
int kernel_x = (in_x + 1) * up_x - mid_x - 1;
int kernel_y = (in_y + 1) * up_y - mid_y - 1;
if (out_x < p.out_w & out_y < p.out_h) {
scalar_t v = 0.0;
#pragma unroll
for (int y = 0; y < kernel_h / up_y; y++)
#pragma unroll
for (int x = 0; x < kernel_w / up_x; x++)
v += sx[rel_in_y + y][rel_in_x + x] *
sk[kernel_y + y * up_y][kernel_x + x * up_x];
out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
minor_idx] = v;
}
}
}
}
}
torch::Tensor upfirdn2d_op(const torch::Tensor &input,
const torch::Tensor &kernel, int up_x, int up_y,
int down_x, int down_y, int pad_x0, int pad_x1,
int pad_y0, int pad_y1) {
int curDevice = -1;
cudaGetDevice(&curDevice);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
UpFirDn2DKernelParams p;
auto x = input.contiguous();
auto k = kernel.contiguous();
p.major_dim = x.size(0);
p.in_h = x.size(1);
p.in_w = x.size(2);
p.minor_dim = x.size(3);
p.kernel_h = k.size(0);
p.kernel_w = k.size(1);
p.up_x = up_x;
p.up_y = up_y;
p.down_x = down_x;
p.down_y = down_y;
p.pad_x0 = pad_x0;
p.pad_x1 = pad_x1;
p.pad_y0 = pad_y0;
p.pad_y1 = pad_y1;
p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
p.down_y;
p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
p.down_x;
auto out =
at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
int mode = -1;
int tile_out_h = -1;
int tile_out_w = -1;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
void *cuda_kernel = (void *)upfirdn2d_kernel_large<scalar_t>;
if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 1 && p.kernel_w <= 24) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 2, 1, 1, 1, 1, 24, 8, 128>;
tile_out_h = 8;
tile_out_w = 128;
}
if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 1 && p.kernel_w <= 12) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 2, 1, 1, 1, 1, 12, 8, 128>;
tile_out_h = 8;
tile_out_w = 128;
}
if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 24 && p.kernel_w <= 1) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 2, 1, 1, 24, 1, 32, 32>;
tile_out_h = 32;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 12 && p.kernel_w <= 1) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 2, 1, 1, 12, 1, 32, 32>;
tile_out_h = 32;
tile_out_w = 32;
}
//
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 &&
p.kernel_h <= 1 && p.kernel_w <= 24) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 1, 1, 24, 8, 64>;
tile_out_h = 8;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 &&
p.kernel_h <= 1 && p.kernel_w <= 12) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 1, 1, 12, 8, 64>;
tile_out_h = 8;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 &&
p.kernel_h <= 24 && p.kernel_w <= 1) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 2, 24, 1, 16, 32>;
tile_out_h = 16;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 &&
p.kernel_h <= 12 && p.kernel_w <= 1) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 2, 12, 1, 16, 32>;
tile_out_h = 16;
tile_out_w = 32;
}
//
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 4, 4, 16, 64>;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 3 && p.kernel_w <= 3) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 1, 1, 1, 1, 3, 3, 16, 64>;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 4, 4, 16, 64>;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
cuda_kernel =
(void *)upfirdn2d_kernel<scalar_t, 2, 2, 1, 1, 2, 2, 16, 64>;
tile_out_h = 16;
tile_out_w = 64;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 4 && p.kernel_w <= 4) {
cuda_kernel = (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>;
tile_out_h = 8;
tile_out_w = 32;
}
if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
p.kernel_h <= 2 && p.kernel_w <= 2) {
cuda_kernel = (void *)upfirdn2d_kernel<scalar_t, 1, 1, 2, 2, 4, 4, 8, 32>;
tile_out_h = 8;
tile_out_w = 32;
}
dim3 block_size;
dim3 grid_size;
if (tile_out_h > 0 && tile_out_w > 0) {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 1;
block_size = dim3(32 * 8, 1, 1);
grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
(p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
(p.major_dim - 1) / p.loop_major + 1);
} else {
p.loop_major = (p.major_dim - 1) / 16384 + 1;
p.loop_x = 4;
block_size = dim3(4, 32, 1);
grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
(p.out_w - 1) / (p.loop_x * block_size.y) + 1,
(p.major_dim - 1) / p.loop_major + 1);
}
scalar_t *out_p = out.data_ptr<scalar_t>();
scalar_t *x_p = x.data_ptr<scalar_t>();
scalar_t *k_p = k.data_ptr<scalar_t>();
void *args[] = {&out_p, &x_p, &k_p, &p};
AT_CUDA_CHECK(
cudaLaunchKernel(cuda_kernel, grid_size, block_size, args, 0, stream));
});
return out;
}
================================================
FILE: torchtools/nn/transformers.py
================================================
import torch.nn as nn
# Based on the GPT2 implementatyion from MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy
class GPTTransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.mlp = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout),
)
def forward(self, x, src_mask=None, src_key_padding_mask=None):
x = self.ln1(x)
x = x + self.attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
x = x + self.mlp(self.ln2(x))
return x
================================================
FILE: torchtools/nn/vq.py
================================================
import torch
from torch import nn
from .functional.vq import vector_quantize, binarize
import numpy as np
class VectorQuantize(nn.Module):
def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
"""
Takes an input of variable size (as long as the last dimension matches the embedding size).
Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
with the same size as the input, vq and commitment components for the loss as a touple
in the second output and the indices of the quantized vectors in the third:
quantized, (vq_loss, commit_loss), indices
"""
super(VectorQuantize, self).__init__()
self.codebook = nn.Embedding(k, embedding_size)
self.codebook.weight.data.uniform_(-1./k, 1./k)
self.vq = vector_quantize.apply
self.ema_decay = ema_decay
self.ema_loss = ema_loss
if ema_loss:
self.register_buffer('ema_element_count', torch.ones(k))
self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
def _laplace_smoothing(self, x, epsilon):
n = torch.sum(x)
return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
def _updateEMA(self, z_e_x, indices):
mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
elem_count = mask.sum(dim=0)
weight_sum = torch.mm(mask.t(), z_e_x)
self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
def idx2vq(self, idx, dim=-1):
q_idx = self.codebook(idx)
if dim != -1:
q_idx = q_idx.movedim(-1, dim)
return q_idx
def forward(self, x, get_losses=True, dim=-1):
if dim != -1:
x = x.movedim(dim, -1)
z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
vq_loss, commit_loss = None, None
if self.ema_loss and self.training:
self._updateEMA(z_e_x.detach(), indices.detach())
# pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
if get_losses:
vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
z_q_x = z_q_x.view(x.shape)
if dim != -1:
z_q_x = z_q_x.movedim(-1, dim)
return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
class Binarize(nn.Module):
def __init__(self, threshold=0.5):
"""
Takes an input of any size.
Returns an output of the same size but with its values binarized (0 if input is below a threshold, 1 if its above)
"""
super(Binarize, self).__init__()
self.bin = binarize.apply
self.threshold = threshold
def forward(self, x):
return self.bin(x, self.threshold)
# Finite Scalar Quantization: https://arxiv.org/abs/2309.15505
class FSQ(nn.Module):
def __init__(self, bins, dim=-1, eps=1e-1):
super().__init__()
self.dim = dim
self.eps = eps
self.register_buffer('bins', torch.tensor(bins))
self.register_buffer('bases', torch.tensor([1] + np.cumprod(bins[:-1]).tolist()))
self.codebook_size = np.prod(bins)
self.in_shift, self.out_shift = None, None
def _round(self, x, quantize):
x = x.sigmoid() * (1-1e-7)
if quantize is True:
x_rounded = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins)
x = x + (x_rounded - x).detach()
x_sigmoid = x
x = (x / (1-1e-7)).logit(eps=self.eps)
return x, x_sigmoid
def vq_to_idx(self, x, is_sigmoid=False):
if not is_sigmoid:
x = x.sigmoid() * (1-1e-7)
x = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins)
x = x.mul(self.bins-1).long()
x = (x * self.bases).sum(dim=-1).long()
return x
def idx_to_vq(self, x):
x = x.unsqueeze(-1) // self.bases % self.bins
x = x.div(self.bins-1)
x = (x / (1-1e-7)).logit(eps=self.eps)
if self.dim != -1:
x = x.movedim(-1, self.dim)
return x
def forward(self, x, quantize=True):
if self.dim != -1:
x = x.movedim(self.dim, -1)
x, x_sigmoid = self._round(x, quantize=quantize)
idx = self.vq_to_idx(x_sigmoid, is_sigmoid=True)
if self.dim != -1:
x = x.movedim(-1, self.dim)
return x, idx
================================================
FILE: torchtools/optim/__init__.py
================================================
from .radam import RAdam, PlainRAdam, AdamW
from .ranger import Ranger
from .lookahead import Lookahead, LookaheadAdam
from .over9000 import Over9000, RangerLars
from .ralamb import Ralamb
from .novograd import Novograd
from .lamb import Lamb
================================================
FILE: torchtools/optim/lamb.py
================================================
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
import collections
import math
import torch
from torch.optim import Optimizer
try:
from tensorboardX import SummaryWriter
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])
for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
except ModuleNotFoundError as e:
print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# Paper v3 does not use debiasing.
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
adam_step.add_(group['weight_decay'], p.data)
adam_norm = adam_step.pow(2).sum().sqrt()
if weight_norm == 0 or adam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
if self.adam:
trust_ratio = 1
p.data.add_(-step_size * trust_ratio, adam_step)
return loss
================================================
FILE: torchtools/optim/lookahead.py
================================================
####
# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch
# Original paper: https://arxiv.org/abs/1907.08610
####
# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
""" Lookahead Optimizer Wrapper.
Implementation modified from: https://github.com/alphadl/lookahead.pytorch
Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
"""
import torch
from torch.optim.optimizer import Optimizer
from collections import defaultdict
class Lookahead(Optimizer):
def __init__(self, base_optimizer, alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
self.base_optimizer = base_optimizer
self.param_groups = self.base_optimizer.param_groups
self.defaults = base_optimizer.defaults
self.defaults.update(defaults)
self.state = defaultdict(dict)
# manually add our defaults to the param groups
for name, default in defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def update_slow(self, group):
for fast_p in group["params"]:
if fast_p.grad is None:
continue
param_state = self.state[fast_p]
if 'slow_buffer' not in param_state:
param_state['slow_buffer'] = torch.empty_like(fast_p.data)
param_state['slow_buffer'].copy_(fast_p.data)
slow = param_state['slow_buffer']
slow.add_(group['lookahead_alpha'], fast_p.data - slow)
fast_p.data.copy_(slow)
def sync_lookahead(self):
for group in self.param_groups:
self.update_slow(group)
def step(self, closure=None):
# print(self.k)
# assert id(self.param_groups) == id(self.base_optimizer.param_groups)
loss = self.base_optimizer.step(closure)
for group in self.param_groups:
group['lookahead_step'] += 1
if group['lookahead_step'] % group['lookahead_k'] == 0:
self.update_slow(group)
return loss
def state_dict(self):
fast_state_dict = self.base_optimizer.state_dict()
slow_state = {
(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()
}
fast_state = fast_state_dict['state']
param_groups = fast_state_dict['param_groups']
return {
'state': fast_state,
'slow_state': slow_state,
'param_groups': param_groups,
}
def load_state_dict(self, state_dict):
fast_state_dict = {
'state': state_dict['state'],
'param_groups': state_dict['param_groups'],
}
self.base_optimizer.load_state_dict(fast_state_dict)
# We want to restore the slow state, but share param_groups reference
# with base_optimizer. This is a bit redundant but least code
slow_state_new = False
if 'slow_state' not in state_dict:
print('Loading state_dict from optimizer without Lookahead applied.')
state_dict['slow_state'] = defaultdict(dict)
slow_state_new = True
slow_state_dict = {
'state': state_dict['slow_state'],
'param_groups': state_dict['param_groups'], # this is pointless but saves code
}
super(Lookahead, self).load_state_dict(slow_state_dict)
self.param_groups = self.base_optimizer.param_groups # make both ref same container
if slow_state_new:
# reapply defaults to catch missing lookahead specific ones
for name, default in self.defaults.items():
for group in self.param_groups:
group.setdefault(name, default)
def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
adam = Adam(params, *args, **kwargs)
return Lookahead(adam, alpha, k)
================================================
FILE: torchtools/optim/novograd.py
================================================
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.optim import Optimizer
import math
class AdamW(Optimizer):
"""Implements AdamW algorithm.
It has been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
return loss
class Novograd(Optimizer):
"""
Implements Novograd algorithm.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.95, 0))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
grad_averaging: gradient averaging
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False)
"""
def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
weight_decay=0, grad_averaging=False, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay,
grad_averaging=grad_averaging,
amsgrad=amsgrad)
super(Novograd, self).__init__(params, defaults)
def __setstate__(self, state):
super(Novograd, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Sparse gradients are not supported.')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
norm = torch.sum(torch.pow(grad, 2))
if exp_avg_sq == 0:
exp_avg_sq.copy_(norm)
else:
exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
grad.div_(denom)
if group['weight_decay'] != 0:
grad.add_(group['weight_decay'], p.data)
if group['grad_averaging']:
grad.mul_(1 - beta1)
exp_avg.mul_(beta1).add_(grad)
p.data.add_(-group['lr'], exp_avg)
return loss
================================================
FILE: torchtools/optim/over9000.py
================================================
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
import torch, math
from torch.optim.optimizer import Optimizer
import itertools as it
from .lookahead import Lookahead
from .ralamb import Ralamb
# RAdam + LARS + LookAHead
# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
ralamb = Ralamb(params, *args, **kwargs)
return Lookahead(ralamb, alpha, k)
RangerLars = Over9000
================================================
FILE: torchtools/optim/radam.py
================================================
####
# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam
# Paper: https://arxiv.org/abs/1908.03265
####
import math
import torch
from torch.optim.optimizer import Optimizer, required
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
p_data_fp32.add_(-step_size * group['lr'], exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
self.degenerated_to_sgd = degenerated_to_sgd
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
elif self.degenerated_to_sgd:
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class AdamW(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, warmup = warmup)
super(AdamW, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamW, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['warmup'] > state['step']:
scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
else:
scheduled_lr = group['lr']
step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
p.data.copy_(p_data_fp32)
return loss
================================================
FILE: torchtools/optim/ralamb.py
================================================
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
import torch, math
from torch.optim.optimizer import Optimizer
# RAdam + LARS
class Ralamb(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(Ralamb, self).__init__(params, defaults)
def __setstate__(self, state):
super(Ralamb, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('Ralamb does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, radam_step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
radam_step_size = 1.0 / (1 - beta1 ** state['step'])
buffered[2] = radam_step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
radam_step = p_data_fp32.clone()
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
else:
radam_step.add_(-radam_step_size * group['lr'], exp_avg)
radam_norm = radam_step.pow(2).sum().sqrt()
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
if weight_norm == 0 or radam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / radam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = radam_norm
state['trust_ratio'] = trust_ratio
if N_sma >= 5:
p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
else:
p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
p.data.copy_(p_data_fp32)
return loss
================================================
FILE: torchtools/optim/ranger.py
================================================
####
# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
####
import math
import torch
from torch.optim.optimizer import Optimizer, required
import itertools as it
from .lookahead import Lookahead
from .radam import RAdam
def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):
radam = RAdam(params, betas=betas, *args, **kwargs)
return Lookahead(radam, alpha, k)
================================================
FILE: torchtools/transforms/__init__.py
================================================
from .smart_crop import SmartCrop
================================================
FILE: torchtools/transforms/models/__init__.py
================================================
================================================
FILE: torchtools/transforms/smart_crop.py
================================================
import torch
import torchvision
from torch import nn
import numpy as np
import os
# MICRO RESNET
class ResBlock(nn.Module):
def __init__(self, channels):
super(ResBlock, self).__init__()
self.resblock = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(channels, channels, kernel_size=3),
nn.InstanceNorm2d(channels, affine=True),
)
def forward(self, x):
out = self.resblock(x)
return out + x
class Upsample2d(nn.Module):
def __init__(self, scale_factor):
super(Upsample2d, self).__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
def forward(self, x):
x = self.interp(x, scale_factor=self.scale_factor, mode='nearest')
return x
class MicroResNet(nn.Module):
def __init__(self):
super(MicroResNet, self).__init__()
self.downsampler = nn.Sequential(
nn.ReflectionPad2d(4),
nn.Conv2d(3, 8, kernel_size=9, stride=4),
nn.InstanceNorm2d(8, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 16, kernel_size=3, stride=2),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(16, 32, kernel_size=3, stride=2),
nn.InstanceNorm2d(32, affine=True),
nn.ReLU(),
)
self.residual = nn.Sequential(
ResBlock(32),
nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32),
ResBlock(64),
)
self.segmentator = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(64, 16, kernel_size=3),
nn.InstanceNorm2d(16, affine=True),
nn.ReLU(),
Upsample2d(scale_factor=2),
nn.ReflectionPad2d(4),
nn.Conv2d(16, 1, kernel_size=9),
nn.Sigmoid()
)
def forward(self, x):
out = self.downsampler(x)
out = self.residual(out)
out = self.segmentator(out)
return out
# SmartCrop module
class SmartCrop(nn.Module):
def __init__(self, output_size, randomize_p=0.0, randomize_q=0.1, temperature=0.03):
super().__init__()
self.output_size = output_size
self.randomize_p, self.randomize_q = randomize_p, randomize_q
self.temperature = temperature
if isinstance(self.output_size, int):
self.output_size = (self.output_size, self.output_size)
self.saliency_model = MicroResNet().eval().requires_grad_(False)
checkpoint = torch.load(os.path.dirname(__file__) + "/models/saliency_model_v9.pt", map_location="cpu")
self.saliency_model.load_state_dict(checkpoint)
def forward(self, image):
is_batch = len(image.shape) == 4
if not is_batch:
image = image.unsqueeze(0)
with torch.no_grad():
resized_image = torchvision.transforms.functional.resize(image, 240, antialias=True)
saliency_map = self.saliency_model(resized_image)
tempered_heatmap = saliency_map.view(saliency_map.size(0), -1).div(self.temperature).softmax(-1)
tempered_heatmap = tempered_heatmap / tempered_heatmap.sum(dim=1)
tempered_heatmap = (tempered_heatmap > tempered_heatmap.max(dim=-1)[0]*0.75).float()
saliency_map = tempered_heatmap.view(*saliency_map.shape)
# GET CENTROID
coord_space = torch.cat([
torch.linspace(0, 1, saliency_map.size(-2))[None, None, :, None].expand(-1, -1, -1, saliency_map.size(-1)),
torch.linspace(0, 1, saliency_map.size(-1))[None, None, None, :].expand(-1, -1, saliency_map.size(-2), -1),
], dim=1)
centroid = (coord_space * saliency_map).sum(dim=[-1, -2]) / saliency_map.sum(dim=[-1, -2])
# CROP
crops = []
for i in range(image.size(0)):
if np.random.rand() < self.randomize_p:
centroid[i, 0] += np.random.uniform(-self.randomize_q, self.randomize_q)
centroid[i, 1] += np.random.uniform(-self.randomize_q, self.randomize_q)
top = (centroid[i, 0]*image.size(-2)-self.output_size[-2]/2).clamp(min=0, max=max(0, image.size(-2)-self.output_size[-2])).int()
left = (centroid[i, 1]*image.size(-1)-self.output_size[-1]/2).clamp(min=0, max=max(0, image.size(-1)-self.output_size[-1])).int()
bottom, right = top + self.output_size[-2], left + self.output_size[-1]
crop = image[i, :, top:bottom, left:right]
if crop.size(-2) < self.output_size[-2] or crop.size(-1) < self.output_size[-1]:
crop = torchvision.transforms.functional.center_crop(crop, self.output_size)
crops.append(crop)
if is_batch:
crops = torch.stack(crops, dim=0)
else:
crops = crops[0]
return crops
================================================
FILE: torchtools/utils/__init__.py
================================================
from .diffusion import Diffuzz
from .diffusion2 import Diffuzz2
from .gamma_parametrization import apply_gamma_reparam, gamma_reparam_model, remove_gamma_reparam
from .weight_normalization import apply_weight_norm, weight_norm_model, remove_weight_norm
================================================
FILE: torchtools/utils/diffusion.py
================================================
import torch
# Samplers --------------------------------------------------------------------
class SimpleSampler():
def __init__(self, diffuzz):
self.current_step = -1
self.diffuzz = diffuzz
def __call__(self, *args, **kwargs):
self.current_step += 1
return self.step(*args, **kwargs)
def init_x(self, shape):
return torch.randn(*shape, device=self.diffuzz.device)
def step(self, x, t, t_prev, noise):
raise NotImplementedError("You should override the 'apply' function.")
class DDPMSampler(SimpleSampler):
def step(self, x, t, t_prev, noise):
alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])
alpha = (alpha_cumprod / alpha_cumprod_prev)
mu = (1.0 / alpha).sqrt() * (x - (1-alpha) * noise / (1-alpha_cumprod).sqrt())
std = ((1-alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu)
return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]])
class DDIMSampler(SimpleSampler):
def step(self, x, t, t_prev, noise):
alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])
x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt()
dp_xt = (1 - alpha_cumprod_prev).sqrt()
return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise
class DPMSolverPlusPlusSampler(SimpleSampler): # FIXME: CURRENTLY NOT WORKING
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.q_ts = {}
def _get_coef(self, alpha_cumprod):
log_alpha_t = alpha_cumprod.log()
alpha_t = log_alpha_t.exp()
sigma_t = (1-alpha_t ** 2).sqrt()
lambda_t = log_alpha_t - sigma_t.log()
return alpha_t, sigma_t, lambda_t
def init_x(self, shape):
alpha_cumprod = self.diffuzz._alpha_cumprod(torch.ones(shape[0], device=self.diffuzz.device)).view(-1, *[1 for _ in shape[1:]])
return torch.randn(*shape, device=self.diffuzz.device) * self._get_coef(alpha_cumprod)[1]
def step(self, x, t, t_prev, noise):
alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
stride = (t_prev - t)
if self.current_step == 0:
alpha_t, sigma_t, _ = self._get_coef(alpha_cumprod)
elif self.current_step == 1:
alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod)
_, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next)
h = lambda_t - lambda_t_next
x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * self.q_ts[self.current_step-1]
else:
alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_cumprod_next_next = self.diffuzz._alpha_cumprod(t+stride*2).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod)
_, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next)
_, _, lambda_t_next_next = self._get_coef(alpha_cumprod_next_next)
h = lambda_t - lambda_t_next
h_next = lambda_t_next - lambda_t_next_next
r = h_next / h
D = (1 + 1 / (2 * r)) * self.q_ts[self.current_step-1] - 1 / (2 * r) * self.q_ts[self.current_step-2]
x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * D
self.q_ts[self.current_step] = (x - sigma_t * noise) / alpha_t
return x
sampler_dict = {
'ddpm': DDPMSampler,
'ddim': DDIMSampler,
'dpmsolver++': DPMSolverPlusPlusSampler,
}
# Custom simplified foward/backward diffusion (cosine schedule)
class Diffuzz():
def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(0.0001, 0.9999)):
self.device = device
self.s = torch.tensor([s]).to(device)
self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
self.scaler = scaler
self.cached_steps = None
self.clamp_range = clamp_range
if cache_steps is not None:
self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device))
def _alpha_cumprod(self, t):
if self.cached_steps is None:
if self.scaler > 1:
t = 1 - (1-t) ** self.scaler
elif self.scaler < 1:
t = t ** self.scaler
alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod
return alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])
else:
return self.cached_steps[t.mul(len(self.cached_steps)-1).long()]
def diffuse(self, x, t, noise=None): # t -> [0, 1]
if noise is None:
noise = torch.randn_like(x)
alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise
def undiffuse(self, x, t, t_prev, noise, sampler=None):
if sampler is None:
sampler = DDPMSampler(self)
return sampler(x, t, t_prev, noise)
def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, unconditional_inputs=None, sampler='ddpm', half=False):
r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device)
if isinstance(sampler, str):
if sampler in sampler_dict:
sampler = sampler_dict[sampler](self)
else:
raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}")
elif issubclass(sampler, SimpleSampler):
sampler = sampler(self)
else:
raise ValueError("Sampler should be either a string or a SimpleSampler object.")
preds = []
x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone()
if half:
r_range = r_range.half()
x = x.half()
for i in range(0, timesteps):
if mask is not None and x_init is not None:
x_renoised, _ = self.diffuse(x_init, r_range[i])
x = x * mask + x_renoised * (1-mask)
pred_noise = model(x, r_range[i], **model_inputs)
if cfg is not None:
if unconditional_inputs is None:
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
pred_noise_unconditional = model(x, r_range[i], **unconditional_inputs)
pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg)
x = self.undiffuse(x, r_range[i], r_range[i+1], pred_noise, sampler=sampler)
preds.append(x)
return preds
def p2_weight(self, t, k=1.0, gamma=1.0):
alpha_cumprod = self._alpha_cumprod(t)
return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma
================================================
FILE: torchtools/utils/diffusion2.py
================================================
import torch
import numpy as np
# Samplers --------------------------------------------------------------------
class SimpleSampler():
def __init__(self, diffuzz, mode="v"):
self.current_step = -1
self.diffuzz = diffuzz
if mode not in ['v', 'e', 'x']:
raise Exception("Mode should be either 'v', 'e' or 'x'")
self.mode = mode
def __call__(self, *args, **kwargs):
self.current_step += 1
return self.step(*args, **kwargs)
def init_x(self, shape):
return torch.randn(*shape, device=self.diffuzz.device)
def step(self, x, t, t_prev, noise):
raise NotImplementedError("You should override the 'apply' function.")
# https://github.com/ozanciga/diffusion-for-beginners/blob/main/samplers/ddim.py
class DDIMSampler(SimpleSampler):
def step(self, x, t, t_prev, pred, eta=0):
alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]])
sigma_tau = eta * ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)).sqrt() * (1 - alpha_cumprod / alpha_cumprod_prev).sqrt() if eta > 0 else 0
if self.mode == 'v':
x0 = alpha_cumprod.sqrt() * x - (1-alpha_cumprod).sqrt() * pred
noise = (1-alpha_cumprod).sqrt() * x + alpha_cumprod.sqrt() * pred
elif self.mode == 'x':
x0 = pred
noise = (x - x0 * alpha_cumprod.sqrt()) / (1 - alpha_cumprod).sqrt()
else:
noise = pred
x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / alpha_cumprod.sqrt()
renoised = alpha_cumprod_prev.sqrt() * x0 + (1 - alpha_cumprod_prev - sigma_tau ** 2).sqrt() * noise + sigma_tau * torch.randn_like(x)
return x0, renoised, pred
class DDPMSampler(DDIMSampler):
def step(self, x, t, t_prev, pred, eta=1):
return super().step(x, t, t_prev, pred, eta)
sampler_dict = {
'ddpm': DDPMSampler,
'ddim': DDIMSampler,
}
# Custom simplified foward/backward diffusion (cosine schedule)
class Diffuzz2():
def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(1e-7, 1-1e-7)):
self.device = device
self.s = torch.tensor([s]).to(device)
self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
self.scaler = 2 * np.log(1/scaler)
self.cached_steps = None
self.clamp_range = clamp_range
if cache_steps is not None:
self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device))
def _alpha_cumprod(self, t):
if self.cached_steps is None:
alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod
alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])
if self.scaler != 1:
alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(self.scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1])
return alpha_cumprod
else:
return self.cached_steps[t.mul(len(self.cached_steps)-1).long()]
def scale_t(self, t, scaler):
scaler = 2 * np.log(1/scaler)
alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod
alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1])
if scaler != 1:
alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1])
return (((alpha_cumprod * self._init_alpha_cumprod) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + self.s) - self.s
def diffuse(self, x, t, noise=None): # t -> [0, 1]
if noise is None:
noise = torch.randn_like(x)
alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise
def get_v(self, x, t, noise):
alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]])
# x0 = alpha_cumprod * noised − (1-alpha_cumprod).sqrt() * pred_v
# noise = (1-alpha_cumprod).sqrt() * noised + alpha_cumprod * pred_v
return alpha_cumprod.sqrt() * noise - (1-alpha_cumprod).sqrt() * x
def x0_from_v(self, noised, pred_v, t):
alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]])
return alpha_cumprod.sqrt() * noised - (1-alpha_cumprod).sqrt() * pred_v
def noise_from_v(self, noised, pred_v, t):
alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]])
return (1-alpha_cumprod).sqrt() * noised + alpha_cumprod.sqrt() * pred_v
def undiffuse(self, x, t, t_prev, pred, sampler=None, **kwargs):
if sampler is None:
sampler = DDPMSampler(self)
return sampler(x, t, t_prev, pred, **kwargs)
def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_rho=0.7, unconditional_inputs=None, sampler='ddpm', dtype=None, sample_mode='v', sampler_params={}, t_scaler=1):
r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device)
if t_scaler != 1:
r_range = self.scale_t(r_range, t_scaler)
if isinstance(sampler, str):
if sampler in sampler_dict:
sampler = sampler_dict[sampler](self, sample_mode)
else:
raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}")
elif issubclass(sampler, SimpleSampler):
sampler = sampler(self, sample_mode)
else:
raise ValueError("Sampler should be either a string or a SimpleSampler object.")
x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone()
if dtype is not None:
r_range = r_range.to(dtype)
x = x.to(dtype)
if cfg is not None:
if unconditional_inputs is None:
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
model_inputs = {k:torch.cat([v, v_u]) if isinstance(v, torch.Tensor) else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())}
for i in range(0, timesteps):
if mask is not None and x_init is not None:
x_renoised, _ = self.diffuse(x_init, r_range[i])
x = x * mask + x_renoised * (1-mask)
if cfg is not None:
pred, pred_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), **model_inputs).chunk(2)
pred_cfg = torch.lerp(pred_unconditional, pred, cfg)
if cfg_rho > 0:
std_pos, std_cfg = pred.std(), pred_cfg.std()
pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho)
else:
pred = pred_cfg
else:
pred = model(x, r_range[i], **model_inputs)
diff_out = self.undiffuse(x, r_range[i], r_range[i+1], pred, sampler=sampler, **sampler_params)
x = diff_out[1]
altered_vars = yield diff_out
# Update some running variables if the user wants
if altered_vars is not None:
cfg = altered_vars.get('cfg', cfg)
cfg_rho = altered_vars.get('cfg_rho', cfg_rho)
sampler = altered_vars.get('sampler', sampler)
unconditional_inputs = altered_vars.get('unconditional_inputs', unconditional_inputs)
model_inputs = altered_vars.get('model_inputs', model_inputs)
x = altered_vars.get('x', x)
mask = altered_vars.get('mask', mask)
x_init = altered_vars.get('x_init', x_init)
def p2_weight(self, t, k=1.0, gamma=1.0):
alpha_cumprod = self._alpha_cumprod(t)
return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma
def truncated_snr_weight(self, t, min=1.0, max=None):
alpha_cumprod = self._alpha_cumprod(t)
srn = (alpha_cumprod / (1 - alpha_cumprod))
if min != None or max != None:
srn = srn.clamp(min=min, max=max)
return srn
================================================
FILE: torchtools/utils/gamma_parametrization.py
================================================
import torch
from torch import nn
class _GammaScaling(nn.Module):
def __init__(self):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1))
def forward(self, w):
return w * self.gamma
def apply_gamma_reparam(module, name="weight"): # this reparametrizes the parameters of a single module
nn.utils.parametrizations.spectral_norm(module, name)
nn.utils.parametrize.register_parametrization(module, name, _GammaScaling())
return module
def gamma_reparam_model(model):
for module in model.modules(): # this reparametrizes all linear layers of the model
if isinstance(module, nn.Linear) and not torch.nn.utils.parametrize.is_parametrized(module, "weight"):
apply_gamma_reparam(module, "weight")
elif isinstance(module, nn.MultiheadAttention) and not torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"):
apply_gamma_reparam(module, "in_proj_weight")
return model
def remove_gamma_reparam(model):
for module in model.modules():
if torch.nn.utils.parametrize.is_parametrized(module, "weight"):
nn.utils.parametrize.remove_parametrizations(module, "weight")
elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"):
nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight")
================================================
FILE: torchtools/utils/weight_normalization.py
================================================
import torch
from torch import nn
class _WeigthNorm(nn.Module):
def __init__(self, eps=1e-4):
super().__init__()
self.eps = eps
def _normalize(self, w):
norm_dims = list(range(1, len(w.shape)))
w_norm = torch.linalg.vector_norm(w, dim=norm_dims, keepdim=True)
# w_norm = torch.norm_except_dim(w, 2, 0).clone()
return w / (w_norm + self.eps)
def forward(self, w):
if self.training:
with torch.no_grad():
fan_in = w[0].numel()**0.5
w.data = self._normalize(w.data.clone()) * fan_in
# w.copy_(self._normalize(w) * fan_in)
return self._normalize(w)
def apply_weight_norm(module, name="weight", init_weight=True): # this reparametrizes the parameters of a single module
if init_weight:
torch.nn.init.normal(getattr(module, name))
nn.utils.parametrize.register_parametrization(module, name, _WeigthNorm(), unsafe=True)
return module
def weight_norm_model(model, whitelist=None, init_weight=True):
whitelist = whitelist or []
def check_parameter(module, name):
return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(getattr(module, name), nn.Parameter)
for name, module in model.named_modules(): # this reparametrizes all layers of the model that have a "weight" parameter
if not any([w in name for w in whitelist]):
if check_parameter(module, "weight"):
apply_weight_norm(module, init_weight=init_weight)
elif check_parameter(module, "in_proj_weight"):
apply_weight_norm(module, 'in_proj_weight', init_weight=init_weight)
return model
def remove_weight_norm(model):
for module in model.modules():
if torch.nn.utils.parametrize.is_parametrized(module, "weight"):
nn.utils.parametrize.remove_parametrizations(module, "weight")
elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"):
nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight")
gitextract_gjrwr8nv/
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── readme.md
├── setup.cfg
├── setup.py
└── torchtools/
├── __init__.py
├── lr_scheduler/
│ ├── __init__.py
│ ├── delayed.py
│ └── inverse_sqrt.py
├── nn/
│ ├── __init__.py
│ ├── adain.py
│ ├── alias_free_activation.py
│ ├── equal_layers.py
│ ├── evonorm2d.py
│ ├── fourier_features.py
│ ├── functional/
│ │ ├── __init__.py
│ │ ├── gradient_penalty.py
│ │ ├── magnitude_preserving.py
│ │ ├── perceptual.py
│ │ └── vq.py
│ ├── gp_loss.py
│ ├── haar_dwt.py
│ ├── magnitude_preserving.py
│ ├── mish.py
│ ├── modulation.py
│ ├── perceptual.py
│ ├── pixel_normalzation.py
│ ├── pos_embeddings.py
│ ├── simple_self_attention.py
│ ├── stylegan2/
│ │ ├── __init__.py
│ │ ├── upfirdn2d.cpp
│ │ ├── upfirdn2d.py
│ │ └── upfirdn2d_kernel.cu
│ ├── transformers.py
│ └── vq.py
├── optim/
│ ├── __init__.py
│ ├── lamb.py
│ ├── lookahead.py
│ ├── novograd.py
│ ├── over9000.py
│ ├── radam.py
│ ├── ralamb.py
│ └── ranger.py
├── transforms/
│ ├── __init__.py
│ ├── models/
│ │ ├── __init__.py
│ │ └── saliency_model_v9.pt
│ └── smart_crop.py
└── utils/
├── __init__.py
├── diffusion.py
├── diffusion2.py
├── gamma_parametrization.py
└── weight_normalization.py
SYMBOL INDEX (229 symbols across 36 files)
FILE: torchtools/lr_scheduler/delayed.py
class DelayerScheduler (line 3) | class DelayerScheduler(_LRScheduler):
method __init__ (line 12) | def __init__(self, optimizer, delay_epochs, after_scheduler):
method get_lr (line 18) | def get_lr(self):
method step (line 27) | def step(self, epoch=None):
function DelayedCosineAnnealingLR (line 36) | def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_e...
FILE: torchtools/lr_scheduler/inverse_sqrt.py
class InverseSqrtLR (line 6) | class InverseSqrtLR(LRScheduler):
method __init__ (line 7) | def __init__(self, optimizer, lr, warmup_steps, pre_warmup_lr=None, la...
method _process_lr (line 14) | def _process_lr(self, _):
method get_lr (line 19) | def get_lr(self):
method _get_closed_form_lr (line 26) | def _get_closed_form_lr(self):
FILE: torchtools/nn/adain.py
class AdaIN (line 4) | class AdaIN(nn.Module):
method __init__ (line 5) | def __init__(self, n_channels):
method forward (line 9) | def forward(self, image, style):
FILE: torchtools/nn/alias_free_activation.py
class AliasFreeActivation (line 13) | class AliasFreeActivation(nn.Module):
method __init__ (line 14) | def __init__(self, activation, level, max_levels, max_size, max_channe...
method alias_level_params (line 46) | def alias_level_params(level, max_levels, max_size, max_channels, star...
method _lowpass_filter (line 60) | def _lowpass_filter(self, n_taps, cutoff, band_half, sr):
method _kaiser_window (line 67) | def _kaiser_window(self, n_taps, f_h, sr):
method _kaiser_attenuation (line 72) | def _kaiser_attenuation(self, n_taps, f_h, sr):
method _kaiser_beta (line 77) | def _kaiser_beta(self, n_taps, f_h, sr):
method forward (line 87) | def forward(self, x):
method _upsample (line 96) | def _upsample(self, x, kernel, factor, pad=(0, 0)):
method _downsample (line 101) | def _downsample(self, x, kernel, factor, pad=(0, 0)):
method extra_repr (line 106) | def extra_repr(self):
FILE: torchtools/nn/equal_layers.py
class EqualLinear (line 12) | class EqualLinear(nn.Linear):
method __init__ (line 13) | def __init__(self, *args, bias_init=0, lr_mul=1, **kwargs):
method forward (line 23) | def forward(self, x):
class EqualConv2d (line 27) | class EqualConv2d(nn.Conv2d):
method __init__ (line 28) | def __init__(self, *args, **kwargs):
method forward (line 38) | def forward(self, x):
class EqualLeakyReLU (line 42) | class EqualLeakyReLU(nn.LeakyReLU):
method __init__ (line 43) | def __init__(self, *args, scale=2**0.5, **kwargs):
method forward (line 47) | def forward(self, x):
FILE: torchtools/nn/evonorm2d.py
class SwishImplementation (line 6) | class SwishImplementation(torch.autograd.Function):
method forward (line 8) | def forward(ctx, i):
method backward (line 13) | def backward(ctx, grad_output):
class MemoryEfficientSwish (line 18) | class MemoryEfficientSwish(nn.Module):
method forward (line 19) | def forward(self, x):
function instance_std (line 22) | def instance_std(x, eps=1e-5):
function group_std (line 28) | def group_std(x, groups = 32, eps = 1e-5):
class EvoNorm2D (line 34) | class EvoNorm2D(nn.Module):
method __init__ (line 36) | def __init__(self, input, non_linear = True, version = 'S0', efficient...
method reset_parameters (line 65) | def reset_parameters(self):
method _check_input_dim (line 68) | def _check_input_dim(self, x):
method forward (line 73) | def forward(self, x):
FILE: torchtools/nn/fourier_features.py
class FourierFeatures2d (line 5) | class FourierFeatures2d(nn.Module):
method __init__ (line 6) | def __init__(self, size, dim, cutoff, affine_eps=1e-8, freq_range=[-0....
method forward (line 28) | def forward(self, affine):
method extra_repr (line 60) | def extra_repr(self):
FILE: torchtools/nn/functional/gradient_penalty.py
function gradient_penalty (line 9) | def gradient_penalty(netD, real_data, fake_data, l=10):
FILE: torchtools/nn/functional/magnitude_preserving.py
function mp_cat (line 3) | def mp_cat(*args, dim=1, t=0.5):
function mp_sum (line 13) | def mp_sum(*args, t=0.5):
FILE: torchtools/nn/functional/perceptual.py
function total_variation (line 3) | def total_variation(X, reduction='sum'):
FILE: torchtools/nn/functional/vq.py
class vector_quantize (line 4) | class vector_quantize(Function):
method forward (line 6) | def forward(ctx, x, codebook):
method backward (line 21) | def backward(ctx, grad_output, grad_indices):
class binarize (line 36) | class binarize(Function):
method forward (line 38) | def forward(ctx, x, threshold=0.5):
method backward (line 46) | def backward(ctx, grad_output):
FILE: torchtools/nn/gp_loss.py
class GPLoss (line 5) | class GPLoss(nn.Module):
method __init__ (line 6) | def __init__(self, discriminator, l=10):
method forward (line 11) | def forward(self, real_data, fake_data):
FILE: torchtools/nn/haar_dwt.py
class HaarForward (line 5) | class HaarForward(nn.Module):
method __init__ (line 10) | def __init__(self, beta=2):
method forward (line 15) | def forward(self, x: torch.Tensor) -> torch.Tensor:
class HaarInverse (line 33) | class HaarInverse(nn.Module):
method __init__ (line 38) | def __init__(self, beta=2):
method forward (line 43) | def forward(self, x: torch.Tensor) -> torch.Tensor:
FILE: torchtools/nn/magnitude_preserving.py
class MP_GELU (line 4) | class MP_GELU(nn.GELU):
method forward (line 5) | def forward(self, x):
class MP_SiLU (line 8) | class MP_SiLU(nn.SiLU):
method forward (line 9) | def forward(self, x):
class Gain (line 12) | class Gain(nn.Module):
method __init__ (line 13) | def __init__(self, init_w=0.0):
method forward (line 17) | def forward(self, x):
FILE: torchtools/nn/mish.py
class Mish (line 15) | class Mish(nn.Module):
method __init__ (line 16) | def __init__(self):
method forward (line 19) | def forward(self, x):
FILE: torchtools/nn/modulation.py
class ModulatedConv2d (line 11) | class ModulatedConv2d(nn.Conv2d):
method __init__ (line 12) | def __init__(self, *args, demodulate=True, ema_decay=1.0, **kwargs):
method forward (line 25) | def forward(self, x, w):
FILE: torchtools/nn/perceptual.py
class TVLoss (line 5) | class TVLoss(nn.Module):
method __init__ (line 6) | def __init__(self, reduction='sum', alpha=1e-4):
method forward (line 11) | def forward(self, x):
FILE: torchtools/nn/pixel_normalzation.py
class PixelNorm (line 4) | class PixelNorm(nn.Module):
method __init__ (line 5) | def __init__(self, dim=1, eps=1e-4):
method forward (line 10) | def forward(self, x):
FILE: torchtools/nn/pos_embeddings.py
class RotaryEmbedding (line 8) | class RotaryEmbedding(nn.Module):
method __init__ (line 9) | def __init__(self, dim, base=10000):
method forward (line 17) | def forward(self, x, seq_dim=1):
function rotate_half (line 29) | def rotate_half(x):
function apply_rotary_pos_emb (line 34) | def apply_rotary_pos_emb(q, k, cos, sin):
FILE: torchtools/nn/simple_self_attention.py
function conv1d (line 9) | def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False):
class SimpleSelfAttention (line 18) | class SimpleSelfAttention(nn.Module):
method __init__ (line 20) | def __init__(self, n_in, ks=1, sym=False):
method forward (line 27) | def forward(self, x):
FILE: torchtools/nn/stylegan2/upfirdn2d.cpp
function upfirdn2d (line 17) | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor ...
function PYBIND11_MODULE (line 29) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: torchtools/nn/stylegan2/upfirdn2d.py
class UpFirDn2dBackward (line 20) | class UpFirDn2dBackward(Function):
method forward (line 22) | def forward(
method backward (line 64) | def backward(ctx, gradgrad_input):
class UpFirDn2d (line 89) | class UpFirDn2d(Function):
method forward (line 91) | def forward(ctx, input, kernel, up, down, pad):
method backward (line 128) | def backward(ctx, grad_output):
function upfirdn2d (line 149) | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
function upfirdn2d_native (line 168) | def upfirdn2d_native(
FILE: torchtools/nn/transformers.py
class GPTTransformerEncoderLayer (line 5) | class GPTTransformerEncoderLayer(nn.Module):
method __init__ (line 6) | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0):
method forward (line 18) | def forward(self, x, src_mask=None, src_key_padding_mask=None):
FILE: torchtools/nn/vq.py
class VectorQuantize (line 6) | class VectorQuantize(nn.Module):
method __init__ (line 7) | def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
method _laplace_smoothing (line 27) | def _laplace_smoothing(self, x, epsilon):
method _updateEMA (line 31) | def _updateEMA(self, z_e_x, indices):
method idx2vq (line 42) | def idx2vq(self, idx, dim=-1):
method forward (line 48) | def forward(self, x, get_losses=True, dim=-1):
class Binarize (line 67) | class Binarize(nn.Module):
method __init__ (line 68) | def __init__(self, threshold=0.5):
method forward (line 78) | def forward(self, x):
class FSQ (line 82) | class FSQ(nn.Module):
method __init__ (line 83) | def __init__(self, bins, dim=-1, eps=1e-1):
method _round (line 93) | def _round(self, x, quantize):
method vq_to_idx (line 102) | def vq_to_idx(self, x, is_sigmoid=False):
method idx_to_vq (line 110) | def idx_to_vq(self, x):
method forward (line 118) | def forward(self, x, quantize=True):
FILE: torchtools/optim/lamb.py
function log_lamb_rs (line 14) | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token...
class Lamb (line 29) | class Lamb(Optimizer):
method __init__ (line 47) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
method step (line 62) | def step(self, closure=None):
FILE: torchtools/optim/lookahead.py
class Lookahead (line 16) | class Lookahead(Optimizer):
method __init__ (line 17) | def __init__(self, base_optimizer, alpha=0.5, k=6):
method update_slow (line 33) | def update_slow(self, group):
method sync_lookahead (line 45) | def sync_lookahead(self):
method step (line 49) | def step(self, closure=None):
method state_dict (line 59) | def state_dict(self):
method load_state_dict (line 73) | def load_state_dict(self, state_dict):
function LookaheadAdam (line 100) | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
FILE: torchtools/optim/novograd.py
class AdamW (line 24) | class AdamW(Optimizer):
method __init__ (line 47) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
method __setstate__ (line 61) | def __setstate__(self, state):
method step (line 66) | def step(self, closure=None):
class Novograd (line 124) | class Novograd(Optimizer):
method __init__ (line 143) | def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
method __setstate__ (line 160) | def __setstate__(self, state):
method step (line 165) | def step(self, closure=None):
FILE: torchtools/optim/over9000.py
function Over9000 (line 17) | def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
FILE: torchtools/optim/radam.py
class RAdam (line 10) | class RAdam(Optimizer):
method __init__ (line 11) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weig...
method __setstate__ (line 29) | def __setstate__(self, state):
method step (line 32) | def step(self, closure=None):
class PlainRAdam (line 100) | class PlainRAdam(Optimizer):
method __init__ (line 102) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weig...
method __setstate__ (line 117) | def __setstate__(self, state):
method step (line 120) | def step(self, closure=None):
class AdamW (line 177) | class AdamW(Optimizer):
method __init__ (line 179) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weig...
method __setstate__ (line 193) | def __setstate__(self, state):
method step (line 196) | def step(self, closure=None):
FILE: torchtools/optim/ralamb.py
class Ralamb (line 9) | class Ralamb(Optimizer):
method __init__ (line 11) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weig...
method __setstate__ (line 16) | def __setstate__(self, state):
method step (line 19) | def step(self, closure=None):
FILE: torchtools/optim/ranger.py
function Ranger (line 14) | def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):
FILE: torchtools/transforms/smart_crop.py
class ResBlock (line 8) | class ResBlock(nn.Module):
method __init__ (line 9) | def __init__(self, channels):
method forward (line 22) | def forward(self, x):
class Upsample2d (line 26) | class Upsample2d(nn.Module):
method __init__ (line 27) | def __init__(self, scale_factor):
method forward (line 33) | def forward(self, x):
class MicroResNet (line 37) | class MicroResNet(nn.Module):
method __init__ (line 38) | def __init__(self):
method forward (line 73) | def forward(self, x):
class SmartCrop (line 80) | class SmartCrop(nn.Module):
method __init__ (line 81) | def __init__(self, output_size, randomize_p=0.0, randomize_q=0.1, temp...
method forward (line 92) | def forward(self, image):
FILE: torchtools/utils/diffusion.py
class SimpleSampler (line 4) | class SimpleSampler():
method __init__ (line 5) | def __init__(self, diffuzz):
method __call__ (line 9) | def __call__(self, *args, **kwargs):
method init_x (line 13) | def init_x(self, shape):
method step (line 16) | def step(self, x, t, t_prev, noise):
class DDPMSampler (line 19) | class DDPMSampler(SimpleSampler):
method step (line 20) | def step(self, x, t, t_prev, noise):
class DDIMSampler (line 29) | class DDIMSampler(SimpleSampler):
method step (line 30) | def step(self, x, t, t_prev, noise):
class DPMSolverPlusPlusSampler (line 38) | class DPMSolverPlusPlusSampler(SimpleSampler): # FIXME: CURRENTLY NOT W...
method __init__ (line 39) | def __init__(self, *args, **kwargs):
method _get_coef (line 43) | def _get_coef(self, alpha_cumprod):
method init_x (line 50) | def init_x(self, shape):
method step (line 54) | def step(self, x, t, t_prev, noise):
class Diffuzz (line 89) | class Diffuzz():
method __init__ (line 90) | def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, ...
method _alpha_cumprod (line 100) | def _alpha_cumprod(self, t):
method diffuse (line 111) | def diffuse(self, x, t, noise=None): # t -> [0, 1]
method undiffuse (line 117) | def undiffuse(self, x, t, t_prev, noise, sampler=None):
method sample (line 122) | def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t...
method p2_weight (line 152) | def p2_weight(self, t, k=1.0, gamma=1.0):
FILE: torchtools/utils/diffusion2.py
class SimpleSampler (line 5) | class SimpleSampler():
method __init__ (line 6) | def __init__(self, diffuzz, mode="v"):
method __call__ (line 13) | def __call__(self, *args, **kwargs):
method init_x (line 17) | def init_x(self, shape):
method step (line 20) | def step(self, x, t, t_prev, noise):
class DDIMSampler (line 24) | class DDIMSampler(SimpleSampler):
method step (line 25) | def step(self, x, t, t_prev, pred, eta=0):
class DDPMSampler (line 42) | class DDPMSampler(DDIMSampler):
method step (line 43) | def step(self, x, t, t_prev, pred, eta=1):
class Diffuzz2 (line 52) | class Diffuzz2():
method __init__ (line 53) | def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, ...
method _alpha_cumprod (line 63) | def _alpha_cumprod(self, t):
method scale_t (line 73) | def scale_t(self, t, scaler):
method diffuse (line 81) | def diffuse(self, x, t, noise=None): # t -> [0, 1]
method get_v (line 87) | def get_v(self, x, t, noise):
method x0_from_v (line 93) | def x0_from_v(self, noised, pred_v, t):
method noise_from_v (line 97) | def noise_from_v(self, noised, pred_v, t):
method undiffuse (line 101) | def undiffuse(self, x, t, t_prev, pred, sampler=None, **kwargs):
method sample (line 106) | def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t...
method p2_weight (line 158) | def p2_weight(self, t, k=1.0, gamma=1.0):
method truncated_snr_weight (line 162) | def truncated_snr_weight(self, t, min=1.0, max=None):
FILE: torchtools/utils/gamma_parametrization.py
class _GammaScaling (line 5) | class _GammaScaling(nn.Module):
method __init__ (line 6) | def __init__(self):
method forward (line 10) | def forward(self, w):
function apply_gamma_reparam (line 13) | def apply_gamma_reparam(module, name="weight"): # this reparametrizes th...
function gamma_reparam_model (line 18) | def gamma_reparam_model(model):
function remove_gamma_reparam (line 26) | def remove_gamma_reparam(model):
FILE: torchtools/utils/weight_normalization.py
class _WeigthNorm (line 4) | class _WeigthNorm(nn.Module):
method __init__ (line 5) | def __init__(self, eps=1e-4):
method _normalize (line 9) | def _normalize(self, w):
method forward (line 15) | def forward(self, w):
function apply_weight_norm (line 23) | def apply_weight_norm(module, name="weight", init_weight=True): # this r...
function weight_norm_model (line 29) | def weight_norm_model(model, whitelist=None, init_weight=True):
function remove_weight_norm (line 43) | def remove_weight_norm(model):
Condensed preview — 53 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (147K chars).
[
{
"path": ".gitignore",
"chars": 45,
"preview": "*/**/__pycache__\n*.egg-info\n/dist\n*.pyc\n.idea"
},
{
"path": "LICENSE",
"chars": 1077,
"preview": "MIT License\n\nCopyright (c) 2022 Pablo Pernías\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
},
{
"path": "MANIFEST.in",
"chars": 82,
"preview": "include LICENSE\ninclude readme.md\nrecursive-include torchtools/transforms/models *"
},
{
"path": "readme.md",
"chars": 21042,
"preview": "# Pytorch Tools\n\n## Install\n\nRequirements:\n```\nPyTorch >= 1.0.0\nTorchivision\nNumpy >= 1.0.0\n```\n\n```\n# In order to insta"
},
{
"path": "setup.cfg",
"chars": 39,
"preview": "[metadata]\ndescription-file = readme.md"
},
{
"path": "setup.py",
"chars": 694,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name='torchtools',\n packages=find_packages(),\n description"
},
{
"path": "torchtools/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "torchtools/lr_scheduler/__init__.py",
"chars": 104,
"preview": "from .delayed import DelayerScheduler, DelayedCosineAnnealingLR\nfrom .inverse_sqrt import InverseSqrtLR\n"
},
{
"path": "torchtools/lr_scheduler/delayed.py",
"chars": 1298,
"preview": "from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR\n\nclass DelayerScheduler(_LRScheduler):\n\t\"\"\" Starts "
},
{
"path": "torchtools/lr_scheduler/inverse_sqrt.py",
"chars": 1170,
"preview": "import warnings\n\nfrom torch.optim.lr_scheduler import LRScheduler\n\n\nclass InverseSqrtLR(LRScheduler):\n def __init__(s"
},
{
"path": "torchtools/nn/__init__.py",
"chars": 695,
"preview": "from .mish import Mish\nfrom .simple_self_attention import SimpleSelfAttention\nfrom .vq import VectorQuantize, Binarize, "
},
{
"path": "torchtools/nn/adain.py",
"chars": 378,
"preview": "import torch\nfrom torch import nn\n\nclass AdaIN(nn.Module):\n def __init__(self, n_channels):\n super(AdaIN, self"
},
{
"path": "torchtools/nn/alias_free_activation.py",
"chars": 4702,
"preview": "import torch\nfrom torch import nn\nimport math\nfrom .stylegan2 import upfirdn2d\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED "
},
{
"path": "torchtools/nn/equal_layers.py",
"chars": 1424,
"preview": "\nimport torch\nfrom torch import nn\nimport math\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM \n# https://gith"
},
{
"path": "torchtools/nn/evonorm2d.py",
"chars": 3690,
"preview": "import torch\nimport torch.nn as nn\n\n\n## Taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to di"
},
{
"path": "torchtools/nn/fourier_features.py",
"chars": 3103,
"preview": "import torch\nfrom torch import nn\nimport math\n\nclass FourierFeatures2d(nn.Module):\n def __init__(self, size, dim, cut"
},
{
"path": "torchtools/nn/functional/__init__.py",
"chars": 178,
"preview": "from .vq import vector_quantize, binarize\nfrom .gradient_penalty import gradient_penalty\nfrom .perceptual import total_v"
},
{
"path": "torchtools/nn/functional/gradient_penalty.py",
"chars": 1039,
"preview": "####\n# CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp\n# ORIGINAL PAPER https://arxiv.org/pdf/"
},
{
"path": "torchtools/nn/functional/magnitude_preserving.py",
"chars": 731,
"preview": "import torch\n\ndef mp_cat(*args, dim=1, t=0.5):\n if isinstance(t, float):\n t = [1-t, t]\n assert len(args) =="
},
{
"path": "torchtools/nn/functional/perceptual.py",
"chars": 270,
"preview": "import torch\n\ndef total_variation(X, reduction='sum'):\n\ttv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1])\n\ttv_v = torch"
},
{
"path": "torchtools/nn/functional/vq.py",
"chars": 1320,
"preview": "import torch\nfrom torch.autograd import Function\n\nclass vector_quantize(Function):\n\t@staticmethod\n\tdef forward(ctx, x, c"
},
{
"path": "torchtools/nn/gp_loss.py",
"chars": 344,
"preview": "import torch\nfrom torch import nn\nfrom .functional import gradient_penalty\n\nclass GPLoss(nn.Module):\n\tdef __init__(self,"
},
{
"path": "torchtools/nn/haar_dwt.py",
"chars": 2515,
"preview": "import torch\nfrom torch import nn\n\n# Taken almost as is from https://github.com/bes-dev/haar_pytorch\nclass HaarForward(n"
},
{
"path": "torchtools/nn/magnitude_preserving.py",
"chars": 443,
"preview": "import torch\nfrom torch import nn\n\nclass MP_GELU(nn.GELU):\n def forward(self, x):\n return super().forward(x) /"
},
{
"path": "torchtools/nn/mish.py",
"chars": 691,
"preview": "####\n# CODE TAKEN FROM https://github.com/lessw2020/mish\n# ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1\n####\n\nimpor"
},
{
"path": "torchtools/nn/modulation.py",
"chars": 1886,
"preview": "\nimport torch\nfrom torch import nn\nimport math\n\n####\n# TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM \n# https://gith"
},
{
"path": "torchtools/nn/perceptual.py",
"chars": 324,
"preview": "import torch\nimport torch.nn as nn\nfrom .functional import total_variation\n\nclass TVLoss(nn.Module):\n\tdef __init__(self,"
},
{
"path": "torchtools/nn/pixel_normalzation.py",
"chars": 295,
"preview": "import torch\nfrom torch import nn\n\nclass PixelNorm(nn.Module):\n def __init__(self, dim=1, eps=1e-4):\n super()."
},
{
"path": "torchtools/nn/pos_embeddings.py",
"chars": 1323,
"preview": "import torch\nfrom torch import nn\n\n####\n# CODE TAKEN FROM https://github.com/lucidrains/x-transformers\n####\n\nclass Rotar"
},
{
"path": "torchtools/nn/simple_self_attention.py",
"chars": 1711,
"preview": "import torch.nn as nn\nimport torch, math, sys\n\n####\n# CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention\n####"
},
{
"path": "torchtools/nn/stylegan2/__init__.py",
"chars": 32,
"preview": "from .upfirdn2d import upfirdn2d"
},
{
"path": "torchtools/nn/stylegan2/upfirdn2d.cpp",
"chars": 1343,
"preview": "#include <ATen/ATen.h>\r\n#include <torch/extension.h>\r\n\r\ntorch::Tensor upfirdn2d_op(const torch::Tensor &input,\r\n "
},
{
"path": "torchtools/nn/stylegan2/upfirdn2d.py",
"chars": 6121,
"preview": "from collections import abc\r\nimport os\r\n\r\nimport torch\r\nfrom torch.nn import functional as F\r\nfrom torch.autograd import"
},
{
"path": "torchtools/nn/stylegan2/upfirdn2d_kernel.cu",
"chars": 13022,
"preview": "// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.\r\n//\r\n// This work is made available under the Nvidia Sou"
},
{
"path": "torchtools/nn/transformers.py",
"chars": 870,
"preview": "import torch.nn as nn\n\n\n# Based on the GPT2 implementatyion from MinGPT https://github.com/karpathy/minGPT by Andrej Kar"
},
{
"path": "torchtools/nn/vq.py",
"chars": 4711,
"preview": "import torch\nfrom torch import nn\nfrom .functional.vq import vector_quantize, binarize\nimport numpy as np\n\nclass VectorQ"
},
{
"path": "torchtools/optim/__init__.py",
"chars": 243,
"preview": "from .radam import RAdam, PlainRAdam, AdamW\nfrom .ranger import Ranger\nfrom .lookahead import Lookahead, LookaheadAdam\nf"
},
{
"path": "torchtools/optim/lamb.py",
"chars": 5272,
"preview": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport collections\nimport math\n\nimport torch\nfrom torc"
},
{
"path": "torchtools/optim/lookahead.py",
"chars": 4176,
"preview": "####\n# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch\n# Original paper: https://arxiv.org/abs/1907.086"
},
{
"path": "torchtools/optim/novograd.py",
"chars": 9611,
"preview": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\n# Copyright (c) 2019, NVIDIA CORPORATION. All rights r"
},
{
"path": "torchtools/optim/over9000.py",
"chars": 618,
"preview": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport torch, math\nfrom torch.optim.optimizer import O"
},
{
"path": "torchtools/optim/radam.py",
"chars": 10459,
"preview": "####\n# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam\n# Paper: https://arxiv.org/abs/1908.03265\n####\n\nimport ma"
},
{
"path": "torchtools/optim/ralamb.py",
"chars": 4091,
"preview": "####\n# CODE TAKEN FROM https://github.com/mgrankin/over9000\n####\n\nimport torch, math\nfrom torch.optim.optimizer import O"
},
{
"path": "torchtools/optim/ranger.py",
"chars": 564,
"preview": "####\n# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer\n# Blog post: https://medium.com/@less"
},
{
"path": "torchtools/transforms/__init__.py",
"chars": 33,
"preview": "from .smart_crop import SmartCrop"
},
{
"path": "torchtools/transforms/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "torchtools/transforms/smart_crop.py",
"chars": 5174,
"preview": "import torch\nimport torchvision\nfrom torch import nn\nimport numpy as np\nimport os\n\n# MICRO RESNET\nclass ResBlock(nn.Modu"
},
{
"path": "torchtools/utils/__init__.py",
"chars": 252,
"preview": "from .diffusion import Diffuzz\nfrom .diffusion2 import Diffuzz2\nfrom .gamma_parametrization import apply_gamma_reparam, "
},
{
"path": "torchtools/utils/diffusion.py",
"chars": 7610,
"preview": "import torch\n\n# Samplers --------------------------------------------------------------------\nclass SimpleSampler():\n "
},
{
"path": "torchtools/utils/diffusion2.py",
"chars": 8626,
"preview": "import torch\nimport numpy as np\n\n# Samplers --------------------------------------------------------------------\nclass S"
},
{
"path": "torchtools/utils/gamma_parametrization.py",
"chars": 1360,
"preview": "import torch\nfrom torch import nn\n\n\nclass _GammaScaling(nn.Module):\n def __init__(self):\n super().__init__()\n "
},
{
"path": "torchtools/utils/weight_normalization.py",
"chars": 2122,
"preview": "import torch\nfrom torch import nn\n\nclass _WeigthNorm(nn.Module):\n def __init__(self, eps=1e-4):\n super().__ini"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the pabloppp/pytorch-tools GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 53 files (135.7 KB), approximately 39.3k tokens, and a symbol index with 229 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.