Showing preview only (393K chars total). Download the full file or copy to clipboard to get everything.
Repository: haoheliu/SemantiCodec-inference
Branch: main
Commit: 7c05d426d297
Files: 51
Total size: 373.1 KB
Directory structure:
gitextract_ma67qwgl/
├── .gitignore
├── LICENSE
├── README.md
├── semanticodec/
│ ├── __init__.py
│ ├── config.py
│ ├── main.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── audiomae/
│ │ │ ├── AudioMAE.py
│ │ │ ├── __init__.py
│ │ │ ├── models_mae.py
│ │ │ ├── patch_embed.py
│ │ │ └── pos_embed.py
│ │ ├── decoder/
│ │ │ ├── __init__.py
│ │ │ ├── hifigan/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── models_v2.py
│ │ │ ├── latent_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── ddim.py
│ │ │ │ │ ├── ddpm.py
│ │ │ │ │ └── dpm_solver/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── dpm_solver.py
│ │ │ │ ├── modules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── attention.py
│ │ │ │ │ ├── diffusionmodules/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── model.py
│ │ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ │ └── util.py
│ │ │ │ │ ├── distributions/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── distributions.py
│ │ │ │ │ ├── ema.py
│ │ │ │ │ ├── mamba.py
│ │ │ │ │ ├── nn.py
│ │ │ │ │ └── x_transformer.py
│ │ │ │ └── util.py
│ │ │ ├── latent_encoder/
│ │ │ │ ├── __init__.py
│ │ │ │ └── autoencoder.py
│ │ │ └── utilities/
│ │ │ ├── __init__.py
│ │ │ ├── audio/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── audio_processing.py
│ │ │ │ ├── stft.py
│ │ │ │ └── tools.py
│ │ │ ├── model.py
│ │ │ └── tools.py
│ │ └── encoder/
│ │ ├── __init__.py
│ │ └── encoder.py
│ └── utils.py
├── setup.py
└── test/
├── encoding.py
└── test_all_settings.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
pretrained
*.npy
*.wav
g_*
*.pyc
*.pkl
*.json
codebook_idx
long_audio*
output*
sample*
*.json
*.egg-info
.ipynb*
trim_checkpoint.py
__pycache*
.DS*
build
================================================
FILE: LICENSE
================================================
Copyright (c) 2012-2024 Scott Chacon and others
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: README.md
================================================
[](https://arxiv.org/abs/2405.00233) [](https://haoheliu.github.io/SemantiCodec/)
# SemantiCodec
Ultra-low bitrate neural audio codec with a better semantic in the latent space.
**Highlight**
- Bitrate: 0.31 kbps - 1.40 kbps
- Token rate: 25, 50, or 100 per second
- cpu, cuda, and mps are supported
# Usage
## Installation
```bash
pip install git+https://github.com/haoheliu/SemantiCodec-inference.git
```
## Encoding and decoding
**Checkpoints will be automatically downloaded when you initialize the SemantiCodec with the following code.**
```python
from semanticodec import SemantiCodec
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=16384)
filepath = "test/test.wav" # audio with arbitrary length
tokens = semanticodec.encode(filepath)
waveform = semanticodec.decode(tokens)
# Save the reconstruction file
import soundfile as sf
sf.write("output.wav", waveform[0,0], 16000)
```
## Other Settings
```python
from semanticodec import SemantiCodec
###############Choose one of the following######################
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps
semanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=32768) # 0.70 kbps
semanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=32768) # 0.35 kbps
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=16384) # 1.35 kbps
semanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=16384) # 0.68 kbps
semanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=16384) # 0.34 kbps
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=8192) # 1.30 kbps
semanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=8192) # 0.65 kbps
semanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=8192) # 0.33 kbps
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=4096) # 1.25 kbps
semanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=4096) # 0.63 kbps
semanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=4096) # 0.31 kbps
#####################################
filepath = "test/test.wav"
tokens = semanticodec.encode(filepath)
waveform = semanticodec.decode(tokens)
import soundfile as sf
sf.write("output.wav", waveform[0,0], 16000)
```
If you are interested in reusing the same evaluation pipeline and data in the paper, please refer to this [zenodo repo](https://zenodo.org/records/11047204).
## Citation
If you find this repo helpful, please consider citing in the following format:
```bibtex
@ARTICLE{semanticodec2024,
author={Liu, Haohe and Xu, Xuenan and Yuan, Yi and Wu, Mengyue and Wang, Wenwu and Plumbley, Mark D.},
journal={IEEE Journal of Selected Topics in Signal Processing},
title={SemantiCodec: An Ultra Low Bitrate Semantic Audio Codec for General Sound},
year={2024},
volume={18},
number={8},
pages={1448-1461},
doi={10.1109/JSTSP.2024.3506286}
}
```

================================================
FILE: semanticodec/__init__.py
================================================
from semanticodec.main import SemantiCodec
================================================
FILE: semanticodec/config.py
================================================
def get_config(token_rate=100, vocab_size=None, checkpoint_path=None):
assert vocab_size in [4096, 8192, 16384, 32768], "vocab_size must be 4096, 8192, 16384 or 32768"
assert token_rate in [25, 50, 100], "token_rate must be 25, 50 or 100"
if checkpoint_path is not None:
semantic_codebook = {
25: {
4096: f"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_2048_0.npy",
8192: f"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_4096_0.npy",
16384: f"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_8192_0.npy",
32768: f"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_16384_0.npy",
},
50: {
4096: f"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_2048_0.npy",
8192: f"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_4096_0.npy",
16384: f"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_8192_0.npy",
32768: f"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_16384_0.npy",
},
100: {
4096: f"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_2048_0.npy",
8192: f"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_4096_0.npy",
16384: f"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_8192_0.npy",
32768: f"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_16384_0.npy",
},
}
else:
semantic_codebook = {
25: {
4096: "codebook_idx/combine_128_audioset_dominate/codebook_2048_0.npy",
8192: "codebook_idx/combine_128_audioset_dominate/codebook_4096_0.npy",
16384: "codebook_idx/combine_128_audioset_dominate/codebook_8192_0.npy",
32768: "codebook_idx/combine_128_audioset_dominate/codebook_16384_0.npy",
},
50: {
4096: "codebook_idx/combine_256_audioset_dominate/codebook_2048_0.npy",
8192: "codebook_idx/combine_256_audioset_dominate/codebook_4096_0.npy",
16384: "codebook_idx/combine_256_audioset_dominate/codebook_8192_0.npy",
32768: "codebook_idx/combine_256_audioset_dominate/codebook_16384_0.npy",
},
100: {
4096: "codebook_idx/combine_512_audioset_dominate/codebook_2048_0.npy",
8192: "codebook_idx/combine_512_audioset_dominate/codebook_4096_0.npy",
16384: "codebook_idx/combine_512_audioset_dominate/codebook_8192_0.npy",
32768: "codebook_idx/combine_512_audioset_dominate/codebook_16384_0.npy",
},
}
basic_config = {
"model": {
"params": {
"latent_t_size": 256,
"scale_by_std": True,
"sampling_rate": 16000,
"first_stage_config": {
"params": {
"monitor": "val/rec_loss",
"image_key": "fbank",
"embed_dim": 8,
"batchsize": 16,
"reload_from_ckpt": "/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/vae_32k/2023_06_22_vae_16k_64_4/last.ckpt",
"subband": 1,
"time_shuffle": 1,
"sampling_rate": 16000,
"ddconfig": {
"ch": 128,
"double_z": True,
"out_ch": 1,
"attn_resolutions": [],
"dropout": 0.0,
"mel_bins": 64,
"ch_mult": [
1,
2,
4
],
"num_res_blocks": 2,
"z_channels": 8,
"downsample_time": False,
"in_channels": 1,
"resolution": 256
},
"lossconfig": {
"params": {
"disc_start": 50001,
"kl_weight": 1000.0,
"disc_in_channels": 1,
"disc_weight": 0.5
},
"target": "semanticodec.modules.decoder.latent_diffusion.modules.losses.LPIPSWithDiscriminator"
}
},
"target": "semanticodec.modules.decoder.latent_encoder.autoencoder.AutoencoderKL",
"base_learning_rate": 8e-06
},
"unet_config": {
"params": {
"channel_mult": [
1,
2,
3,
5
],
"out_channels": 8,
"attention_resolutions": [
8,
4,
2
],
"context_dim": [
1728
],
"num_res_blocks": 2,
"in_channels": 8,
"image_size": 64,
"transformer_depth": 1,
"use_spatial_transformer": True,
"model_channels": 64,
"num_head_channels": 32
},
"target": "semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel"
},
"base_learning_rate": 0.0001,
"channels": 8,
"linear_start": 0.0015,
"first_stage_key": "fbank",
"parameterization": "v",
"cond_stage_config": {
"crossattn_audiomae_pooled": {
"cond_stage_key": "ta_kaldi_fbank",
"params": {
"use_oracle": False,
"lstm_bidirectional": True,
"feature_dimension": 768,
"codebook_size": 8192,
"residual_encoder": "lstm",
"rvq_layers": 0,
"lstm_layer": 4
},
"target": "semanticodec.modules.encoder.encoder.AudioMAEConditionQuantResEncoder",
"conditioning_key": "crossattn"
}
},
"num_timesteps_cond": 1,
"timesteps": 1000,
"latent_f_size": 16,
"linear_end": 0.0195
},
"target": "semanticodec.modules.decoder.latent_diffusion.models.ddpm.LatentDiffusion"
}
}
if token_rate == 50:
# modify context_dim
basic_config["model"]["params"]["unet_config"]["params"]["context_dim"] = [3264]
# modify cond_stage_config
basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["lstm_layer"] = 3
basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["feature_dimension"] = 768 * 2
elif token_rate == 25:
# modify context_dim
basic_config["model"]["params"]["unet_config"]["params"]["context_dim"] = [6336]
# modify cond_stage_config
basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["lstm_layer"] = 2
basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["feature_dimension"] = 768 * 4
elif token_rate == 100:
pass
else:
raise ValueError("token_rate must be 50, 25 or 100")
if checkpoint_path is None:
checkpoint_path = "semanticodec_tokenrate_%s" % token_rate
else:
print("Using custom checkpoint path: %s" % checkpoint_path)
feature_dim = basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["feature_dimension"]
lstm_layers = basic_config["model"]["params"]["cond_stage_config"]["crossattn_audiomae_pooled"]["params"]["lstm_layer"]
return basic_config, checkpoint_path, feature_dim, lstm_layers, semantic_codebook[token_rate][vocab_size]
================================================
FILE: semanticodec/main.py
================================================
from configparser import NoSectionError
import torch
import torch.nn as nn
import os
import torchaudio
import math
from semanticodec.modules.encoder.encoder import AudioMAEConditionQuantResEncoder
from semanticodec.modules.decoder.latent_diffusion.models.ddpm import (
extract_encoder_state_dict,
overlap_add_waveform,
)
from semanticodec.config import get_config
from semanticodec.modules.decoder.latent_diffusion.util import instantiate_from_config
from semanticodec.utils import extract_kaldi_fbank_feature
from huggingface_hub import hf_hub_download
# Constants
SAMPLE_RATE = 16000
SEGMENT_DURATION = 10.24
MEL_TARGET_LENGTH = 1024
AUDIOMAE_PATCH_DURATION = 0.16
SEGMENT_OVERLAP_RATIO = 0.0625
class SemantiCodec(nn.Module):
def __init__(
self,
token_rate,
semantic_vocab_size,
ddim_sample_step=50,
cfg_scale=2.0,
checkpoint_path = None,
cache_path="pretrained",
):
super().__init__()
self.token_rate = token_rate
self.stack_factor_K = 100 / self.token_rate
self.ddim_sample_step = ddim_sample_step
self.cfg_scale = cfg_scale
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
# Initialize encoder and decoder
config, checkpoint_path, feature_dim, lstm_layers, semanticodebook = get_config(
token_rate, semantic_vocab_size, checkpoint_path
)
encoder_checkpoint_path = os.path.join(checkpoint_path, "encoder.ckpt")
if not os.path.exists(encoder_checkpoint_path):
if not os.path.exists(cache_path):
os.makedirs(cache_path)
print(f"checkpoint cache dir '{cache_path}' was created.")
encoder_checkpoint_path = hf_hub_download(repo_id="haoheliu/SemantiCodec",filename=checkpoint_path+"/encoder.ckpt",cache_dir=cache_path)
decoder_checkpoint_path = os.path.join(checkpoint_path, "decoder.ckpt")
if not os.path.exists(decoder_checkpoint_path):
decoder_checkpoint_path = hf_hub_download(repo_id="haoheliu/SemantiCodec",filename=checkpoint_path+"/decoder.ckpt",cache_dir=cache_path)
if not os.path.exists(semanticodebook):
semanticodebook = "/".join(semanticodebook.split("/")[-3:])
semanticodebook = hf_hub_download(repo_id="haoheliu/SemantiCodec",filename=semanticodebook,cache_dir=cache_path)
# Initialize encoder
print("🚀 Loading SemantiCodec encoder")
state_dict = torch.load(encoder_checkpoint_path, map_location="cpu")
self.encoder = AudioMAEConditionQuantResEncoder(
feature_dimension=feature_dim,
lstm_layer=lstm_layers,
centroid_npy_path=semanticodebook,
)
self.encoder.load_state_dict(state_dict)
self.encoder = self.encoder.to(self.device)
print("✅ Encoder loaded")
# Initialize decoder
print("🚀 Loading SemantiCodec decoder")
self.decoder = instantiate_from_config(config["model"])
checkpoint = torch.load(decoder_checkpoint_path, map_location="cpu")
self.decoder.load_state_dict(checkpoint)
self.decoder = self.decoder.to(self.device)
print("✅ Decoder loaded")
def load_audio(self, filepath):
if not os.path.exists(filepath):
raise FileNotFoundError(f"{filepath} does not exist")
assert isinstance(filepath, str)
waveform, sr = torchaudio.load(filepath)
# resample to 16000
if sr != SAMPLE_RATE:
waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)
sr = SAMPLE_RATE
# if stereo to mono
if waveform.shape[0] > 1:
waveform = waveform[0:1]
# Calculate the original duration
original_duration = waveform.shape[1] / sr
# This is to pad the audio to the multiplication of 0.16 seconds so that the original audio can be reconstructed
original_duration = original_duration + (
AUDIOMAE_PATCH_DURATION - original_duration % AUDIOMAE_PATCH_DURATION
)
# Calculate the token length in theory
target_token_len = (
8 * original_duration / AUDIOMAE_PATCH_DURATION / self.stack_factor_K
)
segment_sample_length = int(SAMPLE_RATE * SEGMENT_DURATION)
# Pad audio to the multiplication of 10.24 seconds for easier segmentations
if waveform.shape[1] % segment_sample_length < segment_sample_length:
waveform = torch.cat(
[
waveform,
torch.zeros(
1,
int(
segment_sample_length
- waveform.shape[1] % segment_sample_length
),
),
],
dim=1,
)
mel_target_length = MEL_TARGET_LENGTH * int(
waveform.shape[1] / segment_sample_length
)
# Calculate the mel spectrogram
mel = extract_kaldi_fbank_feature(
waveform, sr, target_length=mel_target_length
)["ta_kaldi_fbank"].unsqueeze(0)
mel = mel.squeeze(1)
assert mel.shape[-1] == 128 and mel.shape[-2] % 1024 == 0
return mel, target_token_len
def encode(self, filepath):
mel, target_token_len = self.load_audio(filepath)
tokens = self.encoder(mel.to(self.device))
tokens = tokens[:, : math.ceil(target_token_len), :]
return tokens
def decode(self, tokens):
windowed_token_list = self.encoder.long_token_split_window(
tokens,
window_length=int(512 / self.stack_factor_K),
overlap=SEGMENT_OVERLAP_RATIO,
)
windowed_waveform = []
for _, windowed_token in enumerate(windowed_token_list):
latent = self.encoder.token_to_quantized_feature(windowed_token)
latent = torch.cat(
[
latent,
torch.ones(
latent.shape[0],
int(512 / self.stack_factor_K) - latent.shape[1],
latent.shape[2],
).to(latent.device)
* -1,
],
dim=1,
)
waveform = self.decoder.generate_sample(
latent,
ddim_steps=self.ddim_sample_step,
unconditional_guidance_scale=self.cfg_scale,
)
windowed_waveform.append(waveform)
output = overlap_add_waveform(
windowed_waveform, overlap_duration=SEGMENT_DURATION * SEGMENT_OVERLAP_RATIO
)
# Each patch step equal 16 mel time frames, which have 0.01 second
trim_duration = (tokens.shape[1] / 8) * 16 * 0.01 * self.stack_factor_K
return output[..., : int(trim_duration * SAMPLE_RATE)]
def forward(self, filepath):
tokens = self.encode(filepath)
waveform = self.decode(tokens)
return waveform
================================================
FILE: semanticodec/modules/__init__.py
================================================
================================================
FILE: semanticodec/modules/audiomae/AudioMAE.py
================================================
"""
Reference Repo: https://github.com/facebookresearch/AudioMAE
"""
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
import semanticodec.modules.audiomae.models_mae as models_mae
# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))
class PatchEmbed_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
) # with overlapped patches
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h * w
def get_output_shape(self, img_size):
# todo: don't be lazy..
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class Vanilla_AudioMAE(nn.Module):
"""Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)"""
def __init__(
self,
):
super().__init__()
model = models_mae.__dict__["mae_vit_base_patch16"](
in_chans=1, audio_exp=True, img_size=(1024, 128)
)
# checkpoint_path = "/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth"
# checkpoint = torch.load(checkpoint_path, map_location="cpu")
# model.load_state_dict(checkpoint["model"], strict=False)
# Skip the missing keys of decoder modules (not required)
# print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')
self.model = model.eval()
def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
"""
x: mel fbank [Batch, 1, 1024 (T), 128 (F)]
mask_ratio: 'masking ratio (percentage of removed patches).'
"""
with torch.no_grad():
# embed: [B, 513, 768] for mask_ratio=0.0
if no_mask:
if no_average:
raise RuntimeError("This function is deprecated")
embed = self.model.forward_encoder_no_random_mask_no_average(
x
) # mask_ratio
else:
embed = self.model.forward_encoder_no_mask(x) # mask_ratio
else:
raise RuntimeError("This function is deprecated")
embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)
return embed
if __name__ == "__main__":
model = Vanilla_AudioMAE().cuda()
input = torch.randn(4, 1, 1024, 128).cuda()
print("The first run")
embed = model(input, mask_ratio=0.0, no_mask=True)
print(embed)
print("The second run")
embed = model(input, mask_ratio=0.0)
print(embed)
================================================
FILE: semanticodec/modules/audiomae/__init__.py
================================================
================================================
FILE: semanticodec/modules/audiomae/models_mae.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
from functools import partial
from json import encoder
import torch
import torch.nn as nn
from timm.models.vision_transformer import Block
from semanticodec.modules.audiomae.pos_embed import (
get_2d_sincos_pos_embed,
get_2d_sincos_pos_embed_flexible,
get_1d_sincos_pos_embed_from_grid,
)
from semanticodec.modules.audiomae.patch_embed import PatchEmbed_new, PatchEmbed_org
class MaskedAutoencoderViT(nn.Module):
"""Masked Autoencoder with VisionTransformer backbone"""
def __init__(
self,
img_size=224,
patch_size=16,
stride=10,
in_chans=3,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=16,
mlp_ratio=4.0,
norm_layer=nn.LayerNorm,
norm_pix_loss=False,
audio_exp=False,
alpha=0.0,
temperature=0.2,
mode=0,
contextual_depth=8,
use_custom_patch=False,
split_pos=False,
pos_trainable=False,
use_nce=False,
beta=4.0,
decoder_mode=0,
mask_t_prob=0.6,
mask_f_prob=0.5,
mask_2d=False,
epoch=0,
no_shift=False,
):
super().__init__()
self.audio_exp = audio_exp
self.embed_dim = embed_dim
self.decoder_embed_dim = decoder_embed_dim
# --------------------------------------------------------------------------
# MAE encoder specifics
if use_custom_patch:
print(
f"Use custom patch_emb with patch size: {patch_size}, stride: {stride}"
)
self.patch_embed = PatchEmbed_new(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
stride=stride,
)
else:
self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)
self.use_custom_patch = use_custom_patch
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# self.split_pos = split_pos # not useful
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable
) # fixed sin-cos embedding
self.encoder_depth = depth
self.contextual_depth = contextual_depth
self.blocks = nn.ModuleList(
[
Block(
embed_dim,
num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
) # qk_scale=None
for i in range(depth)
]
)
self.norm = norm_layer(embed_dim)
# --------------------------------------------------------------------------
# MAE decoder specifics
self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.decoder_pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, decoder_embed_dim),
requires_grad=pos_trainable,
) # fixed sin-cos embedding
self.no_shift = no_shift
self.decoder_mode = decoder_mode
if (
self.use_custom_patch
): # overlapped patches as in AST. Similar performance yet compute heavy
window_size = (6, 6)
feat_size = (102, 12)
else:
window_size = (4, 4)
feat_size = (64, 8)
if self.decoder_mode == 1:
decoder_modules = []
for index in range(16):
if self.no_shift:
shift_size = (0, 0)
else:
if (index % 2) == 0:
shift_size = (0, 0)
else:
shift_size = (2, 0)
# shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])
decoder_modules.append(
SwinTransformerBlock(
dim=decoder_embed_dim,
num_heads=16,
feat_size=feat_size,
window_size=window_size,
shift_size=shift_size,
mlp_ratio=mlp_ratio,
drop=0.0,
drop_attn=0.0,
drop_path=0.0,
extra_norm=False,
sequential_attn=False,
norm_layer=norm_layer, # nn.LayerNorm,
)
)
self.decoder_blocks = nn.ModuleList(decoder_modules)
else:
# Transfomer
self.decoder_blocks = nn.ModuleList(
[
Block(
decoder_embed_dim,
decoder_num_heads,
mlp_ratio,
qkv_bias=True,
norm_layer=norm_layer,
) # qk_scale=None,
for i in range(decoder_depth)
]
)
self.decoder_norm = norm_layer(decoder_embed_dim)
self.decoder_pred = nn.Linear(
decoder_embed_dim, patch_size**2 * in_chans, bias=True
) # decoder to patch
# --------------------------------------------------------------------------
self.norm_pix_loss = norm_pix_loss
self.patch_size = patch_size
self.stride = stride
# audio exps
self.alpha = alpha
self.T = temperature
self.mode = mode
self.use_nce = use_nce
self.beta = beta
self.log_softmax = nn.LogSoftmax(dim=-1)
self.mask_t_prob = mask_t_prob
self.mask_f_prob = mask_f_prob
self.mask_2d = mask_2d
self.epoch = epoch
self.initialize_weights()
def initialize_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
if self.audio_exp:
pos_embed = get_2d_sincos_pos_embed_flexible(
self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True
)
else:
pos_embed = get_2d_sincos_pos_embed(
self.pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
if self.audio_exp:
decoder_pos_embed = get_2d_sincos_pos_embed_flexible(
self.decoder_pos_embed.shape[-1],
self.patch_embed.patch_hw,
cls_token=True,
)
else:
decoder_pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
int(self.patch_embed.num_patches**0.5),
cls_token=True,
)
self.decoder_pos_embed.data.copy_(
torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)
)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w = self.patch_embed.proj.weight.data
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
torch.nn.init.normal_(self.cls_token, std=0.02)
torch.nn.init.normal_(self.mask_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def patchify(self, imgs):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
L = (H/p)*(W/p)
"""
p = self.patch_embed.patch_size[0]
# assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
if self.audio_exp:
if self.use_custom_patch: # overlapped patch
h, w = self.patch_embed.patch_hw
# todo: fixed h/w patch size and stride size. Make hw custom in the future
x = imgs.unfold(2, self.patch_size, self.stride).unfold(
3, self.patch_size, self.stride
) # n,1,H,W -> n,1,h,w,p,p
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
# x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
# x = torch.einsum('nchpwq->nhwpqc', x)
# x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
else:
h = imgs.shape[2] // p
w = imgs.shape[3] // p
# h,w = self.patch_embed.patch_hw
x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))
else:
h = w = imgs.shape[2] // p
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
x = torch.einsum("nchpwq->nhwpqc", x)
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
return x
def unpatchify(self, x):
"""
x: (N, L, patch_size**2 *3)
specs: (N, 1, H, W)
"""
p = self.patch_embed.patch_size[0]
h = 1024 // p
w = 128 // p
x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))
x = torch.einsum("nhwpqc->nchpwq", x)
specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))
return specs
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
ids_shuffle = torch.argsort(
noise, dim=1
) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
"""
2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
if self.use_custom_patch: # overlapped patch
T = 101
F = 12
else:
T = 64
F = 8
# x = x.reshape(N, T, F, D)
len_keep_t = int(T * (1 - mask_t_prob))
len_keep_f = int(F * (1 - mask_f_prob))
# noise for mask in time
noise_t = torch.rand(N, T, device=x.device) # noise in [0, 1]
# sort noise for each sample aling time
ids_shuffle_t = torch.argsort(
noise_t, dim=1
) # ascend: small is keep, large is remove
ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)
ids_keep_t = ids_shuffle_t[:, :len_keep_t]
# noise mask in freq
noise_f = torch.rand(N, F, device=x.device) # noise in [0, 1]
ids_shuffle_f = torch.argsort(
noise_f, dim=1
) # ascend: small is keep, large is remove
ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)
ids_keep_f = ids_shuffle_f[:, :len_keep_f] #
# generate the binary mask: 0 is keep, 1 is remove
# mask in freq
mask_f = torch.ones(N, F, device=x.device)
mask_f[:, :len_keep_f] = 0
mask_f = (
torch.gather(mask_f, dim=1, index=ids_restore_f)
.unsqueeze(1)
.repeat(1, T, 1)
) # N,T,F
# mask in time
mask_t = torch.ones(N, T, device=x.device)
mask_t[:, :len_keep_t] = 0
mask_t = (
torch.gather(mask_t, dim=1, index=ids_restore_t)
.unsqueeze(1)
.repeat(1, F, 1)
.permute(0, 2, 1)
) # N,T,F
mask = 1 - (1 - mask_t) * (1 - mask_f) # N, T, F
# get masked x
id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)
id2res = id2res + 999 * mask # add a large value for masked elements
id2res2 = torch.argsort(id2res.flatten(start_dim=1))
ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
ids_restore = torch.argsort(id2res2.flatten(start_dim=1))
mask = mask.flatten(start_dim=1)
return x_masked, mask, ids_restore
def forward_encoder(self, x, mask_ratio, mask_2d=False):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
if mask_2d:
x, mask, ids_restore = self.random_masking_2d(
x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob
)
else:
x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x, mask, ids_restore, None
def forward_encoder_no_random_mask_no_average(self, x):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
# if mask_2d:
# x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)
# else:
# x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x
def forward_encoder_no_mask(self, x):
# embed patches
x = self.patch_embed(x)
# add pos embed w/o cls token
x = x + self.pos_embed[:, 1:, :]
# masking: length -> length * mask_ratio
# x, mask, ids_restore = self.random_masking(x, mask_ratio)
# append cls token
cls_token = self.cls_token + self.pos_embed[:, :1, :]
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# apply Transformer blocks
contextual_embs = []
for n, blk in enumerate(self.blocks):
x = blk(x)
if n > self.contextual_depth:
contextual_embs.append(self.norm(x))
# x = self.norm(x)
contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)
return contextual_emb
def forward_decoder(self, x, ids_restore):
# embed tokens
x = self.decoder_embed(x)
# append mask tokens to sequence
mask_tokens = self.mask_token.repeat(
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
)
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
x_ = torch.gather(
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
) # unshuffle
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
# add pos embed
x = x + self.decoder_pos_embed
if self.decoder_mode != 0:
B, L, D = x.shape
x = x[:, 1:, :]
if self.use_custom_patch:
x = x.reshape(B, 101, 12, D)
x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1) # hack
x = x.reshape(B, 1224, D)
if self.decoder_mode > 3: # mvit
x = self.decoder_blocks(x)
else:
# apply Transformer blocks
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
pred = self.decoder_pred(x)
# remove cls token
if self.decoder_mode != 0:
if self.use_custom_patch:
pred = pred.reshape(B, 102, 12, 256)
pred = pred[:, :101, :, :]
pred = pred.reshape(B, 1212, 256)
else:
pred = pred
else:
pred = pred[:, 1:, :]
return pred, None, None # emb, emb_pixel
def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target = self.patchify(imgs)
if norm_pix_loss:
mean = target.mean(dim=-1, keepdim=True)
var = target.var(dim=-1, keepdim=True)
target = (target - mean) / (var + 1.0e-6) ** 0.5
loss = (pred - target) ** 2
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
return loss
def forward(self, imgs, mask_ratio=0.8):
emb_enc, mask, ids_restore, _ = self.forward_encoder(
imgs, mask_ratio, mask_2d=self.mask_2d
)
pred, _, _ = self.forward_decoder(emb_enc, ids_restore) # [N, L, p*p*3]
loss_recon = self.forward_loss(
imgs, pred, mask, norm_pix_loss=self.norm_pix_loss
)
loss_contrastive = torch.FloatTensor([0.0]).cuda()
return loss_recon, pred, mask, loss_contrastive
def mae_vit_small_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16,
embed_dim=384,
depth=12,
num_heads=6,
decoder_embed_dim=512,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
def mae_vit_base_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
decoder_embed_dim=512,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
def mae_vit_large_patch16_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=16,
embed_dim=1024,
depth=24,
num_heads=16,
decoder_embed_dim=512,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
def mae_vit_huge_patch14_dec512d8b(**kwargs):
model = MaskedAutoencoderViT(
patch_size=14,
embed_dim=1280,
depth=32,
num_heads=16,
decoder_embed_dim=512,
decoder_num_heads=16,
mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
return model
# set recommended archs
mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks
mae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b # decoder: 512 dim, 8 blocks
================================================
FILE: semanticodec/modules/audiomae/patch_embed.py
================================================
import torch
import torch.nn as nn
from timm.models.layers import to_2tuple
class PatchEmbed_org(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
y = x.flatten(2).transpose(1, 2)
return y
class PatchEmbed_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
self.img_size = img_size
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
) # with overlapped patches
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
# self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])
# self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
_, _, h, w = self.get_output_shape(img_size) # n, emb_dim, h, w
self.patch_hw = (h, w)
self.num_patches = h * w
def get_output_shape(self, img_size):
# todo: don't be lazy..
return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
# assert H == self.img_size[0] and W == self.img_size[1], \
# f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# x = self.proj(x).flatten(2).transpose(1, 2)
x = self.proj(x) # 32, 1, 1024, 128 -> 32, 768, 101, 12
x = x.flatten(2) # 32, 768, 101, 12 -> 32, 768, 1212
x = x.transpose(1, 2) # 32, 768, 1212 -> 32, 1212, 768
return x
class PatchEmbed3D_new(nn.Module):
"""Flexible Image to Patch Embedding"""
def __init__(
self,
video_size=(16, 224, 224),
patch_size=(2, 16, 16),
in_chans=3,
embed_dim=768,
stride=(2, 16, 16),
):
super().__init__()
self.video_size = video_size
self.patch_size = patch_size
self.in_chans = in_chans
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=patch_size, stride=stride
)
_, _, t, h, w = self.get_output_shape(video_size) # n, emb_dim, h, w
self.patch_thw = (t, h, w)
self.num_patches = t * h * w
def get_output_shape(self, video_size):
# todo: don't be lazy..
return self.proj(
torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])
).shape
def forward(self, x):
B, C, T, H, W = x.shape
x = self.proj(x) # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14
x = x.flatten(2) # 32, 768, 1568
x = x.transpose(1, 2) # 32, 768, 1568 -> 32, 1568, 768
return x
if __name__ == "__main__":
# patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))
# input = torch.rand(8,1,1024,128)
# output = patch_emb(input)
# print(output.shape) # (8,512,64)
patch_emb = PatchEmbed3D_new(
video_size=(6, 224, 224),
patch_size=(2, 16, 16),
in_chans=3,
embed_dim=768,
stride=(2, 16, 16),
)
input = torch.rand(8, 3, 6, 224, 224)
output = patch_emb(input)
print(output.shape) # (8,64)
================================================
FILE: semanticodec/modules/audiomae/pos_embed.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Position embedding utils
# --------------------------------------------------------
import numpy as np
import torch
# --------------------------------------------------------
# 2D sine-cosine position embedding
# References:
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
# MoCo v3: https://github.com/facebookresearch/moco-v3
# --------------------------------------------------------
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size[0], dtype=np.float32)
grid_w = np.arange(grid_size[1], dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
# omega = np.arange(embed_dim // 2, dtype=np.float)
omega = np.arange(embed_dim // 2, dtype=float)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------
def interpolate_pos_embed(model, checkpoint_model):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches**0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size, orig_size, new_size, new_size)
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size, orig_size, embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
# orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
# new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size[0], orig_size[1], new_size[0], new_size[1])
)
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(
-1, orig_size[0], orig_size[1], embedding_size
).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size[0], new_size[1]),
mode="bicubic",
align_corners=False,
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):
if "pos_embed" in checkpoint_model:
pos_embed_checkpoint = checkpoint_model["pos_embed"]
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
if orig_size != new_size:
print(
"Position interpolate from %dx%d to %dx%d"
% (orig_size[0], orig_size[1], new_size[0], new_size[1])
)
# extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)
pos_tokens = pos_embed_checkpoint[:, 1:, :] # remove
pos_tokens = pos_tokens.reshape(
-1, orig_size[0], orig_size[1], embedding_size
) # .permute(0, 3, 1, 2)
# pos_tokens = torch.nn.functional.interpolate(
# pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
# pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
pos_tokens = pos_tokens[:, :, : new_size[1], :] # assume only time diff
pos_tokens = pos_tokens.flatten(1, 2)
new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)
checkpoint_model["pos_embed"] = new_pos_embed
def interpolate_patch_embed_audio(
model,
checkpoint_model,
orig_channel,
new_channel=1,
kernel_size=(16, 16),
stride=(16, 16),
padding=(0, 0),
):
if orig_channel != new_channel:
if "patch_embed.proj.weight" in checkpoint_model:
# aggregate 3 channels in rgb ckpt to 1 channel for audio
new_proj_weight = torch.nn.Parameter(
torch.sum(checkpoint_model["patch_embed.proj.weight"], dim=1).unsqueeze(
1
)
)
checkpoint_model["patch_embed.proj.weight"] = new_proj_weight
================================================
FILE: semanticodec/modules/decoder/__init__.py
================================================
================================================
FILE: semanticodec/modules/decoder/hifigan/LICENSE
================================================
MIT License
Copyright (c) 2020 Jungil Kong
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: semanticodec/modules/decoder/hifigan/__init__.py
================================================
from .models_v2 import Generator
from .models import Generator as Generator_old
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
================================================
FILE: semanticodec/modules/decoder/hifigan/models.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock, self).__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
# print("Removing weight norm...")
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
================================================
FILE: semanticodec/modules/decoder/hifigan/models_v2.py
================================================
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
LRELU_SLOPE = 0.1
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
m.weight.data.normal_(mean, std)
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
)
),
]
)
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
)
),
]
)
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs1:
remove_weight_norm(l)
for l in self.convs2:
remove_weight_norm(l)
class ResBlock2(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.h = h
self.convs = nn.ModuleList(
[
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
)
),
weight_norm(
Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
)
),
]
)
self.convs.apply(init_weights)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
def remove_weight_norm(self):
for l in self.convs:
remove_weight_norm(l)
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(
Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
)
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(
weight_norm(
ConvTranspose1d(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2 ** (i + 1)),
u * 2,
u,
padding=u // 2 + u % 2,
output_padding=u % 2,
)
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
# import ipdb; ipdb.set_trace()
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
for l in self.ups:
remove_weight_norm(l)
for l in self.resblocks:
l.remove_weight_norm()
remove_weight_norm(self.conv_pre)
remove_weight_norm(self.conv_post)
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/__init__.py
================================================
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/models/__init__.py
================================================
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/models/ddim.py
================================================
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
class DDIMSampler(object):
def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
if torch.backends.mps.is_available():
ddim_sigmas = ddim_sigmas.to(torch.float32)
ddim_alphas = ddim_alphas.to(torch.float32)
ddim_alphas_prev = ddim_alphas_prev.astype(np.float32)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.0,
mask=None,
x0=None,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
ucg_schedule=None,
**kwargs
):
# if conditioning is not None:
# if isinstance(conditioning, dict):
# ctmp = conditioning[list(conditioning.keys())[0]]
# while isinstance(ctmp, list): ctmp = ctmp[0]
# cbs = ctmp.shape[0]
# if cbs != batch_size:
# print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
# elif isinstance(conditioning, list):
# for ctmp in conditioning:
# if ctmp.shape[0] != batch_size:
# print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
# else:
# if conditioning.shape[0] != batch_size:
# print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None,
):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = (
self.ddpm_num_timesteps
if ddim_use_original_steps
else self.ddim_timesteps
)
elif timesteps is not None and not ddim_use_original_steps:
subset_end = (
int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]
)
- 1
)
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {"x_inter": [img], "pred_x0": [img]}
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
) # TODO: deterministic forward pass?
img = img_orig * mask + (1.0 - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates["x_inter"].append(img)
intermediates["pred_x0"].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
dynamic_threshold=None,
):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:
model_output = self.model.apply_model(x, t, c)
else:
x_in = x
t_in = t
assert isinstance(c, dict)
assert isinstance(unconditional_conditioning, dict)
model_uncond = self.model.apply_model(
x_in, t_in, unconditional_conditioning
)
model_t = self.model.apply_model(x_in, t_in, c)
model_output = model_uncond + unconditional_guidance_scale * (
model_t - model_uncond
)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", "not implemented"
e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs
)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(
self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
callback=None,
):
num_reference_steps = (
self.ddpm_num_timesteps
if use_original_steps
else self.ddim_timesteps.shape[0]
)
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc="Encoding Image"):
t = torch.full(
(x0.shape[0],), i, device=self.model.device, dtype=torch.long
)
if unconditional_guidance_scale == 1.0:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)),
torch.cat((t, t)),
torch.cat((unconditional_conditioning, c)),
),
2,
)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond
)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
weighted_noise_pred = (
alphas_next[i].sqrt()
* ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())
* noise_pred
)
x_next = xt_weighted + weighted_noise_pred
if (
return_intermediates
and i % (num_steps // return_intermediates) == 0
and i < num_steps - 1
):
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback:
callback(i)
out = {"x_encoded": x_next, "intermediate_steps": inter_steps}
if return_intermediates:
out.update({"intermediates": intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise
)
@torch.no_grad()
def decode(
self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None,
):
timesteps = (
np.arange(self.ddpm_num_timesteps)
if use_original_steps
else self.ddim_timesteps
)
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc="Decoding image", total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full(
(x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long
)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
if callback:
callback(i)
return x_dec
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/models/ddpm.py
================================================
import torch
import torch.nn as nn
import numpy as np
from contextlib import contextmanager
from functools import partial
from tqdm import tqdm
from semanticodec.modules.decoder.latent_diffusion.util import (
exists,
default,
count_params,
instantiate_from_config,
)
from semanticodec.modules.decoder.latent_diffusion.modules.ema import LitEma
from semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (
make_beta_schedule,
extract_into_tensor,
noise_like,
)
from semanticodec.modules.decoder.latent_diffusion.models.ddim import DDIMSampler
from semanticodec.modules.decoder.latent_diffusion.util import disabled_train
from semanticodec.utils import PositionalEncoding
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
unet_config,
sampling_rate=None,
timesteps=1000,
beta_schedule="linear",
use_ema=True,
first_stage_key="image",
latent_t_size=256,
latent_f_size=16,
channels=3,
clip_denoised=True,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
given_betas=None,
v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
conditioning_key=None,
parameterization="eps", # all assuming fixed variance schedules
logvar_init=0.0,
):
super().__init__()
assert parameterization in [
"eps",
"x0",
"v",
], 'currently only supporting "eps" and "x0" and "v"'
self.parameterization = parameterization
self.state = None
assert sampling_rate is not None
self.validation_folder_name = "temp_name"
self.clip_denoised = clip_denoised
self.first_stage_key = first_stage_key
self.sampling_rate = sampling_rate
self.latent_t_size = latent_t_size
self.latent_f_size = latent_f_size
self.v_posterior = v_posterior
self.channels = channels
self.model = DiffusionWrapper(unet_config, conditioning_key)
count_params(self.model, verbose=True)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model)
# print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.register_schedule(
given_betas=given_betas,
beta_schedule=beta_schedule,
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
self.logvar = nn.Parameter(self.logvar, requires_grad=False)
self.pos_embed = PositionalEncoding(seq_length=512, embedding_dim=192)
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
if exists(given_betas):
betas = given_betas
else:
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s,
)
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert (
alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep"
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer("betas", to_torch(betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer(
"sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod))
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer(
"posterior_log_variance_clipped",
to_torch(np.log(np.maximum(posterior_variance, 1e-20))),
)
self.register_buffer(
"posterior_mean_coef1",
to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),
)
self.register_buffer(
"posterior_mean_coef2",
to_torch(
(1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
),
)
if self.parameterization == "eps":
lvlb_weights = self.betas**2 / (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
elif self.parameterization == "x0":
lvlb_weights = (
0.5
* np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod))
)
elif self.parameterization == "v":
lvlb_weights = torch.ones_like(
self.betas**2
/ (
2
* self.posterior_variance
* to_torch(alphas)
* (1 - self.alphas_cumprod)
)
)
else:
raise NotImplementedError("mu not supported")
# TODO how to choose this term
lvlb_weights[0] = lvlb_weights[1]
self.register_buffer("lvlb_weights", lvlb_weights, persistent=False)
assert not torch.isnan(self.lvlb_weights).all()
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
# if context is not None:
# print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
# if context is not None:
# print(f"{context}: Restored training weights")
def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""
mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
log_variance = extract_into_tensor(
self.log_one_minus_alphas_cumprod, t, x_start.shape
)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
* noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract_into_tensor(
self.posterior_log_variance_clipped, t, x_t.shape
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_out = self.model(x, t)
if self.parameterization == "eps":
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
if clip_denoised:
x_recon.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
x_start=x_recon, x_t=x, t=t
)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(
x=x, t=t, clip_denoised=clip_denoised
)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (
(1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()
)
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)),
desc="Sampling t",
total=self.num_timesteps,
):
img = self.p_sample(
img,
torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised,
)
return img
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)
channels = self.channels
return self.p_sample_loop(shape, return_intermediates=return_intermediates)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def predict_eps_from_z_and_v(self, x_t, t, v):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
* x_t
)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise
- extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
class LatentDiffusion(DDPM):
"""main class"""
def __init__(
self,
first_stage_config,
cond_stage_config=None,
num_timesteps_cond=None,
scale_factor=1.0,
evaluation_params={},
scale_by_std=False,
base_learning_rate=None,
*args,
**kwargs,
):
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
else:
self.device = torch.device("cpu")
self.learning_rate = base_learning_rate
self.num_timesteps_cond = default(num_timesteps_cond, 1)
self.scale_by_std = scale_by_std
self.evaluation_params = evaluation_params
assert self.num_timesteps_cond <= kwargs["timesteps"]
conditioning_key = list(cond_stage_config.keys())
self.conditioning_key = conditioning_key
super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
try:
self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
except:
self.num_downs = 0
if not scale_by_std:
self.scale_factor = scale_factor
else:
self.register_buffer("scale_factor", torch.tensor(scale_factor))
self.instantiate_first_stage(first_stage_config)
self.cond_stage_models = nn.ModuleList([])
self.clip_denoised = False
self.bbox_tokenizer = None
self.conditional_dry_run_finished = False
self.restarted_from_ckpt = False
def make_cond_schedule(
self,
):
self.cond_ids = torch.full(
size=(self.num_timesteps,),
fill_value=self.num_timesteps - 1,
dtype=torch.long,
)
ids = torch.round(
torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)
).long()
self.cond_ids[: self.num_timesteps_cond] = ids
def register_schedule(
self,
given_betas=None,
beta_schedule="linear",
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3,
):
super().register_schedule(
given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s
)
self.shorten_cond_schedule = self.num_timesteps_cond > 1
if self.shorten_cond_schedule:
self.make_cond_schedule()
def instantiate_first_stage(self, config):
model = instantiate_from_config(config)
self.first_stage_model = model.eval()
self.first_stage_model.train = disabled_train
for param in self.first_stage_model.parameters():
param.requires_grad = False
def decode_first_stage(self, z):
with torch.no_grad():
z = 1.0 / self.scale_factor * z
decoding = self.first_stage_model.decode(z)
return decoding
def mel_spectrogram_to_waveform(self, mel):
# Mel: [bs, 1, t-steps, fbins]
if len(mel.size()) == 4:
mel = mel.squeeze(1)
mel = mel.permute(0, 2, 1)
waveform = self.first_stage_model.vocoder(mel)
waveform = waveform.cpu().detach().numpy()
return waveform
def encode_first_stage(self, x):
with torch.no_grad():
return self.first_stage_model.encode(x)
@torch.no_grad()
def sample_log(
self,
cond,
batch_size,
ddim_steps,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
mask=None,
**kwargs,
):
if mask is not None:
shape = (self.channels, mask.size()[-2], mask.size()[-1])
else:
shape = (self.channels, self.latent_t_size, self.latent_f_size)
# print("Use ddim sampler")
ddim_sampler = DDIMSampler(self, device = self.device)
samples, intermediates = ddim_sampler.sample(
ddim_steps,
batch_size,
shape,
cond,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
**kwargs,
)
return samples, intermediates
def apply_model(self, x_noisy, t, cond, return_ids=False):
x_recon = self.model(x_noisy, t, cond_dict=cond)
if isinstance(x_recon, tuple) and not return_ids:
return x_recon[0]
else:
return x_recon
@torch.no_grad()
def generate_sample(
self,
quanized_feature,
ddim_steps=200,
ddim_eta=1.0,
x_T=None,
unconditional_guidance_scale=1.0,
):
batch_size = quanized_feature.shape[0]
pe = self.pos_embed(quanized_feature)
unconditional_conditioning = {}
if unconditional_guidance_scale != 1.0:
unconditional_quanized_feature = torch.cat(
[
quanized_feature * 0.0,
pe.repeat(quanized_feature.size(0), 1, 1).to(
quanized_feature.device
),
],
dim=-1,
)
unconditional_conditioning = {
"crossattn_audiomae_pooled": [
unconditional_quanized_feature,
torch.ones(
(
unconditional_quanized_feature.size(0),
unconditional_quanized_feature.size(1),
)
)
.to(unconditional_quanized_feature.device)
.float(),
]
}
quanized_feature = torch.cat(
[
quanized_feature,
pe.repeat(quanized_feature.size(0), 1, 1).to(quanized_feature.device),
],
dim=-1,
)
latent = {
"crossattn_audiomae_pooled": [
quanized_feature,
torch.ones((quanized_feature.size(0), quanized_feature.size(1)))
.to(quanized_feature.device)
.float(),
]
}
samples, _ = self.sample_log(
cond=latent,
batch_size=batch_size,
x_T=x_T,
ddim=True,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
mel = self.decode_first_stage(samples)
return self.mel_spectrogram_to_waveform(mel)
class DiffusionWrapper(nn.Module):
def __init__(self, diff_model_config, conditioning_key):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.conditioning_key = conditioning_key
def forward(self, x, t, cond_dict: dict = {}):
x = x.contiguous()
t = t.contiguous()
context_list, attn_mask_list = [], []
context, attn_mask = cond_dict["crossattn_audiomae_pooled"]
context_list.append(context)
attn_mask_list.append(attn_mask)
out = self.diffusion_model(
x,
t,
context_list=context_list,
y=None,
context_attn_mask_list=attn_mask_list,
)
return out
def extract_encoder_state_dict(checkpoint_path):
state_dict = torch.load(checkpoint_path)["state_dict"]
new_state_dict = {}
for key in state_dict.keys():
if "cond_stage_models.0" in key:
if "pos_embed.pe" in key:
continue
new_key_name = key.replace("cond_stage_models.0.", "")
new_state_dict[new_key_name] = state_dict[key]
return new_state_dict
def overlap_add_waveform(windowed_waveforms, overlap_duration=0.64):
"""
Concatenates a series of windowed waveforms with overlap, applying fade-in and fade-out effects to the overlaps.
Parameters:
- windowed_waveforms: a list of numpy arrays with shape (1, 1, samples_per_waveform)
Returns:
- A single waveform numpy array resulting from the overlap-add process.
"""
# Assuming a sampling rate of 16000 Hz and 0.64 seconds overlap
if overlap_duration < 1e-4:
return np.concatenate(windowed_waveforms, axis=-1)
sampling_rate = 16000
overlap_samples = int(overlap_duration * sampling_rate)
# Initialize the output waveform
output_waveform = np.array([]).reshape(1, 1, -1)
for i, waveform in enumerate(windowed_waveforms):
# If not the first waveform, apply fade-in at the beginning
if i > 0:
fade_in = np.linspace(0, 1, overlap_samples).reshape(1, 1, -1)
waveform[:, :, :overlap_samples] *= fade_in
# If output waveform already has content, apply fade-out to its last overlap and add the overlapping parts
if output_waveform.size > 0:
fade_out = np.linspace(1, 0, overlap_samples).reshape(1, 1, -1)
# Apply fade-out to the end of the output waveform
output_waveform[:, :, -overlap_samples:] *= fade_out
# Add the faded-in start of the current waveform to the faded-out end of the output waveform
output_waveform[:, :, -overlap_samples:] += waveform[:, :, :overlap_samples]
# Concatenate the current waveform (minus the initial overlap if not the first) to the output
if output_waveform.size == 0:
output_waveform = waveform
else:
output_waveform = np.concatenate(
(output_waveform, waveform[:, :, overlap_samples:]), axis=2
)
return output_waveform
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/__init__.py
================================================
from .sampler import DPMSolverSampler
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/dpm_solver.py
================================================
import torch
import torch.nn.functional as F
import math
class NoiseScheduleVP:
def __init__(
self,
schedule="discrete",
betas=None,
alphas_cumprod=None,
continuous_beta_0=0.1,
continuous_beta_1=20.0,
):
"""Create a wrapper class for the forward SDE (VP type).
***
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
***
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
log_alpha_t = self.marginal_log_mean_coeff(t)
sigma_t = self.marginal_std(t)
lambda_t = self.marginal_lambda(t)
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
t = self.inverse_lambda(lambda_t)
===============================================================
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
1. For discrete-time DPMs:
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
t_i = (i + 1) / N
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
Args:
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
**Important**: Please pay special attention for the args for `alphas_cumprod`:
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
alpha_{t_n} = \sqrt{\hat{alpha_n}},
and
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
2. For continuous-time DPMs:
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
schedule are the default settings in DDPM and improved-DDPM:
Args:
beta_min: A `float` number. The smallest beta for the linear schedule.
beta_max: A `float` number. The largest beta for the linear schedule.
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
T: A `float` number. The ending time of the forward process.
===============================================================
Args:
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
'linear' or 'cosine' for continuous-time DPMs.
Returns:
A wrapper object of the forward SDE (VP type).
===============================================================
Example:
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', betas=betas)
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
# For continuous-time DPMs (VPSDE), linear schedule:
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
"""
if schedule not in ["discrete", "linear", "cosine"]:
raise ValueError(
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
schedule
)
)
self.schedule = schedule
if schedule == "discrete":
if betas is not None:
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
else:
assert alphas_cumprod is not None
log_alphas = 0.5 * torch.log(alphas_cumprod)
self.total_N = len(log_alphas)
self.T = 1.0
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(
(1, -1)
)
self.log_alpha_array = log_alphas.reshape(
(
1,
-1,
)
)
else:
self.total_N = 1000
self.beta_0 = continuous_beta_0
self.beta_1 = continuous_beta_1
self.cosine_s = 0.008
self.cosine_beta_max = 999.0
self.cosine_t_max = (
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
* 2.0
* (1.0 + self.cosine_s)
/ math.pi
- self.cosine_s
)
self.cosine_log_alpha_0 = math.log(
math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)
)
self.schedule = schedule
if schedule == "cosine":
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
self.T = 0.9946
else:
self.T = 1.0
def marginal_log_mean_coeff(self, t):
"""
Compute log(alpha_t) of a given continuous-time label t in [0, T].
"""
if self.schedule == "discrete":
return interpolate_fn(
t.reshape((-1, 1)),
self.t_array.to(t.device),
self.log_alpha_array.to(t.device),
).reshape((-1))
elif self.schedule == "linear":
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
elif self.schedule == "cosine":
log_alpha_fn = lambda s: torch.log(
torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)
)
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
return log_alpha_t
def marginal_alpha(self, t):
"""
Compute alpha_t of a given continuous-time label t in [0, T].
"""
return torch.exp(self.marginal_log_mean_coeff(t))
def marginal_std(self, t):
"""
Compute sigma_t of a given continuous-time label t in [0, T].
"""
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
def marginal_lambda(self, t):
"""
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
"""
log_mean_coeff = self.marginal_log_mean_coeff(t)
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
return log_mean_coeff - log_std
def inverse_lambda(self, lamb):
"""
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
"""
if self.schedule == "linear":
tmp = (
2.0
* (self.beta_1 - self.beta_0)
* torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
)
Delta = self.beta_0**2 + tmp
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
elif self.schedule == "discrete":
log_alpha = -0.5 * torch.logaddexp(
torch.zeros((1,)).to(lamb.device), -2.0 * lamb
)
t = interpolate_fn(
log_alpha.reshape((-1, 1)),
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
torch.flip(self.t_array.to(lamb.device), [1]),
)
return t.reshape((-1,))
else:
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
t_fn = (
lambda log_alpha_t: torch.arccos(
torch.exp(log_alpha_t + self.cosine_log_alpha_0)
)
* 2.0
* (1.0 + self.cosine_s)
/ math.pi
- self.cosine_s
)
t = t_fn(log_alpha)
return t
def model_wrapper(
model,
noise_schedule,
model_type="noise",
model_kwargs={},
guidance_type="uncond",
condition=None,
unconditional_condition=None,
guidance_scale=1.0,
classifier_fn=None,
classifier_kwargs={},
):
"""Create a wrapper function for the noise prediction model.
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
We support four types of the diffusion model by setting `model_type`:
1. "noise": noise prediction model. (Trained by predicting noise).
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
3. "v": velocity prediction model. (Trained by predicting the velocity).
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
arXiv preprint arXiv:2202.00512 (2022).
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
arXiv preprint arXiv:2210.02303 (2022).
4. "score": marginal score function. (Trained by denoising score matching).
Note that the score function and the noise prediction model follows a simple relationship:
```
noise(x_t, t) = -sigma_t * score(x_t, t)
```
We support three types of guided sampling by DPMs by setting `guidance_type`:
1. "uncond": unconditional sampling by DPMs.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
The input `model` has the following format:
``
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
``
The input `classifier_fn` has the following format:
``
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
``
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
The input `model` has the following format:
``
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
``
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
arXiv preprint arXiv:2207.12598 (2022).
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
or continuous-time labels (i.e. epsilon to T).
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
``
def model_fn(x, t_continuous) -> noise:
t_input = get_model_input_time(t_continuous)
return noise_pred(model, x, t_input, **model_kwargs)
``
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
===============================================================
Args:
model: A diffusion model with the corresponding format described above.
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
model_type: A `str`. The parameterization type of the diffusion model.
"noise" or "x_start" or "v" or "score".
model_kwargs: A `dict`. A dict for the other inputs of the model function.
guidance_type: A `str`. The type of the guidance for sampling.
"uncond" or "classifier" or "classifier-free".
condition: A pytorch tensor. The condition for the guided sampling.
Only used for "classifier" or "classifier-free" guidance type.
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
Only used for "classifier-free" guidance type.
guidance_scale: A `float`. The scale for the guided sampling.
classifier_fn: A classifier function. Only used for the classifier guidance.
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
Returns:
A noise prediction model that accepts the noised data and the continuous time as the inputs.
"""
def get_model_input_time(t_continuous):
"""
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
For continuous-time DPMs, we just use `t_continuous`.
"""
if noise_schedule.schedule == "discrete":
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
else:
return t_continuous
def noise_pred_fn(x, t_continuous, cond=None):
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
t_input = get_model_input_time(t_continuous)
if cond is None:
output = model(x, t_input, **model_kwargs)
else:
output = model(x, t_input, cond, **model_kwargs)
if model_type == "noise":
return output
elif model_type == "x_start":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(
sigma_t, dims
)
elif model_type == "v":
alpha_t, sigma_t = noise_schedule.marginal_alpha(
t_continuous
), noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
elif model_type == "score":
sigma_t = noise_schedule.marginal_std(t_continuous)
dims = x.dim()
return -expand_dims(sigma_t, dims) * output
def cond_grad_fn(x, t_input):
"""
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
"""
with torch.enable_grad():
x_in = x.detach().requires_grad_(True)
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
return torch.autograd.grad(log_prob.sum(), x_in)[0]
def model_fn(x, t_continuous):
"""
The noise predicition model function that is used for DPM-Solver.
"""
if t_continuous.reshape((-1,)).shape[0] == 1:
t_continuous = t_continuous.expand((x.shape[0]))
if guidance_type == "uncond":
return noise_pred_fn(x, t_continuous)
elif guidance_type == "classifier":
assert classifier_fn is not None
t_input = get_model_input_time(t_continuous)
cond_grad = cond_grad_fn(x, t_input)
sigma_t = noise_schedule.marginal_std(t_continuous)
noise = noise_pred_fn(x, t_continuous)
return (
noise
- guidance_scale
* expand_dims(sigma_t, dims=cond_grad.dim())
* cond_grad
)
elif guidance_type == "classifier-free":
if guidance_scale == 1.0 or unconditional_condition is None:
return noise_pred_fn(x, t_continuous, cond=condition)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition, condition])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
assert model_type in ["noise", "x_start", "v"]
assert guidance_type in ["uncond", "classifier", "classifier-free"]
return model_fn
class DPM_Solver:
def __init__(
self,
model_fn,
noise_schedule,
predict_x0=False,
thresholding=False,
max_val=1.0,
):
"""Construct a DPM-Solver.
We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
Args:
model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
``
def model_fn(x, t_continuous):
return noise
``
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
[1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
"""
self.model = model_fn
self.noise_schedule = noise_schedule
self.predict_x0 = predict_x0
self.thresholding = thresholding
self.max_val = max_val
def noise_prediction_fn(self, x, t):
"""
Return the noise prediction model.
"""
return self.model(x, t)
def data_prediction_fn(self, x, t):
"""
Return the data prediction model (with thresholding).
"""
noise = self.noise_prediction_fn(x, t)
dims = x.dim()
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(
t
), self.noise_schedule.marginal_std(t)
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
if self.thresholding:
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(
torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims
)
x0 = torch.clamp(x0, -s, s) / s
return x0
def model_fn(self, x, t):
"""
Convert the model to the noise prediction model or the data prediction model.
"""
if self.predict_x0:
return self.data_prediction_fn(x, t)
else:
return self.noise_prediction_fn(x, t)
def get_time_steps(self, skip_type, t_T, t_0, N, device):
"""Compute the intermediate time steps for sampling.
Args:
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
N: A `int`. The total number of the spacing of the time steps.
device: A torch device.
Returns:
A pytorch tensor of the time steps, with the shape (N + 1,).
"""
if skip_type == "logSNR":
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
logSNR_steps = torch.linspace(
lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1
).to(device)
return self.noise_schedule.inverse_lambda(logSNR_steps)
elif skip_type == "time_uniform":
return torch.linspace(t_T, t_0, N + 1).to(device)
elif skip_type == "time_quadratic":
t_order = 2
t = (
torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)
.pow(t_order)
.to(device)
)
return t
else:
raise ValueError(
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(
skip_type
)
)
def get_orders_and_timesteps_for_singlestep_solver(
self, steps, order, skip_type, t_T, t_0, device
):
"""
Get the order of each step for sampling by the singlestep DPM-Solver.
We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
- If order == 1:
We take `steps` of DPM-Solver-1 (i.e. DDIM).
- If order == 2:
- Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of DPM-Solver-2.
- If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If order == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
============================================
Args:
order: A `int`. The max order for the solver (2 or 3).
steps: A `int`. The total number of function evaluations (NFE).
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
- 'logSNR': uniform logSNR for the time steps.
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
device: A torch device.
Returns:
orders: A list of the solver order of each step.
"""
if order == 3:
K = steps // 3 + 1
if steps % 3 == 0:
orders = [
3,
] * (
K - 2
) + [2, 1]
elif steps % 3 == 1:
orders = [
3,
] * (
K - 1
) + [1]
else:
orders = [
3,
] * (
K - 1
) + [2]
elif order == 2:
if steps % 2 == 0:
K = steps // 2
orders = [
2,
] * K
else:
K = steps // 2 + 1
orders = [
2,
] * (
K - 1
) + [1]
elif order == 1:
K = 1
orders = [
1,
] * steps
else:
raise ValueError("'order' must be '1' or '2' or '3'.")
if skip_type == "logSNR":
# To reproduce the results in DPM-Solver paper
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
else:
timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
torch.cumsum(
torch.tensor(
[
0,
]
+ orders
)
).to(device)
]
return timesteps_outer, orders
def denoise_to_zero_fn(self, x, s):
"""
Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
"""
return self.data_prediction_fn(x, s)
def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
"""
DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(
s
), ns.marginal_log_mean_coeff(t)
sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
if self.predict_x0:
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
else:
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
)
if return_intermediate:
return x_t, {"model_s": model_s}
else:
return x_t
def singlestep_dpm_solver_second_update(
self,
x,
s,
t,
r1=0.5,
model_s=None,
return_intermediate=False,
solver_type="dpm_solver",
):
"""
Singlestep solver DPM-Solver-2 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the second-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpm_solver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None:
r1 = 0.5
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
s1 = ns.inverse_lambda(lambda_s1)
log_alpha_s, log_alpha_s1, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_t = (
ns.marginal_std(s),
ns.marginal_std(s1),
ns.marginal_std(t),
)
alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_1 = torch.expm1(-h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpm_solver":
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
- (0.5 / r1)
* expand_dims(alpha_t * phi_1, dims)
* (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1.0 / r1)
* expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
* (model_s1 - model_s)
)
else:
phi_11 = torch.expm1(r1 * h)
phi_1 = torch.expm1(h)
if model_s is None:
model_s = self.model_fn(x, s)
x_s1 = (
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
if solver_type == "dpm_solver":
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (0.5 / r1)
* expand_dims(sigma_t * phi_1, dims)
* (model_s1 - model_s)
)
elif solver_type == "taylor":
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1.0 / r1)
* expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
* (model_s1 - model_s)
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1}
else:
return x_t
def singlestep_dpm_solver_third_update(
self,
x,
s,
t,
r1=1.0 / 3.0,
r2=2.0 / 3.0,
model_s=None,
model_s1=None,
return_intermediate=False,
solver_type="dpm_solver",
):
"""
Singlestep solver DPM-Solver-3 from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
r1: A `float`. The hyperparameter of the third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
model_s: A pytorch tensor. The model function evaluated at time `s`.
If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpm_solver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
solver_type
)
)
if r1 is None:
r1 = 1.0 / 3.0
if r2 is None:
r2 = 2.0 / 3.0
ns = self.noise_schedule
dims = x.dim()
lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
h = lambda_t - lambda_s
lambda_s1 = lambda_s + r1 * h
lambda_s2 = lambda_s + r2 * h
s1 = ns.inverse_lambda(lambda_s1)
s2 = ns.inverse_lambda(lambda_s2)
log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
ns.marginal_log_mean_coeff(s),
ns.marginal_log_mean_coeff(s1),
ns.marginal_log_mean_coeff(s2),
ns.marginal_log_mean_coeff(t),
)
sigma_s, sigma_s1, sigma_s2, sigma_t = (
ns.marginal_std(s),
ns.marginal_std(s1),
ns.marginal_std(s2),
ns.marginal_std(t),
)
alpha_s1, alpha_s2, alpha_t = (
torch.exp(log_alpha_s1),
torch.exp(log_alpha_s2),
torch.exp(log_alpha_t),
)
if self.predict_x0:
phi_11 = torch.expm1(-r1 * h)
phi_12 = torch.expm1(-r2 * h)
phi_1 = torch.expm1(-h)
phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
phi_2 = phi_1 / h + 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(sigma_s1 / sigma_s, dims) * x
- expand_dims(alpha_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(sigma_s2 / sigma_s, dims) * x
- expand_dims(alpha_s2 * phi_12, dims) * model_s
+ r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpm_solver":
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ (1.0 / r2)
* expand_dims(alpha_t * phi_2, dims)
* (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(sigma_t / sigma_s, dims) * x
- expand_dims(alpha_t * phi_1, dims) * model_s
+ expand_dims(alpha_t * phi_2, dims) * D1
- expand_dims(alpha_t * phi_3, dims) * D2
)
else:
phi_11 = torch.expm1(r1 * h)
phi_12 = torch.expm1(r2 * h)
phi_1 = torch.expm1(h)
phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
phi_2 = phi_1 / h - 1.0
phi_3 = phi_2 / h - 0.5
if model_s is None:
model_s = self.model_fn(x, s)
if model_s1 is None:
x_s1 = (
expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
- expand_dims(sigma_s1 * phi_11, dims) * model_s
)
model_s1 = self.model_fn(x_s1, s1)
x_s2 = (
expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
- expand_dims(sigma_s2 * phi_12, dims) * model_s
- r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
)
model_s2 = self.model_fn(x_s2, s2)
if solver_type == "dpm_solver":
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- (1.0 / r2)
* expand_dims(sigma_t * phi_2, dims)
* (model_s2 - model_s)
)
elif solver_type == "taylor":
D1_0 = (1.0 / r1) * (model_s1 - model_s)
D1_1 = (1.0 / r2) * (model_s2 - model_s)
D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
- expand_dims(sigma_t * phi_1, dims) * model_s
- expand_dims(sigma_t * phi_2, dims) * D1
- expand_dims(sigma_t * phi_3, dims) * D2
)
if return_intermediate:
return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
else:
return x_t
def multistep_dpm_solver_second_update(
self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
):
"""
Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if solver_type not in ["dpm_solver", "taylor"]:
raise ValueError(
"'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(
solver_type
)
)
ns = self.noise_schedule
dims = x.dim()
model_prev_1, model_prev_0 = model_prev_list
t_prev_1, t_prev_0 = t_prev_list
lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
if self.predict_x0:
if solver_type == "dpm_solver":
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
- 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0
)
elif solver_type == "taylor":
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)
* D1_0
)
else:
if solver_type == "dpm_solver":
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
- 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0
)
elif solver_type == "taylor":
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)
* D1_0
)
return x_t
def multistep_dpm_solver_third_update(
self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"
):
"""
Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
ns = self.noise_schedule
dims = x.dim()
model_prev_2, model_prev_1, model_prev_0 = model_prev_list
t_prev_2, t_prev_1, t_prev_0 = t_prev_list
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
ns.marginal_lambda(t_prev_2),
ns.marginal_lambda(t_prev_1),
ns.marginal_lambda(t_prev_0),
ns.marginal_lambda(t),
)
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(
t_prev_0
), ns.marginal_log_mean_coeff(t)
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
alpha_t = torch.exp(log_alpha_t)
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)
D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)
D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)
if self.predict_x0:
x_t = (
expand_dims(sigma_t / sigma_prev_0, dims) * x
- expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0
+ expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1
- expand_dims(
alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims
)
* D2
)
else:
x_t = (
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
- expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0
- expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1
- expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)
* D2
)
return x_t
def singlestep_dpm_solver_update(
self,
x,
s,
t,
order,
return_intermediate=False,
solver_type="dpm_solver",
r1=None,
r2=None,
):
"""
Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
r1: A `float`. The hyperparameter of the second-order or third-order solver.
r2: A `float`. The hyperparameter of the third-order solver.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(
x, s, t, return_intermediate=return_intermediate
)
elif order == 2:
return self.singlestep_dpm_solver_second_update(
x,
s,
t,
return_intermediate=return_intermediate,
solver_type=solver_type,
r1=r1,
)
elif order == 3:
return self.singlestep_dpm_solver_third_update(
x,
s,
t,
return_intermediate=return_intermediate,
solver_type=solver_type,
r1=r1,
r2=r2,
)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def multistep_dpm_solver_update(
self, x, model_prev_list, t_prev_list, t, order, solver_type="dpm_solver"
):
"""
Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
Args:
x: A pytorch tensor. The initial value at time `s`.
model_prev_list: A list of pytorch tensor. The previous computed model values.
t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_t: A pytorch tensor. The approximated solution at time `t`.
"""
if order == 1:
return self.dpm_solver_first_update(
x, t_prev_list[-1], t, model_s=model_prev_list[-1]
)
elif order == 2:
return self.multistep_dpm_solver_second_update(
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
elif order == 3:
return self.multistep_dpm_solver_third_update(
x, model_prev_list, t_prev_list, t, solver_type=solver_type
)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
def dpm_solver_adaptive(
self,
x,
order,
t_T,
t_0,
h_init=0.05,
atol=0.0078,
rtol=0.05,
theta=0.9,
t_err=1e-5,
solver_type="dpm_solver",
):
"""
The adaptive step size solver based on singlestep DPM-Solver.
Args:
x: A pytorch tensor. The initial value at time `t_T`.
order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
t_T: A `float`. The starting time of the sampling (default is T).
t_0: A `float`. The ending time of the sampling (default is epsilon).
h_init: A `float`. The initial step size (for logSNR).
atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
current time and `t_0` is less than `t_err`. The default setting is 1e-5.
solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
Returns:
x_0: A pytorch tensor. The approximated solution at time `t_0`.
[1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
"""
ns = self.noise_schedule
s = t_T * torch.ones((x.shape[0],)).to(x)
lambda_s = ns.marginal_lambda(s)
lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
h = h_init * torch.ones_like(s).to(x)
x_prev = x
nfe = 0
if order == 2:
r1 = 0.5
lower_update = lambda x, s, t: self.dpm_solver_first_update(
x, s, t, return_intermediate=True
)
higher_update = (
lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, solver_type=solver_type, **kwargs
)
)
elif order == 3:
r1, r2 = 1.0 / 3.0, 2.0 / 3.0
lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
)
higher_update = (
lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
)
)
else:
raise ValueError(
"For adaptive step size solver, order must be 2 or 3, got {}".format(
order
)
)
while torch.abs((s - t_0)).mean() > t_err:
t = ns.inverse_lambda(lambda_s + h)
x_lower, lower_noise_kwargs = lower_update(x, s, t)
x_higher = higher_update(x, s, t, **lower_noise_kwargs)
delta = torch.max(
torch.ones_like(x).to(x) * atol,
rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),
)
norm_fn = lambda v: torch.sqrt(
torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)
)
E = norm_fn((x_higher - x_lower) / delta).max()
if torch.all(E <= 1.0):
x = x_higher
s = t
x_prev = x_lower
lambda_s = ns.marginal_lambda(s)
h = torch.min(
theta * h * torch.float_power(E, -1.0 / order).float(),
lambda_0 - lambda_s,
)
nfe += order
print("adaptive solver nfe", nfe)
return x
def sample(
self,
x,
steps=20,
t_start=None,
t_end=None,
order=3,
skip_type="time_uniform",
method="singlestep",
lower_order_final=True,
denoise_to_zero=False,
solver_type="dpm_solver",
atol=0.0078,
rtol=0.05,
):
"""
Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
=====================================================
We support the following algorithms for both noise prediction model and data prediction model:
- 'singlestep':
Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
The total number of function evaluations (NFE) == `steps`.
Given a fixed NFE == `steps`, the sampling procedure is:
- If `order` == 1:
- Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
- If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
- If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If `order` == 3:
- Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
- If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
- If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
- If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
- 'multistep':
Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
We initialize the first `order` values by lower order multistep solvers.
Given a fixed NFE == `steps`, the sampling procedure is:
Denote K = steps.
- If `order` == 1:
- We use K steps of DPM-Solver-1 (i.e. DDIM).
- If `order` == 2:
- We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
- If `order` == 3:
- We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
- 'singlestep_fixed':
Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
- 'adaptive':
Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
(NFE) and the sample quality.
- If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
- If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
=====================================================
Some advices for choosing the algorithm:
- For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
skip_type='time_uniform', method='singlestep')
- For **guided sampling with large guidance scale** by DPMs:
Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
e.g.
>>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
>>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
skip_type='time_uniform', method='multistep')
We support three types of `skip_type`:
- 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
- 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
- 'time_quadratic': quadratic time for the time steps.
=====================================================
Args:
x: A pytorch tensor. The initial value at time `t_start`
e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
steps: A `int`. The total number of function evaluations (NFE).
t_start: A `float`. The starting time of the sampling.
If `T` is None, we use self.noise_schedule.T (default is 1.0).
t_end: A `float`. The ending time of the sampling.
If `t_end` is None, we use 1. / self.noise_schedule.total_N.
e.g. if total_N == 1000, we have `t_end` == 1e-3.
For discrete-time DPMs:
- We recommend `t_end` == 1. / self.noise_schedule.total_N.
For continuous-time DPMs:
- We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
order: A `int`. The order of DPM-Solver.
skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
for diffusion models sampling by diffusion SDEs for low-resolutional images
(such as CIFAR-10). However, we observed that such trick does not matter for
high-resolutional images. As it needs an additional NFE, we do not recommend
it for high-resolutional images.
lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
Only valid for `method=multistep` and `steps < 15`. We empirically find that
this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
(especially for steps <= 10). So we recommend to set it to be `True`.
solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
Returns:
x_end: A pytorch tensor. The approximated solution at time `t_end`.
"""
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
if method == "adaptive":
with torch.no_grad():
x = self.dpm_solver_adaptive(
x,
order=order,
t_T=t_T,
t_0=t_0,
atol=atol,
rtol=rtol,
solver_type=solver_type,
)
elif method == "multistep":
assert steps >= order
timesteps = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device
)
assert timesteps.shape[0] - 1 == steps
with torch.no_grad():
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
x = self.multistep_dpm_solver_update(
x,
model_prev_list,
t_prev_list,
vec_t,
init_order,
solver_type=solver_type,
)
model_prev_list.append(self.model_fn(x, vec_t))
t_prev_list.append(vec_t)
# Compute the remaining values by `order`-th order multistep DPM-Solver.
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final and steps < 15:
step_order = min(order, steps + 1 - step)
else:
step_order = order
x = self.multistep_dpm_solver_update(
x,
model_prev_list,
t_prev_list,
vec_t,
step_order,
solver_type=solver_type,
)
for i in range(order - 1):
t_prev_list[i] = t_prev_list[i + 1]
model_prev_list[i] = model_prev_list[i + 1]
t_prev_list[-1] = vec_t
# We do not need to evaluate the final model value.
if step < steps:
model_prev_list[-1] = self.model_fn(x, vec_t)
elif method in ["singlestep", "singlestep_fixed"]:
if method == "singlestep":
(
timesteps_outer,
orders,
) = self.get_orders_and_timesteps_for_singlestep_solver(
steps=steps,
order=order,
skip_type=skip_type,
t_T=t_T,
t_0=t_0,
device=device,
)
elif method == "singlestep_fixed":
K = steps // order
orders = [
order,
] * K
timesteps_outer = self.get_time_steps(
skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device
)
for i, order in enumerate(orders):
t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
timesteps_inner = self.get_time_steps(
skip_type=skip_type,
t_T=t_T_inner.item(),
t_0=t_0_inner.item(),
N=order,
device=device,
)
lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
h = lambda_inner[-1] - lambda_inner[0]
r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
x = self.singlestep_dpm_solver_update(
x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2
)
if denoise_to_zero:
x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
return x
#############################################################
# other utility functions
#############################################################
def interpolate_fn(x, xp, yp):
"""
A piecewise linear function y = f(x), using xp and yp as keypoints.
We implement f(x) in a differentiable way (i.e. applicable for autograd).
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
Args:
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
yp: PyTorch tensor with shape [C, K].
Returns:
The function values f(x), with shape [N, C].
"""
N, K = x.shape[0], xp.shape[1]
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
x_idx = torch.argmin(x_indices, dim=2)
cand_start_idx = x_idx - 1
start_idx = torch.where(
torch.eq(x_idx, 0),
torch.tensor(1, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
end_idx = torch.where(
torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1
)
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
start_idx2 = torch.where(
torch.eq(x_idx, 0),
torch.tensor(0, device=x.device),
torch.where(
torch.eq(x_idx, K),
torch.tensor(K - 2, device=x.device),
cand_start_idx,
),
)
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
start_y = torch.gather(
y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)
).squeeze(2)
end_y = torch.gather(
y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)
).squeeze(2)
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
return cand
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
`v`: a PyTorch tensor with shape [N].
`dim`: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,) * (dims - 1)]
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/modules/__init__.py
================================================
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/modules/attention.py
================================================
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (
checkpoint,
)
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = (
nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
if not glu
else GEGLU(dim, inner_dim)
)
self.net = nn.Sequential(
project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
)
k = k.softmax(dim=-1)
context = torch.einsum("bhdn,bhen->bhde", k, v)
out = torch.einsum("bhde,bhdn->bhen", context, q)
out = rearrange(
out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b (h w) c")
k = rearrange(k, "b c h w -> b c (h w)")
w_ = torch.einsum("bij,bjk->bik", q, k)
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, "b c h w -> b c (h w)")
w_ = rearrange(w_, "b i j -> b j i")
h_ = torch.einsum("bij,bjk->bik", v, w_)
h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
h_ = self.proj_out(h_)
return x + h_
# class CrossAttention(nn.Module):
# """
# ### Cross Attention Layer
# This falls-back to self-attention when conditional embeddings are not specified.
# """
# use_flash_attention: bool = True
# # use_flash_attention: bool = False
# def __init__(
# self,
# query_dim,
# context_dim=None,
# heads=8,
# dim_head=64,
# dropout=0.0,
# is_inplace: bool = True,
# ):
# # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
# """
# :param d_model: is the input embedding size
# :param n_heads: is the number of attention heads
# :param d_head: is the size of a attention head
# :param d_cond: is the size of the conditional embeddings
# :param is_inplace: specifies whether to perform the attention softmax computation inplace to
# save memory
# """
# super().__init__()
# self.is_inplace = is_inplace
# self.n_heads = heads
# self.d_head = dim_head
# # Attention scaling factor
# self.scale = dim_head**-0.5
# # The normal self-attention layer
# if context_dim is None:
# context_dim = query_dim
# # Query, key and value mappings
# d_attn = dim_head * heads
# self.to_q = nn.Linear(query_dim, d_attn, bias=False)
# self.to_k = nn.Linear(context_dim, d_attn, bias=False)
# self.to_v = nn.Linear(context_dim, d_attn, bias=False)
# # Final linear layer
# self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
# # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
# # Flash attention is only used if it's installed
# # and `CrossAttention.use_flash_attention` is set to `True`.
# try:
# # You can install flash attention by cloning their Github repo,
# # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
# # and then running `python setup.py install`
# from flash_attn.flash_attention import FlashAttention
# self.flash = FlashAttention()
# # Set the scale for scaled dot-product attention.
# self.flash.softmax_scale = self.scale
# # Set to `None` if it's not installed
# except ImportError:
# self.flash = None
# def forward(self, x, context=None, mask=None):
# """
# :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
# :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
# """
# # If `cond` is `None` we perform self attention
# has_cond = context is not None
# if not has_cond:
# context = x
# # Get query, key and value vectors
# q = self.to_q(x)
# k = self.to_k(context)
# v = self.to_v(context)
# # Use flash attention if it's available and the head size is less than or equal to `128`
# if (
# CrossAttention.use_flash_attention
# and self.flash is not None
# and not has_cond
# and self.d_head <= 128
# ):
# return self.flash_attention(q, k, v)
# # Otherwise, fallback to normal attention
# else:
# return self.normal_attention(q, k, v)
# def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# """
# #### Flash Attention
# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# """
# # Get batch size and number of elements along sequence axis (`width * height`)
# batch_size, seq_len, _ = q.shape
# # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
# # shape `[batch_size, seq_len, 3, n_heads * d_head]`
# qkv = torch.stack((q, k, v), dim=2)
# # Split the heads
# qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
# # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
# # fit this size.
# if self.d_head <= 32:
# pad = 32 - self.d_head
# elif self.d_head <= 64:
# pad = 64 - self.d_head
# elif self.d_head <= 128:
# pad = 128 - self.d_head
# else:
# raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
# # Pad the heads
# if pad:
# qkv = torch.cat(
# (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
# )
# # Compute attention
# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
# # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
# # TODO here I add the dtype changing
# out, _ = self.flash(qkv.type(torch.float16))
# # Truncate the extra head size
# out = out[:, :, :, : self.d_head].float()
# # Reshape to `[batch_size, seq_len, n_heads * d_head]`
# out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
# # Map to `[batch_size, height * width, d_model]` with a linear layer
# return self.to_out(out)
# def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
# """
# #### Normal Attention
# :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
# """
# # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
# q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32]
# k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32]
# v = v.view(*v.shape[:2], self.n_heads, -1)
# # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
# attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
# # Compute softmax
# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
# if self.is_inplace:
# half = attn.shape[0] // 2
# attn[half:] = attn[half:].softmax(dim=-1)
# attn[:half] = attn[:half].softmax(dim=-1)
# else:
# attn = attn.softmax(dim=-1)
# # Compute attention output
# # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
# # attn: [bs, 20, 64, 1]
# # v: [bs, 1, 20, 32]
# out = torch.einsum("bhij,bjhd->bihd", attn, v)
# # Reshape to `[batch_size, height * width, n_heads * d_head]`
# out = out.reshape(*out.shape[:2], -1)
# # Map to `[batch_size, height * width, d_model]` with a linear layer
# return self.to_out(out)
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
if exists(mask):
mask = rearrange(mask, "b ... -> b (...)")
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, "b j -> (b h) () j", h=h)
sim.masked_fill_(~(mask == 1), max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum("b i j, b j d -> b i d", attn, v)
out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim,
n_heads,
d_head,
dropout=0.0,
context_dim=None,
gated_ff=True,
checkpoint=True,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, mask=None):
if context is None:
return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
else:
return checkpoint(
self._forward, (x, context, mask), self.parameters(), self.checkpoint
)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(
self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.0,
context_dim=None,
):
super().__init__()
context_dim = context_dim
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
)
for d in range(depth)
]
)
self.proj_out = zero_module(
nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
)
def forward(self, x, context=None, mask=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, "b c h w -> b (h w) c")
for block in self.transformer_blocks:
x = block(x, context=context, mask=mask)
x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
x = self.proj_out(x)
return x + x_in
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/__init__.py
================================================
================================================
FILE: semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/model.py
================================================
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from semanticodec.modules.decoder.latent_diffusion.util import instantiate_from_config
from semanticodec.modules.decoder.latent_diffusion.modules.attention import (
LinearAttention,
)
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(to
gitextract_ma67qwgl/
├── .gitignore
├── LICENSE
├── README.md
├── semanticodec/
│ ├── __init__.py
│ ├── config.py
│ ├── main.py
│ ├── modules/
│ │ ├── __init__.py
│ │ ├── audiomae/
│ │ │ ├── AudioMAE.py
│ │ │ ├── __init__.py
│ │ │ ├── models_mae.py
│ │ │ ├── patch_embed.py
│ │ │ └── pos_embed.py
│ │ ├── decoder/
│ │ │ ├── __init__.py
│ │ │ ├── hifigan/
│ │ │ │ ├── LICENSE
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models.py
│ │ │ │ └── models_v2.py
│ │ │ ├── latent_diffusion/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── models/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── ddim.py
│ │ │ │ │ ├── ddpm.py
│ │ │ │ │ └── dpm_solver/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ └── dpm_solver.py
│ │ │ │ ├── modules/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── attention.py
│ │ │ │ │ ├── diffusionmodules/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ ├── model.py
│ │ │ │ │ │ ├── openaimodel.py
│ │ │ │ │ │ └── util.py
│ │ │ │ │ ├── distributions/
│ │ │ │ │ │ ├── __init__.py
│ │ │ │ │ │ └── distributions.py
│ │ │ │ │ ├── ema.py
│ │ │ │ │ ├── mamba.py
│ │ │ │ │ ├── nn.py
│ │ │ │ │ └── x_transformer.py
│ │ │ │ └── util.py
│ │ │ ├── latent_encoder/
│ │ │ │ ├── __init__.py
│ │ │ │ └── autoencoder.py
│ │ │ └── utilities/
│ │ │ ├── __init__.py
│ │ │ ├── audio/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── audio_processing.py
│ │ │ │ ├── stft.py
│ │ │ │ └── tools.py
│ │ │ ├── model.py
│ │ │ └── tools.py
│ │ └── encoder/
│ │ ├── __init__.py
│ │ └── encoder.py
│ └── utils.py
├── setup.py
└── test/
├── encoding.py
└── test_all_settings.py
SYMBOL INDEX (500 symbols across 32 files)
FILE: semanticodec/config.py
function get_config (line 2) | def get_config(token_rate=100, vocab_size=None, checkpoint_path=None):
FILE: semanticodec/main.py
class SemantiCodec (line 26) | class SemantiCodec(nn.Module):
method __init__ (line 27) | def __init__(
method load_audio (line 87) | def load_audio(self, filepath):
method encode (line 139) | def encode(self, filepath):
method decode (line 145) | def decode(self, tokens):
method forward (line 179) | def forward(self, filepath):
FILE: semanticodec/modules/audiomae/AudioMAE.py
class PatchEmbed_new (line 13) | class PatchEmbed_new(nn.Module):
method __init__ (line 16) | def __init__(
method get_output_shape (line 38) | def get_output_shape(self, img_size):
method forward (line 42) | def forward(self, x):
class Vanilla_AudioMAE (line 52) | class Vanilla_AudioMAE(nn.Module):
method __init__ (line 55) | def __init__(
method forward (line 72) | def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):
FILE: semanticodec/modules/audiomae/models_mae.py
class MaskedAutoencoderViT (line 27) | class MaskedAutoencoderViT(nn.Module):
method __init__ (line 30) | def __init__(
method initialize_weights (line 201) | def initialize_weights(self):
method _init_weights (line 243) | def _init_weights(self, m):
method patchify (line 253) | def patchify(self, imgs):
method unpatchify (line 288) | def unpatchify(self, x):
method random_masking (line 301) | def random_masking(self, x, mask_ratio):
method random_masking_2d (line 330) | def random_masking_2d(self, x, mask_t_prob, mask_f_prob):
method forward_encoder (line 396) | def forward_encoder(self, x, mask_ratio, mask_2d=False):
method forward_encoder_no_random_mask_no_average (line 422) | def forward_encoder_no_random_mask_no_average(self, x):
method forward_encoder_no_mask (line 446) | def forward_encoder_no_mask(self, x):
method forward_decoder (line 471) | def forward_decoder(self, x, ids_restore):
method forward_loss (line 518) | def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):
method forward (line 536) | def forward(self, imgs, mask_ratio=0.8):
function mae_vit_small_patch16_dec512d8b (line 548) | def mae_vit_small_patch16_dec512d8b(**kwargs):
function mae_vit_base_patch16_dec512d8b (line 563) | def mae_vit_base_patch16_dec512d8b(**kwargs):
function mae_vit_large_patch16_dec512d8b (line 578) | def mae_vit_large_patch16_dec512d8b(**kwargs):
function mae_vit_huge_patch14_dec512d8b (line 593) | def mae_vit_huge_patch14_dec512d8b(**kwargs):
FILE: semanticodec/modules/audiomae/patch_embed.py
class PatchEmbed_org (line 6) | class PatchEmbed_org(nn.Module):
method __init__ (line 9) | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=...
method forward (line 23) | def forward(self, x):
class PatchEmbed_new (line 33) | class PatchEmbed_new(nn.Module):
method __init__ (line 36) | def __init__(
method get_output_shape (line 58) | def get_output_shape(self, img_size):
method forward (line 62) | def forward(self, x):
class PatchEmbed3D_new (line 74) | class PatchEmbed3D_new(nn.Module):
method __init__ (line 77) | def __init__(
method get_output_shape (line 98) | def get_output_shape(self, video_size):
method forward (line 104) | def forward(self, x):
FILE: semanticodec/modules/audiomae/pos_embed.py
function get_2d_sincos_pos_embed (line 21) | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
function get_2d_sincos_pos_embed_flexible (line 39) | def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=Fal...
function get_2d_sincos_pos_embed_from_grid (line 57) | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
function get_1d_sincos_pos_embed_from_grid (line 68) | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
function interpolate_pos_embed (line 95) | def interpolate_pos_embed(model, checkpoint_model):
function interpolate_pos_embed_img2audio (line 128) | def interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, ...
function interpolate_pos_embed_audio (line 161) | def interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_...
function interpolate_patch_embed_audio (line 189) | def interpolate_patch_embed_audio(
FILE: semanticodec/modules/decoder/hifigan/__init__.py
class AttrDict (line 5) | class AttrDict(dict):
method __init__ (line 6) | def __init__(self, *args, **kwargs):
FILE: semanticodec/modules/decoder/hifigan/models.py
function init_weights (line 10) | def init_weights(m, mean=0.0, std=0.01):
function get_padding (line 16) | def get_padding(kernel_size, dilation=1):
class ResBlock (line 20) | class ResBlock(torch.nn.Module):
method __init__ (line 21) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
method forward (line 96) | def forward(self, x):
method remove_weight_norm (line 105) | def remove_weight_norm(self):
class Generator (line 112) | class Generator(torch.nn.Module):
method __init__ (line 113) | def __init__(self, h):
method forward (line 149) | def forward(self, x):
method remove_weight_norm (line 167) | def remove_weight_norm(self):
FILE: semanticodec/modules/decoder/hifigan/models_v2.py
function init_weights (line 10) | def init_weights(m, mean=0.0, std=0.01):
function get_padding (line 16) | def get_padding(kernel_size, dilation=1):
class ResBlock1 (line 20) | class ResBlock1(torch.nn.Module):
method __init__ (line 21) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
method forward (line 96) | def forward(self, x):
method remove_weight_norm (line 105) | def remove_weight_norm(self):
class ResBlock2 (line 112) | class ResBlock2(torch.nn.Module):
method __init__ (line 113) | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
method forward (line 142) | def forward(self, x):
method remove_weight_norm (line 149) | def remove_weight_norm(self):
class Generator (line 154) | class Generator(torch.nn.Module):
method __init__ (line 155) | def __init__(self, h):
method forward (line 192) | def forward(self, x):
method remove_weight_norm (line 211) | def remove_weight_norm(self):
FILE: semanticodec/modules/decoder/latent_diffusion/models/ddim.py
class DDIMSampler (line 15) | class DDIMSampler(object):
method __init__ (line 16) | def __init__(self, model, schedule="linear", device=torch.device("cuda...
method register_buffer (line 23) | def register_buffer(self, name, attr):
method make_schedule (line 29) | def make_schedule(
method sample (line 96) | def sample(
method ddim_sampling (line 167) | def ddim_sampling(
method p_sample_ddim (line 265) | def p_sample_ddim(
method encode (line 358) | def encode(
method stochastic_encode (line 434) | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
method decode (line 452) | def decode(
FILE: semanticodec/modules/decoder/latent_diffusion/models/ddpm.py
class DDPM (line 27) | class DDPM(nn.Module):
method __init__ (line 29) | def __init__(
method register_schedule (line 89) | def register_schedule(
method ema_scope (line 194) | def ema_scope(self, context=None):
method q_mean_variance (line 208) | def q_mean_variance(self, x_start, t):
method predict_start_from_noise (line 222) | def predict_start_from_noise(self, x_t, t, noise):
method q_posterior (line 229) | def q_posterior(self, x_start, x_t, t):
method p_mean_variance (line 240) | def p_mean_variance(self, x, t, clip_denoised: bool):
method p_sample (line 255) | def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
method p_sample_loop (line 268) | def p_sample_loop(self, shape):
method sample (line 285) | def sample(self, batch_size=16, return_intermediates=False):
method q_sample (line 290) | def q_sample(self, x_start, t, noise=None):
method predict_start_from_z_and_v (line 298) | def predict_start_from_z_and_v(self, x_t, t, v):
method predict_eps_from_z_and_v (line 306) | def predict_eps_from_z_and_v(self, x_t, t, v):
method get_v (line 313) | def get_v(self, x, noise, t):
class LatentDiffusion (line 320) | class LatentDiffusion(DDPM):
method __init__ (line 323) | def __init__(
method make_cond_schedule (line 371) | def make_cond_schedule(
method register_schedule (line 384) | def register_schedule(
method instantiate_first_stage (line 401) | def instantiate_first_stage(self, config):
method decode_first_stage (line 408) | def decode_first_stage(self, z):
method mel_spectrogram_to_waveform (line 414) | def mel_spectrogram_to_waveform(self, mel):
method encode_first_stage (line 423) | def encode_first_stage(self, x):
method sample_log (line 428) | def sample_log(
method apply_model (line 459) | def apply_model(self, x_noisy, t, cond, return_ids=False):
method generate_sample (line 468) | def generate_sample(
class DiffusionWrapper (line 537) | class DiffusionWrapper(nn.Module):
method __init__ (line 538) | def __init__(self, diff_model_config, conditioning_key):
method forward (line 543) | def forward(self, x, t, cond_dict: dict = {}):
function extract_encoder_state_dict (line 560) | def extract_encoder_state_dict(checkpoint_path):
function overlap_add_waveform (line 572) | def overlap_add_waveform(windowed_waveforms, overlap_duration=0.64):
FILE: semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/dpm_solver.py
class NoiseScheduleVP (line 6) | class NoiseScheduleVP:
method __init__ (line 7) | def __init__(
method marginal_log_mean_coeff (line 144) | def marginal_log_mean_coeff(self, t):
method marginal_alpha (line 163) | def marginal_alpha(self, t):
method marginal_std (line 169) | def marginal_std(self, t):
method marginal_lambda (line 175) | def marginal_lambda(self, t):
method inverse_lambda (line 183) | def inverse_lambda(self, lamb):
function model_wrapper (line 220) | def model_wrapper(
class DPM_Solver (line 405) | class DPM_Solver:
method __init__ (line 406) | def __init__(
method noise_prediction_fn (line 441) | def noise_prediction_fn(self, x, t):
method data_prediction_fn (line 447) | def data_prediction_fn(self, x, t):
method model_fn (line 466) | def model_fn(self, x, t):
method get_time_steps (line 475) | def get_time_steps(self, skip_type, t_T, t_0, N, device):
method get_orders_and_timesteps_for_singlestep_solver (line 514) | def get_orders_and_timesteps_for_singlestep_solver(
method denoise_to_zero_fn (line 604) | def denoise_to_zero_fn(self, x, s):
method dpm_solver_first_update (line 610) | def dpm_solver_first_update(self, x, s, t, model_s=None, return_interm...
method singlestep_dpm_solver_second_update (line 659) | def singlestep_dpm_solver_second_update(
method singlestep_dpm_solver_third_update (line 770) | def singlestep_dpm_solver_third_update(
method multistep_dpm_solver_second_update (line 925) | def multistep_dpm_solver_second_update(
method multistep_dpm_solver_third_update (line 996) | def multistep_dpm_solver_third_update(
method singlestep_dpm_solver_update (line 1056) | def singlestep_dpm_solver_update(
method multistep_dpm_solver_update (line 1109) | def multistep_dpm_solver_update(
method dpm_solver_adaptive (line 1141) | def dpm_solver_adaptive(
method sample (line 1233) | def sample(
function interpolate_fn (line 1457) | def interpolate_fn(x, xp, yp):
function expand_dims (line 1509) | def expand_dims(v, dims):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/attention.py
function exists (line 13) | def exists(val):
function uniq (line 17) | def uniq(arr):
function default (line 21) | def default(val, d):
function max_neg_value (line 27) | def max_neg_value(t):
function init_ (line 31) | def init_(tensor):
class GEGLU (line 39) | class GEGLU(nn.Module):
method __init__ (line 40) | def __init__(self, dim_in, dim_out):
method forward (line 44) | def forward(self, x):
class FeedForward (line 49) | class FeedForward(nn.Module):
method __init__ (line 50) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
method forward (line 64) | def forward(self, x):
function zero_module (line 68) | def zero_module(module):
function Normalize (line 77) | def Normalize(in_channels):
class LinearAttention (line 83) | class LinearAttention(nn.Module):
method __init__ (line 84) | def __init__(self, dim, heads=4, dim_head=32):
method forward (line 91) | def forward(self, x):
class SpatialSelfAttention (line 106) | class SpatialSelfAttention(nn.Module):
method __init__ (line 107) | def __init__(self, in_channels):
method forward (line 125) | def forward(self, x):
class CrossAttention (line 328) | class CrossAttention(nn.Module):
method __init__ (line 329) | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, ...
method forward (line 345) | def forward(self, x, context=None, mask=None):
class BasicTransformerBlock (line 372) | class BasicTransformerBlock(nn.Module):
method __init__ (line 373) | def __init__(
method forward (line 400) | def forward(self, x, context=None, mask=None):
method _forward (line 408) | def _forward(self, x, context=None, mask=None):
class SpatialTransformer (line 415) | class SpatialTransformer(nn.Module):
method __init__ (line 424) | def __init__(
method forward (line 458) | def forward(self, x, context=None, mask=None):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/model.py
function get_timestep_embedding (line 14) | def get_timestep_embedding(timesteps, embedding_dim):
function nonlinearity (line 35) | def nonlinearity(x):
function Normalize (line 40) | def Normalize(in_channels, num_groups=32):
class Upsample (line 46) | class Upsample(nn.Module):
method __init__ (line 47) | def __init__(self, in_channels, with_conv):
method forward (line 55) | def forward(self, x):
class UpsampleTimeStride4 (line 62) | class UpsampleTimeStride4(nn.Module):
method __init__ (line 63) | def __init__(self, in_channels, with_conv):
method forward (line 71) | def forward(self, x):
class Downsample (line 78) | class Downsample(nn.Module):
method __init__ (line 79) | def __init__(self, in_channels, with_conv):
method forward (line 89) | def forward(self, x):
class DownsampleTimeStride4 (line 99) | class DownsampleTimeStride4(nn.Module):
method __init__ (line 100) | def __init__(self, in_channels, with_conv):
method forward (line 110) | def forward(self, x):
class ResnetBlock (line 120) | class ResnetBlock(nn.Module):
method __init__ (line 121) | def __init__(
method forward (line 157) | def forward(self, x, temb):
class LinAttnBlock (line 180) | class LinAttnBlock(LinearAttention):
method __init__ (line 183) | def __init__(self, in_channels):
class AttnBlock (line 187) | class AttnBlock(nn.Module):
method __init__ (line 188) | def __init__(self, in_channels):
method forward (line 206) | def forward(self, x):
function make_attn (line 235) | def make_attn(in_channels, attn_type="vanilla"):
class Model (line 245) | class Model(nn.Module):
method __init__ (line 246) | def __init__(
method forward (line 367) | def forward(self, x, t=None, context=None):
method get_last_layer (line 416) | def get_last_layer(self):
class Encoder (line 420) | class Encoder(nn.Module):
method __init__ (line 421) | def __init__(
method forward (line 520) | def forward(self, x):
class Decoder (line 547) | class Decoder(nn.Module):
method __init__ (line 548) | def __init__(
method forward (line 654) | def forward(self, z):
class SimpleDecoder (line 690) | class SimpleDecoder(nn.Module):
method __init__ (line 691) | def __init__(self, in_channels, out_channels, *args, **kwargs):
method forward (line 724) | def forward(self, x):
class UpsampleDecoder (line 737) | class UpsampleDecoder(nn.Module):
method __init__ (line 738) | def __init__(
method forward (line 781) | def forward(self, x):
class LatentRescaler (line 795) | class LatentRescaler(nn.Module):
method __init__ (line 796) | def __init__(self, factor, in_channels, mid_channels, out_channels, de...
method forward (line 833) | def forward(self, x):
class MergedRescaleEncoder (line 851) | class MergedRescaleEncoder(nn.Module):
method __init__ (line 852) | def __init__(
method forward (line 889) | def forward(self, x):
class MergedRescaleDecoder (line 895) | class MergedRescaleDecoder(nn.Module):
method __init__ (line 896) | def __init__(
method forward (line 932) | def forward(self, x):
class Upsampler (line 938) | class Upsampler(nn.Module):
method __init__ (line 939) | def __init__(self, in_size, out_size, in_channels, out_channels, ch_mu...
method forward (line 964) | def forward(self, x):
class Resize (line 970) | class Resize(nn.Module):
method __init__ (line 971) | def __init__(self, in_channels=None, learned=False, mode="bilinear"):
method forward (line 986) | def forward(self, x, scale_factor=1.0):
class FirstStagePostProcessor (line 996) | class FirstStagePostProcessor(nn.Module):
method __init__ (line 997) | def __init__(
method instantiate_pretrained (line 1044) | def instantiate_pretrained(self, config):
method encode_with_pretrained (line 1052) | def encode_with_pretrained(self, x):
method forward (line 1058) | def forward(self, x):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/openaimodel.py
function convert_module_to_f16 (line 26) | def convert_module_to_f16(x):
function convert_module_to_f32 (line 30) | def convert_module_to_f32(x):
class AttentionPool2d (line 35) | class AttentionPool2d(nn.Module):
method __init__ (line 40) | def __init__(
method forward (line 56) | def forward(self, x):
class TimestepBlock (line 67) | class TimestepBlock(nn.Module):
method forward (line 73) | def forward(self, x, emb):
class TimestepEmbedSequential (line 79) | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
method forward (line 85) | def forward(self, x, emb, context_list=None, mask_list=None):
class Upsample (line 110) | class Upsample(nn.Module):
method __init__ (line 119) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 130) | def forward(self, x):
class TransposedUpsample (line 143) | class TransposedUpsample(nn.Module):
method __init__ (line 146) | def __init__(self, channels, out_channels=None, ks=5):
method forward (line 155) | def forward(self, x):
class Downsample (line 159) | class Downsample(nn.Module):
method __init__ (line 168) | def __init__(self, channels, use_conv, dims=2, out_channels=None, padd...
method forward (line 188) | def forward(self, x):
class ResBlock (line 193) | class ResBlock(TimestepBlock):
method __init__ (line 209) | def __init__(
method forward (line 273) | def forward(self, x, emb):
method _forward (line 284) | def _forward(self, x, emb):
class AttentionBlock (line 307) | class AttentionBlock(nn.Module):
method __init__ (line 314) | def __init__(
method forward (line 343) | def forward(self, x):
method _forward (line 349) | def _forward(self, x):
function count_flops_attn (line 358) | def count_flops_attn(model, _x, y):
class QKVAttentionLegacy (line 378) | class QKVAttentionLegacy(nn.Module):
method __init__ (line 383) | def __init__(self, n_heads):
method forward (line 387) | def forward(self, qkv):
method count_flops (line 408) | def count_flops(model, _x, y):
class QKVAttention (line 412) | class QKVAttention(nn.Module):
method __init__ (line 417) | def __init__(self, n_heads):
method forward (line 421) | def forward(self, qkv):
method count_flops (line 446) | def count_flops(model, _x, y):
class UNetModel (line 450) | class UNetModel(nn.Module):
method __init__ (line 480) | def __init__(
method convert_to_fp16 (line 825) | def convert_to_fp16(self):
method convert_to_fp32 (line 833) | def convert_to_fp32(self):
method forward (line 841) | def forward(
class EncoderUNetModel (line 891) | class EncoderUNetModel(nn.Module):
method __init__ (line 897) | def __init__(
method convert_to_fp16 (line 1070) | def convert_to_fp16(self):
method convert_to_fp32 (line 1077) | def convert_to_fp32(self):
method forward (line 1084) | def forward(self, x, timesteps):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/util.py
function make_beta_schedule (line 21) | def make_beta_schedule(
function make_ddim_timesteps (line 56) | def make_ddim_timesteps(
function make_ddim_sampling_parameters (line 79) | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbos...
function betas_for_alpha_bar (line 99) | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.9...
function extract_into_tensor (line 118) | def extract_into_tensor(a, t, x_shape):
function checkpoint (line 124) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 141) | class CheckpointFunction(torch.autograd.Function):
method forward (line 143) | def forward(ctx, run_function, length, *args):
method backward (line 153) | def backward(ctx, *output_grads):
function timestep_embedding (line 173) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function zero_module (line 200) | def zero_module(module):
function scale_module (line 209) | def scale_module(module, scale):
function mean_flat (line 218) | def mean_flat(tensor):
function normalization (line 225) | def normalization(channels):
class SiLU (line 235) | class SiLU(nn.Module):
method forward (line 236) | def forward(self, x):
class GroupNorm32 (line 240) | class GroupNorm32(nn.GroupNorm):
method forward (line 241) | def forward(self, x):
function conv_nd (line 245) | def conv_nd(dims, *args, **kwargs):
function linear (line 258) | def linear(*args, **kwargs):
function avg_pool_nd (line 265) | def avg_pool_nd(dims, *args, **kwargs):
class HybridConditioner (line 278) | class HybridConditioner(nn.Module):
method __init__ (line 279) | def __init__(self, c_concat_config, c_crossattn_config):
method forward (line 284) | def forward(self, c_concat, c_crossattn):
function noise_like (line 290) | def noise_like(shape, device, repeat=False):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/distributions/distributions.py
class AbstractDistribution (line 5) | class AbstractDistribution:
method sample (line 6) | def sample(self):
method mode (line 9) | def mode(self):
class DiracDistribution (line 13) | class DiracDistribution(AbstractDistribution):
method __init__ (line 14) | def __init__(self, value):
method sample (line 17) | def sample(self):
method mode (line 20) | def mode(self):
class DiagonalGaussianDistribution (line 24) | class DiagonalGaussianDistribution(object):
method __init__ (line 25) | def __init__(self, parameters, deterministic=False):
method sample (line 37) | def sample(self):
method kl (line 43) | def kl(self, other=None):
method nll (line 62) | def nll(self, sample, dims=[1, 2, 3]):
method mode (line 71) | def mode(self):
function normal_kl (line 75) | def normal_kl(mean1, logvar1, mean2, logvar2):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/ema.py
class LitEma (line 5) | class LitEma(nn.Module):
method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
method forward (line 29) | def forward(self, model):
method copy_to (line 52) | def copy_to(self, model):
method store (line 61) | def store(self, parameters):
method restore (line 70) | def restore(self, parameters):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/mamba.py
function count_parameters (line 6) | def count_parameters(model):
class MambaBlocks (line 19) | class MambaBlocks(nn.Module):
method __init__ (line 20) | def __init__(self, dim, n_block=4):
method forward (line 38) | def forward(self, x):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/nn.py
class GroupNorm32 (line 12) | class GroupNorm32(nn.GroupNorm):
method __init__ (line 13) | def __init__(self, num_groups, num_channels, swish, eps=1e-5):
method forward (line 17) | def forward(self, x):
function conv_nd (line 26) | def conv_nd(dims, *args, **kwargs):
function linear (line 39) | def linear(*args, **kwargs):
function avg_pool_nd (line 46) | def avg_pool_nd(dims, *args, **kwargs):
function update_ema (line 59) | def update_ema(target_params, source_params, rate=0.99):
function zero_module (line 72) | def zero_module(module):
function scale_module (line 81) | def scale_module(module, scale):
function mean_flat (line 90) | def mean_flat(tensor):
function normalization (line 97) | def normalization(channels, swish=0.0):
function timestep_embedding (line 128) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
function checkpoint (line 153) | def checkpoint(func, inputs, params, flag):
class CheckpointFunction (line 172) | class CheckpointFunction(th.autograd.Function):
method forward (line 174) | def forward(ctx, run_function, length, *args):
method backward (line 183) | def backward(ctx, *output_grads):
FILE: semanticodec/modules/decoder/latent_diffusion/modules/x_transformer.py
class AbsolutePositionalEmbedding (line 19) | class AbsolutePositionalEmbedding(nn.Module):
method __init__ (line 20) | def __init__(self, dim, max_seq_len):
method init_ (line 25) | def init_(self):
method forward (line 28) | def forward(self, x):
class FixedPositionalEmbedding (line 33) | class FixedPositionalEmbedding(nn.Module):
method __init__ (line 34) | def __init__(self, dim):
method forward (line 39) | def forward(self, x, seq_dim=1, offset=0):
function exists (line 52) | def exists(val):
function default (line 56) | def default(val, d):
function always (line 62) | def always(val):
function not_equals (line 69) | def not_equals(val):
function equals (line 76) | def equals(val):
function max_neg_value (line 83) | def max_neg_value(tensor):
function pick_and_pop (line 90) | def pick_and_pop(keys, d):
function group_dict_by_key (line 95) | def group_dict_by_key(cond, d):
function string_begins_with (line 104) | def string_begins_with(prefix, str):
function group_by_key_prefix (line 108) | def group_by_key_prefix(prefix, d):
function groupby_prefix_and_trim (line 112) | def groupby_prefix_and_trim(prefix, d):
class Scale (line 123) | class Scale(nn.Module):
method __init__ (line 124) | def __init__(self, value, fn):
method forward (line 129) | def forward(self, x, **kwargs):
class Rezero (line 134) | class Rezero(nn.Module):
method __init__ (line 135) | def __init__(self, fn):
method forward (line 140) | def forward(self, x, **kwargs):
class ScaleNorm (line 145) | class ScaleNorm(nn.Module):
method __init__ (line 146) | def __init__(self, dim, eps=1e-5):
method forward (line 152) | def forward(self, x):
class RMSNorm (line 157) | class RMSNorm(nn.Module):
method __init__ (line 158) | def __init__(self, dim, eps=1e-8):
method forward (line 164) | def forward(self, x):
class Residual (line 169) | class Residual(nn.Module):
method forward (line 170) | def forward(self, x, residual):
class GRUGating (line 174) | class GRUGating(nn.Module):
method __init__ (line 175) | def __init__(self, dim):
method forward (line 179) | def forward(self, x, residual):
class GEGLU (line 190) | class GEGLU(nn.Module):
method __init__ (line 191) | def __init__(self, dim_in, dim_out):
method forward (line 195) | def forward(self, x):
class FeedForward (line 200) | class FeedForward(nn.Module):
method __init__ (line 201) | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
method forward (line 215) | def forward(self, x):
class Attention (line 220) | class Attention(nn.Module):
method __init__ (line 221) | def __init__(
method forward (line 279) | def forward(
class AttentionLayers (line 393) | class AttentionLayers(nn.Module):
method __init__ (line 394) | def __init__(
method forward (line 514) | def forward(
class Encoder (line 587) | class Encoder(AttentionLayers):
method __init__ (line 588) | def __init__(self, **kwargs):
class TransformerWrapper (line 593) | class TransformerWrapper(nn.Module):
method __init__ (line 594) | def __init__(
method init_ (line 649) | def init_(self):
method forward (line 652) | def forward(
FILE: semanticodec/modules/decoder/latent_diffusion/util.py
function disabled_train (line 17) | def disabled_train(self, mode=True):
function get_unconditional_condition (line 23) | def get_unconditional_condition(batchsize, downsampling_rate, device):
function log_txt_as_img (line 31) | def log_txt_as_img(wh, xc, size=10):
function ismap (line 57) | def ismap(x):
function isimage (line 63) | def isimage(x):
function int16_to_float32 (line 69) | def int16_to_float32(x):
function float32_to_int16 (line 73) | def float32_to_int16(x):
function exists (line 78) | def exists(x):
function default (line 82) | def default(val, d):
function mean_flat (line 88) | def mean_flat(tensor):
function count_params (line 96) | def count_params(model, verbose=False):
function instantiate_from_config (line 103) | def instantiate_from_config(config):
function get_obj_from_str (line 113) | def get_obj_from_str(string, reload=False):
function _do_parallel_data_prefetch (line 121) | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
function parallel_data_prefetch (line 133) | def parallel_data_prefetch(
FILE: semanticodec/modules/decoder/latent_encoder/autoencoder.py
class AutoencoderKL (line 21) | class AutoencoderKL(nn.Module):
method __init__ (line 22) | def __init__(
method get_log_dir (line 83) | def get_log_dir(self):
method set_log_dir (line 88) | def set_log_dir(self, save_dir, exp_group_name, exp_name):
method init_from_ckpt (line 93) | def init_from_ckpt(self, path, ignore_keys=list()):
method encode (line 104) | def encode(self, x):
method decode (line 112) | def decode(self, z):
method decode_to_waveform (line 120) | def decode_to_waveform(self, dec):
method forward (line 131) | def forward(self, input, sample_posterior=True):
method freq_split_subband (line 146) | def freq_split_subband(self, fbank):
method freq_merge_subband (line 161) | def freq_merge_subband(self, subband_fbank):
method save_wave (line 168) | def save_wave(self, batch_wav, fname, save_dir):
method get_last_layer (line 176) | def get_last_layer(self):
method log_images (line 180) | def log_images(self, batch, train=True, only_inputs=False, waveform=No...
method tensor2numpy (line 192) | def tensor2numpy(self, tensor):
method to_rgb (line 195) | def to_rgb(self, x):
class IdentityFirstStage (line 204) | class IdentityFirstStage(torch.nn.Module):
method __init__ (line 205) | def __init__(self, *args, vq_interface=False, **kwargs):
method encode (line 209) | def encode(self, x, *args, **kwargs):
method decode (line 212) | def decode(self, x, *args, **kwargs):
method quantize (line 215) | def quantize(self, x, *args, **kwargs):
method forward (line 220) | def forward(self, x, *args, **kwargs):
FILE: semanticodec/modules/decoder/utilities/audio/audio_processing.py
function window_sumsquare (line 7) | def window_sumsquare(
function griffin_lim (line 66) | def griffin_lim(magnitudes, stft_fn, n_iters=30):
function dynamic_range_compression (line 85) | def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=...
function dynamic_range_decompression (line 94) | def dynamic_range_decompression(x, C=1):
FILE: semanticodec/modules/decoder/utilities/audio/stft.py
class STFT (line 15) | class STFT(torch.nn.Module):
method __init__ (line 18) | def __init__(self, filter_length, hop_length, win_length, window="hann"):
method transform (line 52) | def transform(self, input_data):
method inverse (line 83) | def inverse(self, magnitude, phase):
method forward (line 124) | def forward(self, input_data):
class TacotronSTFT (line 130) | class TacotronSTFT(torch.nn.Module):
method __init__ (line 131) | def __init__(
method spectral_normalize (line 151) | def spectral_normalize(self, magnitudes, normalize_fun):
method spectral_de_normalize (line 155) | def spectral_de_normalize(self, magnitudes):
method mel_spectrogram (line 159) | def mel_spectrogram(self, y, normalize_fun=torch.log):
FILE: semanticodec/modules/decoder/utilities/audio/tools.py
function get_mel_from_wav (line 9) | def get_mel_from_wav(audio, _stft):
function inv_mel_spec (line 19) | def inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):
FILE: semanticodec/modules/decoder/utilities/model.py
function get_available_checkpoint_keys (line 10) | def get_available_checkpoint_keys(model, ckpt):
function get_param_num (line 30) | def get_param_num(model):
function torch_version_orig_mod_remove (line 35) | def torch_version_orig_mod_remove(state_dict):
function get_vocoder (line 48) | def get_vocoder(config, device, mel_bins):
function vocoder_infer (line 87) | def vocoder_infer(mels, vocoder, lengths=None):
FILE: semanticodec/modules/decoder/utilities/tools.py
function load_json (line 41) | def load_json(fname):
function read_json (line 47) | def read_json(dataset_json_file):
function copy_test_subset_data (line 53) | def copy_test_subset_data(metadata, testset_copy_target_path):
function listdir_nohidden (line 72) | def listdir_nohidden(path):
function get_restore_step (line 77) | def get_restore_step(path):
function download (line 98) | def download(url, local_path, chunk_size=1024):
function md5_hash (line 110) | def md5_hash(path):
function get_ckpt_path (line 116) | def get_ckpt_path(name, root, check=False):
class KeyNotFoundError (line 127) | class KeyNotFoundError(Exception):
method __init__ (line 128) | def __init__(self, cause, keys=None, visited=None):
function retrieve (line 142) | def retrieve(
function to_device (line 225) | def to_device(data, device):
function log (line 276) | def log(logger, step=None, fig=None, audio=None, sampling_rate=22050, ta...
function get_mask_from_lengths (line 304) | def get_mask_from_lengths(lengths, max_len=None):
function expand (line 315) | def expand(values, durations):
function synth_one_sample (line 322) | def synth_one_sample(mel_input, mel_prediction, labels, vocoder):
function pad_1D (line 340) | def pad_1D(inputs, PAD=0):
function pad_2D (line 353) | def pad_2D(inputs, maxlen=None):
function pad (line 374) | def pad(input_ele, mel_max_length=None):
FILE: semanticodec/modules/encoder/encoder.py
class AudioMAEConditionQuantResEncoder (line 17) | class AudioMAEConditionQuantResEncoder(nn.Module):
method __init__ (line 18) | def __init__(
method mark_out_padding (line 97) | def mark_out_padding(self, feature, padding_cutoff_index):
method get_unconditional_condition (line 106) | def get_unconditional_condition(self, batchsize):
method quant_mem_efficient (line 124) | def quant_mem_efficient(
method unquant (line 162) | def unquant(self, tokens):
method indices_utilization_statistic (line 174) | def indices_utilization_statistic(self, indices):
method concate (line 232) | def concate(self, representation):
method get_unconditional_condition (line 247) | def get_unconditional_condition(self, batchsize):
method long_token_split_window (line 281) | def long_token_split_window(self, tokens, window_length=512, overlap=0...
method forward (line 302) | def forward(self, batch):
method _forward (line 326) | def _forward(self, batch):
method token_to_quantized_feature (line 462) | def token_to_quantized_feature(self, tokens):
method wrap_return_dict (line 471) | def wrap_return_dict(self, crossattn_audiomae_pooled, tokens):
FILE: semanticodec/utils.py
function concat_1x2 (line 8) | def concat_1x2(tensor):
function concat_2x2 (line 22) | def concat_2x2(tensor):
function extract_kaldi_fbank_feature (line 35) | def extract_kaldi_fbank_feature(waveform, sampling_rate, target_length=1...
class PositionalEncoding (line 76) | class PositionalEncoding:
method __init__ (line 77) | def __init__(self, seq_length=512, embedding_dim=192):
method __call__ (line 93) | def __call__(self, x):
FILE: setup.py
class UploadCommand (line 67) | class UploadCommand(Command):
method status (line 74) | def status(s):
method initialize_options (line 78) | def initialize_options(self):
method finalize_options (line 81) | def finalize_options(self):
method run (line 84) | def run(self):
FILE: test/test_all_settings.py
function test_semanticodec (line 4) | def test_semanticodec(token_rate, semantic_vocab_size, test_id):
Condensed preview — 51 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (399K chars).
[
{
"path": ".gitignore",
"chars": 153,
"preview": "pretrained\n*.npy\n*.wav\ng_*\n*.pyc\n*.pkl\n*.json\ncodebook_idx\nlong_audio*\noutput*\nsample*\n*.json\n*.egg-info\n.ipynb*\ntrim_ch"
},
{
"path": "LICENSE",
"chars": 1072,
"preview": "Copyright (c) 2012-2024 Scott Chacon and others\n\nPermission is hereby granted, free of charge, to any person obtaining\na"
},
{
"path": "README.md",
"chars": 3118,
"preview": "[](https://arxiv.org/abs/2405.0"
},
{
"path": "semanticodec/__init__.py",
"chars": 43,
"preview": "from semanticodec.main import SemantiCodec\n"
},
{
"path": "semanticodec/config.py",
"chars": 7857,
"preview": "\ndef get_config(token_rate=100, vocab_size=None, checkpoint_path=None):\n assert vocab_size in [4096, 8192, 16384, 327"
},
{
"path": "semanticodec/main.py",
"chars": 7249,
"preview": "from configparser import NoSectionError\nimport torch\nimport torch.nn as nn\nimport os\nimport torchaudio\nimport math\n\nfrom"
},
{
"path": "semanticodec/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/audiomae/AudioMAE.py",
"chars": 3643,
"preview": "\"\"\"\nReference Repo: https://github.com/facebookresearch/AudioMAE\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom timm.model"
},
{
"path": "semanticodec/modules/audiomae/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/audiomae/models_mae.py",
"chars": 21118,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "semanticodec/modules/audiomae/patch_embed.py",
"chars": 4381,
"preview": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple\n\n\nclass PatchEmbed_org(nn.Module):\n \"\"\"Im"
},
{
"path": "semanticodec/modules/audiomae/pos_embed.py",
"chars": 8651,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "semanticodec/modules/decoder/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/hifigan/LICENSE",
"chars": 1067,
"preview": "MIT License\n\nCopyright (c) 2020 Jungil Kong\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
},
{
"path": "semanticodec/modules/decoder/hifigan/__init__.py",
"chars": 230,
"preview": "from .models_v2 import Generator\nfrom .models import Generator as Generator_old\n\n\nclass AttrDict(dict):\n def __init__"
},
{
"path": "semanticodec/modules/decoder/hifigan/models.py",
"chars": 5524,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Conv1d, ConvTranspose1d\nfrom tor"
},
{
"path": "semanticodec/modules/decoder/hifigan/models_v2.py",
"chars": 6875,
"preview": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torch.nn import Conv1d, ConvTranspose1d, AvgPool"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/models/ddim.py",
"chars": 17304,
"preview": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom semanticodec.modules.decoder.latent_di"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/models/ddpm.py",
"chars": 21894,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom contextlib import contextmanager\nfrom functools import partia"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/__init__.py",
"chars": 38,
"preview": "from .sampler import DPMSolverSampler\n"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/dpm_solver.py",
"chars": 69050,
"preview": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass NoiseScheduleVP:\n def __init__(\n self,\n "
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/attention.py",
"chars": 15781,
"preview": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfro"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/model.py",
"chars": 34426,
"preview": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom ein"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/openaimodel.py",
"chars": 40268,
"preview": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/util.py",
"chars": 9901,
"preview": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n#"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/distributions/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/distributions/distributions.py",
"chars": 3097,
"preview": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n def sample(self):\n raise NotImplementedError()\n"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/ema.py",
"chars": 3066,
"preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n def __init__(self, model, decay=0.9999, use_num_upates="
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/mamba.py",
"chars": 1598,
"preview": "import torch\nfrom mamba_ssm import Mamba\nimport torch.nn as nn\n\n\ndef count_parameters(model):\n \"\"\"\n Calculate the "
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/nn.py",
"chars": 6334,
"preview": "\"\"\"\nVarious utilities for neural networks.\n\"\"\"\n\nimport math\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.fu"
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/modules/x_transformer.py",
"chars": 20791,
"preview": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import "
},
{
"path": "semanticodec/modules/decoder/latent_diffusion/util.py",
"chars": 6571,
"preview": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functool"
},
{
"path": "semanticodec/modules/decoder/latent_encoder/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "semanticodec/modules/decoder/latent_encoder/autoencoder.py",
"chars": 7315,
"preview": "import torch\nimport os\n\nimport torch.nn.functional as F\nfrom semanticodec.modules.decoder.latent_diffusion.modules.ema i"
},
{
"path": "semanticodec/modules/decoder/utilities/__init__.py",
"chars": 42,
"preview": "from .tools import *\nfrom .model import *\n"
},
{
"path": "semanticodec/modules/decoder/utilities/audio/__init__.py",
"chars": 73,
"preview": "from .audio_processing import *\nfrom .stft import *\nfrom .tools import *\n"
},
{
"path": "semanticodec/modules/decoder/utilities/audio/audio_processing.py",
"chars": 2642,
"preview": "import torch\nimport numpy as np\nimport librosa.util as librosa_util\nfrom scipy.signal import get_window\n\n\ndef window_sum"
},
{
"path": "semanticodec/modules/decoder/utilities/audio/stft.py",
"chars": 6302,
"preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.signal import get_window\nfrom librosa.util im"
},
{
"path": "semanticodec/modules/decoder/utilities/audio/tools.py",
"chars": 1320,
"preview": "import torch\nimport numpy as np\nfrom scipy.io.wavfile import write\nimport torchaudio\n\nfrom utilities.audio.audio_process"
},
{
"path": "semanticodec/modules/decoder/utilities/model.py",
"chars": 2663,
"preview": "import os\nimport json\n\nimport torch\nimport numpy as np\n\nimport semanticodec.modules.decoder.hifigan as hifigan\n\n\ndef get"
},
{
"path": "semanticodec/modules/decoder/utilities/tools.py",
"chars": 12429,
"preview": "# Author: Haohe Liu\n# Email: haoheliu@gmail.com\n# Date: 11 Feb 2023\n\nimport os\nimport json\n\nimport torch\nimport torch.nn"
},
{
"path": "semanticodec/modules/encoder/__init__.py",
"chars": 1,
"preview": "\n"
},
{
"path": "semanticodec/modules/encoder/encoder.py",
"chars": 18914,
"preview": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom semanticodec.modu"
},
{
"path": "semanticodec/utils.py",
"chars": 3133,
"preview": "import torch\nimport math\nimport torch.nn as nn\n\nimport torchaudio\n\n\ndef concat_1x2(tensor):\n batchsize, width, height"
},
{
"path": "setup.py",
"chars": 4203,
"preview": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# python3 setup.py sdist bdist_wheel\n\"\"\"\n@File : setup.py.py \n@C"
},
{
"path": "test/encoding.py",
"chars": 855,
"preview": "from semanticodec import SemantiCodec\nimport soundfile as sf\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab"
},
{
"path": "test/test_all_settings.py",
"chars": 1049,
"preview": "from semanticodec import SemantiCodec\nimport soundfile as sf\n\ndef test_semanticodec(token_rate, semantic_vocab_size, tes"
}
]
About this extraction
This page contains the full source code of the haoheliu/SemantiCodec-inference GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 51 files (373.1 KB), approximately 94.0k tokens, and a symbol index with 500 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.