[
  {
    "path": ".gitignore",
    "content": "pretrained\n*.npy\n*.wav\ng_*\n*.pyc\n*.pkl\n*.json\ncodebook_idx\nlong_audio*\noutput*\nsample*\n*.json\n*.egg-info\n.ipynb*\ntrim_checkpoint.py\n__pycache*\n.DS*\nbuild"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2012-2024 Scott Chacon and others\n\nPermission is hereby granted, free of charge, to any person obtaining\na copy of this software and associated documentation files (the\n\"Software\"), to deal in the Software without restriction, including\nwithout limitation the rights to use, copy, modify, merge, publish,\ndistribute, sublicense, and/or sell copies of the Software, and to\npermit persons to whom the Software is furnished to do so, subject to\nthe following conditions:\n\nThe above copyright notice and this permission notice shall be\nincluded in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND,\nEXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF\nMERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND\nNONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE\nLIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION\nOF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION\nWITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "[![arXiv](https://img.shields.io/badge/arXiv-2405.00233-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2405.00233)  [![githubio](https://img.shields.io/badge/GitHub.io-Audio_Samples-blue?logo=Github&style=flat-square)](https://haoheliu.github.io/SemantiCodec/) \n\n# SemantiCodec\nUltra-low bitrate neural audio codec with a better semantic in the latent space.\n\n**Highlight**\n- Bitrate: 0.31 kbps - 1.40 kbps\n- Token rate: 25, 50, or 100 per second\n- cpu, cuda, and mps are supported\n\n# Usage\n\n## Installation\n\n```bash\npip install git+https://github.com/haoheliu/SemantiCodec-inference.git\n```\n\n## Encoding and decoding\n\n**Checkpoints will be automatically downloaded when you initialize the SemantiCodec with the following code.**\n\n```python\nfrom semanticodec import SemantiCodec\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=16384) \n\nfilepath = \"test/test.wav\" # audio with arbitrary length\n\ntokens = semanticodec.encode(filepath)\nwaveform = semanticodec.decode(tokens)\n\n# Save the reconstruction file\nimport soundfile as sf\nsf.write(\"output.wav\", waveform[0,0], 16000)\n```\n\n## Other Settings\n\n```python\nfrom semanticodec import SemantiCodec\n\n###############Choose one of the following######################\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768) # 1.40 kbps\nsemanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=32768) # 0.70 kbps\nsemanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=32768) # 0.35 kbps\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=16384) # 1.35 kbps\nsemanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=16384) # 0.68 kbps\nsemanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=16384) # 0.34 kbps\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=8192) # 1.30 kbps\nsemanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=8192) # 0.65 kbps\nsemanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=8192) # 0.33 kbps\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=4096) # 1.25 kbps\nsemanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=4096) # 0.63 kbps\nsemanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=4096) # 0.31 kbps\n#####################################\n\nfilepath = \"test/test.wav\"\n\ntokens = semanticodec.encode(filepath)\nwaveform = semanticodec.decode(tokens)\n\nimport soundfile as sf\nsf.write(\"output.wav\", waveform[0,0], 16000)\n```\n\nIf you are interested in reusing the same evaluation pipeline and data in the paper, please refer to this [zenodo repo](https://zenodo.org/records/11047204).\n\n## Citation\nIf you find this repo helpful, please consider citing in the following format:\n\n```bibtex\n@ARTICLE{semanticodec2024,\n  author={Liu, Haohe and Xu, Xuenan and Yuan, Yi and Wu, Mengyue and Wang, Wenwu and Plumbley, Mark D.},\n  journal={IEEE Journal of Selected Topics in Signal Processing}, \n  title={SemantiCodec: An Ultra Low Bitrate Semantic Audio Codec for General Sound}, \n  year={2024},\n  volume={18},\n  number={8},\n  pages={1448-1461},\n  doi={10.1109/JSTSP.2024.3506286}\n}\n```\n\n\n![result](result.png)\n"
  },
  {
    "path": "semanticodec/__init__.py",
    "content": "from semanticodec.main import SemantiCodec\n"
  },
  {
    "path": "semanticodec/config.py",
    "content": "\ndef get_config(token_rate=100, vocab_size=None, checkpoint_path=None):\n    assert vocab_size in [4096, 8192, 16384, 32768], \"vocab_size must be 4096, 8192, 16384 or 32768\"\n    assert token_rate in [25, 50, 100], \"token_rate must be 25, 50 or 100\"\n\n    if checkpoint_path is not None:\n\n        semantic_codebook = {\n            25: {\n                4096: f\"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_2048_0.npy\",\n                8192: f\"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_4096_0.npy\",\n                16384: f\"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_8192_0.npy\",\n                32768: f\"{checkpoint_path}/codebook_idx/combine_128_audioset_dominate/codebook_16384_0.npy\",\n            },\n            50: {\n                4096: f\"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_2048_0.npy\",\n                8192: f\"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_4096_0.npy\",\n                16384: f\"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_8192_0.npy\",\n                32768: f\"{checkpoint_path}/codebook_idx/combine_256_audioset_dominate/codebook_16384_0.npy\",\n            },\n            100: {\n                4096: f\"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_2048_0.npy\",\n                8192: f\"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_4096_0.npy\",\n                16384: f\"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_8192_0.npy\",\n                32768: f\"{checkpoint_path}/codebook_idx/combine_512_audioset_dominate/codebook_16384_0.npy\",\n            },\n        }\n    else:\n        semantic_codebook = {\n            25: {\n                4096: \"codebook_idx/combine_128_audioset_dominate/codebook_2048_0.npy\",\n                8192: \"codebook_idx/combine_128_audioset_dominate/codebook_4096_0.npy\",\n                16384: \"codebook_idx/combine_128_audioset_dominate/codebook_8192_0.npy\",\n                32768: \"codebook_idx/combine_128_audioset_dominate/codebook_16384_0.npy\",\n            },\n            50: {\n                4096: \"codebook_idx/combine_256_audioset_dominate/codebook_2048_0.npy\",\n                8192: \"codebook_idx/combine_256_audioset_dominate/codebook_4096_0.npy\",\n                16384: \"codebook_idx/combine_256_audioset_dominate/codebook_8192_0.npy\",\n                32768: \"codebook_idx/combine_256_audioset_dominate/codebook_16384_0.npy\",\n            },\n            100: {\n                4096: \"codebook_idx/combine_512_audioset_dominate/codebook_2048_0.npy\",\n                8192: \"codebook_idx/combine_512_audioset_dominate/codebook_4096_0.npy\",\n                16384: \"codebook_idx/combine_512_audioset_dominate/codebook_8192_0.npy\",\n                32768: \"codebook_idx/combine_512_audioset_dominate/codebook_16384_0.npy\",\n            },\n        }\n        \n\n    basic_config = {\n    \"model\": {\n        \"params\": {\n        \"latent_t_size\": 256, \n        \"scale_by_std\": True, \n        \"sampling_rate\": 16000, \n        \"first_stage_config\": {\n            \"params\": {\n            \"monitor\": \"val/rec_loss\", \n            \"image_key\": \"fbank\", \n            \"embed_dim\": 8, \n            \"batchsize\": 16, \n            \"reload_from_ckpt\": \"/mnt/bn/lqhaoheliu/exps/checkpoints/audioldm/vae_32k/2023_06_22_vae_16k_64_4/last.ckpt\", \n            \"subband\": 1, \n            \"time_shuffle\": 1, \n            \"sampling_rate\": 16000, \n            \"ddconfig\": {\n                \"ch\": 128, \n                \"double_z\": True, \n                \"out_ch\": 1, \n                \"attn_resolutions\": [], \n                \"dropout\": 0.0, \n                \"mel_bins\": 64, \n                \"ch_mult\": [\n                1, \n                2, \n                4\n                ], \n                \"num_res_blocks\": 2, \n                \"z_channels\": 8, \n                \"downsample_time\": False, \n                \"in_channels\": 1, \n                \"resolution\": 256\n            }, \n            \"lossconfig\": {\n                \"params\": {\n                \"disc_start\": 50001, \n                \"kl_weight\": 1000.0, \n                \"disc_in_channels\": 1, \n                \"disc_weight\": 0.5\n                }, \n                \"target\": \"semanticodec.modules.decoder.latent_diffusion.modules.losses.LPIPSWithDiscriminator\"\n            }\n            }, \n            \"target\": \"semanticodec.modules.decoder.latent_encoder.autoencoder.AutoencoderKL\", \n            \"base_learning_rate\": 8e-06\n        }, \n        \"unet_config\": {\n            \"params\": {\n            \"channel_mult\": [\n                1, \n                2, \n                3, \n                5\n            ], \n            \"out_channels\": 8, \n            \"attention_resolutions\": [\n                8, \n                4, \n                2\n            ], \n            \"context_dim\": [\n                1728\n            ], \n            \"num_res_blocks\": 2, \n            \"in_channels\": 8, \n            \"image_size\": 64, \n            \"transformer_depth\": 1, \n            \"use_spatial_transformer\": True, \n            \"model_channels\": 64, \n            \"num_head_channels\": 32\n            }, \n            \"target\": \"semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.openaimodel.UNetModel\"\n        }, \n        \"base_learning_rate\": 0.0001, \n        \"channels\": 8, \n        \"linear_start\": 0.0015, \n        \"first_stage_key\": \"fbank\", \n        \"parameterization\": \"v\", \n        \"cond_stage_config\": {\n            \"crossattn_audiomae_pooled\": {\n            \"cond_stage_key\": \"ta_kaldi_fbank\", \n            \"params\": {\n                \"use_oracle\": False, \n                \"lstm_bidirectional\": True, \n                \"feature_dimension\": 768, \n                \"codebook_size\": 8192, \n                \"residual_encoder\": \"lstm\", \n                \"rvq_layers\": 0, \n                \"lstm_layer\": 4\n            }, \n            \"target\": \"semanticodec.modules.encoder.encoder.AudioMAEConditionQuantResEncoder\", \n            \"conditioning_key\": \"crossattn\"\n            }\n        }, \n        \"num_timesteps_cond\": 1, \n        \"timesteps\": 1000, \n        \"latent_f_size\": 16, \n        \"linear_end\": 0.0195\n        }, \n        \"target\": \"semanticodec.modules.decoder.latent_diffusion.models.ddpm.LatentDiffusion\"\n    }\n    }\n\n    if token_rate == 50:\n        # modify context_dim\n        basic_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"context_dim\"] = [3264]\n        # modify cond_stage_config\n        basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"lstm_layer\"] = 3\n        basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"feature_dimension\"] = 768 * 2\n    elif token_rate == 25:\n        # modify context_dim\n        basic_config[\"model\"][\"params\"][\"unet_config\"][\"params\"][\"context_dim\"] = [6336]\n        # modify cond_stage_config\n        basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"lstm_layer\"] = 2\n        basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"feature_dimension\"] = 768 * 4\n    elif token_rate == 100:\n        pass\n    else:\n        raise ValueError(\"token_rate must be 50, 25 or 100\")\n\n    if checkpoint_path is None:\n        checkpoint_path = \"semanticodec_tokenrate_%s\" % token_rate\n    else:\n        print(\"Using custom checkpoint path: %s\" % checkpoint_path)\n\n    feature_dim = basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"feature_dimension\"]\n    lstm_layers = basic_config[\"model\"][\"params\"][\"cond_stage_config\"][\"crossattn_audiomae_pooled\"][\"params\"][\"lstm_layer\"]\n    return basic_config, checkpoint_path, feature_dim, lstm_layers, semantic_codebook[token_rate][vocab_size]"
  },
  {
    "path": "semanticodec/main.py",
    "content": "from configparser import NoSectionError\nimport torch\nimport torch.nn as nn\nimport os\nimport torchaudio\nimport math\n\nfrom semanticodec.modules.encoder.encoder import AudioMAEConditionQuantResEncoder\nfrom semanticodec.modules.decoder.latent_diffusion.models.ddpm import (\n    extract_encoder_state_dict,\n    overlap_add_waveform,\n)\nfrom semanticodec.config import get_config\nfrom semanticodec.modules.decoder.latent_diffusion.util import instantiate_from_config\nfrom semanticodec.utils import extract_kaldi_fbank_feature\nfrom huggingface_hub import hf_hub_download\n\n# Constants\nSAMPLE_RATE = 16000\nSEGMENT_DURATION = 10.24\nMEL_TARGET_LENGTH = 1024\nAUDIOMAE_PATCH_DURATION = 0.16\nSEGMENT_OVERLAP_RATIO = 0.0625\n\n\nclass SemantiCodec(nn.Module):\n    def __init__(\n        self,\n        token_rate,\n        semantic_vocab_size,\n        ddim_sample_step=50,\n        cfg_scale=2.0,\n        checkpoint_path = None,\n        cache_path=\"pretrained\",\n    ):\n        super().__init__()\n        self.token_rate = token_rate\n        self.stack_factor_K = 100 / self.token_rate\n        self.ddim_sample_step = ddim_sample_step\n        self.cfg_scale = cfg_scale\n\n        if torch.cuda.is_available():\n            self.device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available(): \n            self.device = torch.device(\"mps\")\n        else:\n            self.device = torch.device(\"cpu\")\n\n        # Initialize encoder and decoder\n        config, checkpoint_path, feature_dim, lstm_layers, semanticodebook = get_config(\n            token_rate, semantic_vocab_size, checkpoint_path\n        )\n        encoder_checkpoint_path = os.path.join(checkpoint_path, \"encoder.ckpt\")\n        if not os.path.exists(encoder_checkpoint_path):\n            if not os.path.exists(cache_path):\n                os.makedirs(cache_path)\n                print(f\"checkpoint cache dir '{cache_path}' was created.\")\n            encoder_checkpoint_path = hf_hub_download(repo_id=\"haoheliu/SemantiCodec\",filename=checkpoint_path+\"/encoder.ckpt\",cache_dir=cache_path)\n        decoder_checkpoint_path = os.path.join(checkpoint_path, \"decoder.ckpt\")\n        if not os.path.exists(decoder_checkpoint_path):\n            decoder_checkpoint_path = hf_hub_download(repo_id=\"haoheliu/SemantiCodec\",filename=checkpoint_path+\"/decoder.ckpt\",cache_dir=cache_path)\n\n        if not os.path.exists(semanticodebook):\n            semanticodebook = \"/\".join(semanticodebook.split(\"/\")[-3:])\n            semanticodebook = hf_hub_download(repo_id=\"haoheliu/SemantiCodec\",filename=semanticodebook,cache_dir=cache_path)\n\n        # Initialize encoder\n        print(\"🚀 Loading SemantiCodec encoder\")\n        state_dict = torch.load(encoder_checkpoint_path, map_location=\"cpu\")\n        self.encoder = AudioMAEConditionQuantResEncoder(\n            feature_dimension=feature_dim,\n            lstm_layer=lstm_layers,\n            centroid_npy_path=semanticodebook,\n        )\n        self.encoder.load_state_dict(state_dict)\n        self.encoder = self.encoder.to(self.device)\n        print(\"✅ Encoder loaded\")\n\n        # Initialize decoder\n        print(\"🚀 Loading SemantiCodec decoder\")\n        self.decoder = instantiate_from_config(config[\"model\"])\n        checkpoint = torch.load(decoder_checkpoint_path, map_location=\"cpu\")\n        self.decoder.load_state_dict(checkpoint)\n        self.decoder = self.decoder.to(self.device)\n        print(\"✅ Decoder loaded\")\n\n    def load_audio(self, filepath):\n        if not os.path.exists(filepath):\n            raise FileNotFoundError(f\"{filepath} does not exist\")\n\n        assert isinstance(filepath, str)\n        waveform, sr = torchaudio.load(filepath)\n        # resample to 16000\n        if sr != SAMPLE_RATE:\n            waveform = torchaudio.functional.resample(waveform, sr, SAMPLE_RATE)\n            sr = SAMPLE_RATE\n        # if stereo to mono\n        if waveform.shape[0] > 1:\n            waveform = waveform[0:1]\n        # Calculate the original duration\n        original_duration = waveform.shape[1] / sr\n        # This is to pad the audio to the multiplication of 0.16 seconds so that the original audio can be reconstructed\n        original_duration = original_duration + (\n            AUDIOMAE_PATCH_DURATION - original_duration % AUDIOMAE_PATCH_DURATION\n        )\n        # Calculate the token length in theory\n        target_token_len = (\n            8 * original_duration / AUDIOMAE_PATCH_DURATION / self.stack_factor_K\n        )\n        segment_sample_length = int(SAMPLE_RATE * SEGMENT_DURATION)\n        # Pad audio to the multiplication of 10.24 seconds for easier segmentations\n\n        if waveform.shape[1] % segment_sample_length < segment_sample_length:\n            waveform = torch.cat(\n                [\n                    waveform,\n                    torch.zeros(\n                        1,\n                        int(\n                            segment_sample_length\n                            - waveform.shape[1] % segment_sample_length\n                        ),\n                    ),\n                ],\n                dim=1,\n            )\n\n        mel_target_length = MEL_TARGET_LENGTH * int(\n            waveform.shape[1] / segment_sample_length\n        )\n        # Calculate the mel spectrogram\n        mel = extract_kaldi_fbank_feature(\n            waveform, sr, target_length=mel_target_length\n        )[\"ta_kaldi_fbank\"].unsqueeze(0)\n        mel = mel.squeeze(1)\n        assert mel.shape[-1] == 128 and mel.shape[-2] % 1024 == 0\n        return mel, target_token_len\n\n    def encode(self, filepath):\n        mel, target_token_len = self.load_audio(filepath)\n        tokens = self.encoder(mel.to(self.device))\n        tokens = tokens[:, : math.ceil(target_token_len), :]\n        return tokens\n\n    def decode(self, tokens):\n        windowed_token_list = self.encoder.long_token_split_window(\n            tokens,\n            window_length=int(512 / self.stack_factor_K),\n            overlap=SEGMENT_OVERLAP_RATIO,\n        )\n        windowed_waveform = []\n        for _, windowed_token in enumerate(windowed_token_list):\n            latent = self.encoder.token_to_quantized_feature(windowed_token)\n            latent = torch.cat(\n                [\n                    latent,\n                    torch.ones(\n                        latent.shape[0],\n                        int(512 / self.stack_factor_K) - latent.shape[1],\n                        latent.shape[2],\n                    ).to(latent.device)\n                    * -1,\n                ],\n                dim=1,\n            )\n            waveform = self.decoder.generate_sample(\n                latent,\n                ddim_steps=self.ddim_sample_step,\n                unconditional_guidance_scale=self.cfg_scale,\n            )\n            windowed_waveform.append(waveform)\n        output = overlap_add_waveform(\n            windowed_waveform, overlap_duration=SEGMENT_DURATION * SEGMENT_OVERLAP_RATIO\n        )\n        # Each patch step equal 16 mel time frames, which have 0.01 second\n        trim_duration = (tokens.shape[1] / 8) * 16 * 0.01 * self.stack_factor_K\n        return output[..., : int(trim_duration * SAMPLE_RATE)]\n\n    def forward(self, filepath):\n        tokens = self.encode(filepath)\n        waveform = self.decode(tokens)\n        return waveform\n"
  },
  {
    "path": "semanticodec/modules/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/audiomae/AudioMAE.py",
    "content": "\"\"\"\nReference Repo: https://github.com/facebookresearch/AudioMAE\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple\nimport semanticodec.modules.audiomae.models_mae as models_mae\n\n# model = mae_vit_base_patch16(in_chans=1, audio_exp=True, img_size=(1024, 128))\n\n\nclass PatchEmbed_new(nn.Module):\n    \"\"\"Flexible Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        stride = to_2tuple(stride)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=stride\n        )  # with overlapped patches\n        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])\n        # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        _, _, h, w = self.get_output_shape(img_size)  # n, emb_dim, h, w\n        self.patch_hw = (h, w)\n        self.num_patches = h * w\n\n    def get_output_shape(self, img_size):\n        # todo: don't be lazy..\n        return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #    f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        x = x.flatten(2).transpose(1, 2)\n        return x\n\n\nclass Vanilla_AudioMAE(nn.Module):\n    \"\"\"Audio Masked Autoencoder (MAE) pre-trained on AudioSet (for AudioLDM)\"\"\"\n\n    def __init__(\n        self,\n    ):\n        super().__init__()\n        model = models_mae.__dict__[\"mae_vit_base_patch16\"](\n            in_chans=1, audio_exp=True, img_size=(1024, 128)\n        )\n\n        # checkpoint_path = \"/mnt/bn/lqhaoheliu/exps/checkpoints/audiomae/pretrained.pth\"\n        # checkpoint = torch.load(checkpoint_path, map_location=\"cpu\")\n        # model.load_state_dict(checkpoint[\"model\"], strict=False)\n\n        # Skip the missing keys of decoder modules (not required)\n        # print(f'Load AudioMAE from {checkpoint_path} / message: {msg}')\n\n        self.model = model.eval()\n\n    def forward(self, x, mask_ratio=0.0, no_mask=False, no_average=False):\n        \"\"\"\n        x: mel fbank [Batch, 1, 1024 (T), 128 (F)]\n        mask_ratio: 'masking ratio (percentage of removed patches).'\n        \"\"\"\n        with torch.no_grad():\n            # embed: [B, 513, 768] for mask_ratio=0.0\n            if no_mask:\n                if no_average:\n                    raise RuntimeError(\"This function is deprecated\")\n                    embed = self.model.forward_encoder_no_random_mask_no_average(\n                        x\n                    )  # mask_ratio\n                else:\n                    embed = self.model.forward_encoder_no_mask(x)  # mask_ratio\n            else:\n                raise RuntimeError(\"This function is deprecated\")\n                embed, _, _, _ = self.model.forward_encoder(x, mask_ratio=mask_ratio)\n        return embed\n\n\nif __name__ == \"__main__\":\n    model = Vanilla_AudioMAE().cuda()\n    input = torch.randn(4, 1, 1024, 128).cuda()\n    print(\"The first run\")\n    embed = model(input, mask_ratio=0.0, no_mask=True)\n    print(embed)\n    print(\"The second run\")\n    embed = model(input, mask_ratio=0.0)\n    print(embed)\n"
  },
  {
    "path": "semanticodec/modules/audiomae/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/audiomae/models_mae.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n# --------------------------------------------------------\n# References:\n# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm\n# DeiT: https://github.com/facebookresearch/deit\n# --------------------------------------------------------\n\nfrom functools import partial\nfrom json import encoder\n\nimport torch\nimport torch.nn as nn\n\nfrom timm.models.vision_transformer import Block\nfrom semanticodec.modules.audiomae.pos_embed import (\n    get_2d_sincos_pos_embed,\n    get_2d_sincos_pos_embed_flexible,\n    get_1d_sincos_pos_embed_from_grid,\n)\nfrom semanticodec.modules.audiomae.patch_embed import PatchEmbed_new, PatchEmbed_org\n\n\nclass MaskedAutoencoderViT(nn.Module):\n    \"\"\"Masked Autoencoder with VisionTransformer backbone\"\"\"\n\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        stride=10,\n        in_chans=3,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        decoder_embed_dim=512,\n        decoder_depth=8,\n        decoder_num_heads=16,\n        mlp_ratio=4.0,\n        norm_layer=nn.LayerNorm,\n        norm_pix_loss=False,\n        audio_exp=False,\n        alpha=0.0,\n        temperature=0.2,\n        mode=0,\n        contextual_depth=8,\n        use_custom_patch=False,\n        split_pos=False,\n        pos_trainable=False,\n        use_nce=False,\n        beta=4.0,\n        decoder_mode=0,\n        mask_t_prob=0.6,\n        mask_f_prob=0.5,\n        mask_2d=False,\n        epoch=0,\n        no_shift=False,\n    ):\n        super().__init__()\n\n        self.audio_exp = audio_exp\n        self.embed_dim = embed_dim\n        self.decoder_embed_dim = decoder_embed_dim\n        # --------------------------------------------------------------------------\n        # MAE encoder specifics\n        if use_custom_patch:\n            print(\n                f\"Use custom patch_emb with patch size: {patch_size}, stride: {stride}\"\n            )\n            self.patch_embed = PatchEmbed_new(\n                img_size=img_size,\n                patch_size=patch_size,\n                in_chans=in_chans,\n                embed_dim=embed_dim,\n                stride=stride,\n            )\n        else:\n            self.patch_embed = PatchEmbed_org(img_size, patch_size, in_chans, embed_dim)\n        self.use_custom_patch = use_custom_patch\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n\n        # self.split_pos = split_pos # not useful\n        self.pos_embed = nn.Parameter(\n            torch.zeros(1, num_patches + 1, embed_dim), requires_grad=pos_trainable\n        )  # fixed sin-cos embedding\n\n        self.encoder_depth = depth\n        self.contextual_depth = contextual_depth\n        self.blocks = nn.ModuleList(\n            [\n                Block(\n                    embed_dim,\n                    num_heads,\n                    mlp_ratio,\n                    qkv_bias=True,\n                    norm_layer=norm_layer,\n                )  # qk_scale=None\n                for i in range(depth)\n            ]\n        )\n        self.norm = norm_layer(embed_dim)\n\n        # --------------------------------------------------------------------------\n        # MAE decoder specifics\n        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)\n\n        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))\n        self.decoder_pos_embed = nn.Parameter(\n            torch.zeros(1, num_patches + 1, decoder_embed_dim),\n            requires_grad=pos_trainable,\n        )  # fixed sin-cos embedding\n\n        self.no_shift = no_shift\n\n        self.decoder_mode = decoder_mode\n        if (\n            self.use_custom_patch\n        ):  # overlapped patches as in AST. Similar performance yet compute heavy\n            window_size = (6, 6)\n            feat_size = (102, 12)\n        else:\n            window_size = (4, 4)\n            feat_size = (64, 8)\n        if self.decoder_mode == 1:\n            decoder_modules = []\n            for index in range(16):\n                if self.no_shift:\n                    shift_size = (0, 0)\n                else:\n                    if (index % 2) == 0:\n                        shift_size = (0, 0)\n                    else:\n                        shift_size = (2, 0)\n                    # shift_size = tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size])\n                decoder_modules.append(\n                    SwinTransformerBlock(\n                        dim=decoder_embed_dim,\n                        num_heads=16,\n                        feat_size=feat_size,\n                        window_size=window_size,\n                        shift_size=shift_size,\n                        mlp_ratio=mlp_ratio,\n                        drop=0.0,\n                        drop_attn=0.0,\n                        drop_path=0.0,\n                        extra_norm=False,\n                        sequential_attn=False,\n                        norm_layer=norm_layer,  # nn.LayerNorm,\n                    )\n                )\n            self.decoder_blocks = nn.ModuleList(decoder_modules)\n        else:\n            # Transfomer\n            self.decoder_blocks = nn.ModuleList(\n                [\n                    Block(\n                        decoder_embed_dim,\n                        decoder_num_heads,\n                        mlp_ratio,\n                        qkv_bias=True,\n                        norm_layer=norm_layer,\n                    )  # qk_scale=None,\n                    for i in range(decoder_depth)\n                ]\n            )\n\n        self.decoder_norm = norm_layer(decoder_embed_dim)\n        self.decoder_pred = nn.Linear(\n            decoder_embed_dim, patch_size**2 * in_chans, bias=True\n        )  # decoder to patch\n\n        # --------------------------------------------------------------------------\n\n        self.norm_pix_loss = norm_pix_loss\n\n        self.patch_size = patch_size\n        self.stride = stride\n\n        # audio exps\n        self.alpha = alpha\n        self.T = temperature\n        self.mode = mode\n        self.use_nce = use_nce\n        self.beta = beta\n\n        self.log_softmax = nn.LogSoftmax(dim=-1)\n\n        self.mask_t_prob = mask_t_prob\n        self.mask_f_prob = mask_f_prob\n        self.mask_2d = mask_2d\n\n        self.epoch = epoch\n\n        self.initialize_weights()\n\n    def initialize_weights(self):\n        # initialization\n        # initialize (and freeze) pos_embed by sin-cos embedding\n        if self.audio_exp:\n            pos_embed = get_2d_sincos_pos_embed_flexible(\n                self.pos_embed.shape[-1], self.patch_embed.patch_hw, cls_token=True\n            )\n        else:\n            pos_embed = get_2d_sincos_pos_embed(\n                self.pos_embed.shape[-1],\n                int(self.patch_embed.num_patches**0.5),\n                cls_token=True,\n            )\n        self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))\n\n        if self.audio_exp:\n            decoder_pos_embed = get_2d_sincos_pos_embed_flexible(\n                self.decoder_pos_embed.shape[-1],\n                self.patch_embed.patch_hw,\n                cls_token=True,\n            )\n        else:\n            decoder_pos_embed = get_2d_sincos_pos_embed(\n                self.decoder_pos_embed.shape[-1],\n                int(self.patch_embed.num_patches**0.5),\n                cls_token=True,\n            )\n        self.decoder_pos_embed.data.copy_(\n            torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)\n        )\n\n        # initialize patch_embed like nn.Linear (instead of nn.Conv2d)\n        w = self.patch_embed.proj.weight.data\n        torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))\n\n        # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)\n        torch.nn.init.normal_(self.cls_token, std=0.02)\n        torch.nn.init.normal_(self.mask_token, std=0.02)\n\n        # initialize nn.Linear and nn.LayerNorm\n        self.apply(self._init_weights)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            # we use xavier_uniform following official JAX ViT:\n            torch.nn.init.xavier_uniform_(m.weight)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n        elif isinstance(m, nn.LayerNorm):\n            nn.init.constant_(m.bias, 0)\n            nn.init.constant_(m.weight, 1.0)\n\n    def patchify(self, imgs):\n        \"\"\"\n        imgs: (N, 3, H, W)\n        x: (N, L, patch_size**2 *3)\n        L = (H/p)*(W/p)\n        \"\"\"\n        p = self.patch_embed.patch_size[0]\n        # assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0\n\n        if self.audio_exp:\n            if self.use_custom_patch:  # overlapped patch\n                h, w = self.patch_embed.patch_hw\n                # todo: fixed h/w patch size and stride size. Make hw custom in the future\n                x = imgs.unfold(2, self.patch_size, self.stride).unfold(\n                    3, self.patch_size, self.stride\n                )  # n,1,H,W -> n,1,h,w,p,p\n                x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))\n                # x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))\n                # x = torch.einsum('nchpwq->nhwpqc', x)\n                # x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))\n            else:\n                h = imgs.shape[2] // p\n                w = imgs.shape[3] // p\n                # h,w = self.patch_embed.patch_hw\n                x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p))\n                x = torch.einsum(\"nchpwq->nhwpqc\", x)\n                x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1))\n        else:\n            h = w = imgs.shape[2] // p\n            x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))\n            x = torch.einsum(\"nchpwq->nhwpqc\", x)\n            x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))\n\n        return x\n\n    def unpatchify(self, x):\n        \"\"\"\n        x: (N, L, patch_size**2 *3)\n        specs: (N, 1, H, W)\n        \"\"\"\n        p = self.patch_embed.patch_size[0]\n        h = 1024 // p\n        w = 128 // p\n        x = x.reshape(shape=(x.shape[0], h, w, p, p, 1))\n        x = torch.einsum(\"nhwpqc->nchpwq\", x)\n        specs = x.reshape(shape=(x.shape[0], 1, h * p, w * p))\n        return specs\n\n    def random_masking(self, x, mask_ratio):\n        \"\"\"\n        Perform per-sample random masking by per-sample shuffling.\n        Per-sample shuffling is done by argsort random noise.\n        x: [N, L, D], sequence\n        \"\"\"\n        N, L, D = x.shape  # batch, length, dim\n        len_keep = int(L * (1 - mask_ratio))\n\n        noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]\n\n        # sort noise for each sample\n        ids_shuffle = torch.argsort(\n            noise, dim=1\n        )  # ascend: small is keep, large is remove\n        ids_restore = torch.argsort(ids_shuffle, dim=1)\n\n        # keep the first subset\n        ids_keep = ids_shuffle[:, :len_keep]\n        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))\n\n        # generate the binary mask: 0 is keep, 1 is remove\n        mask = torch.ones([N, L], device=x.device)\n        mask[:, :len_keep] = 0\n        # unshuffle to get the binary mask\n        mask = torch.gather(mask, dim=1, index=ids_restore)\n\n        return x_masked, mask, ids_restore\n\n    def random_masking_2d(self, x, mask_t_prob, mask_f_prob):\n        \"\"\"\n        2D: Spectrogram (msking t and f under mask_t_prob and mask_f_prob)\n        Perform per-sample random masking by per-sample shuffling.\n        Per-sample shuffling is done by argsort random noise.\n        x: [N, L, D], sequence\n        \"\"\"\n        N, L, D = x.shape  # batch, length, dim\n        if self.use_custom_patch:  # overlapped patch\n            T = 101\n            F = 12\n        else:\n            T = 64\n            F = 8\n        # x = x.reshape(N, T, F, D)\n        len_keep_t = int(T * (1 - mask_t_prob))\n        len_keep_f = int(F * (1 - mask_f_prob))\n\n        # noise for mask in time\n        noise_t = torch.rand(N, T, device=x.device)  # noise in [0, 1]\n        # sort noise for each sample aling time\n        ids_shuffle_t = torch.argsort(\n            noise_t, dim=1\n        )  # ascend: small is keep, large is remove\n        ids_restore_t = torch.argsort(ids_shuffle_t, dim=1)\n        ids_keep_t = ids_shuffle_t[:, :len_keep_t]\n        # noise mask in freq\n        noise_f = torch.rand(N, F, device=x.device)  # noise in [0, 1]\n        ids_shuffle_f = torch.argsort(\n            noise_f, dim=1\n        )  # ascend: small is keep, large is remove\n        ids_restore_f = torch.argsort(ids_shuffle_f, dim=1)\n        ids_keep_f = ids_shuffle_f[:, :len_keep_f]  #\n\n        # generate the binary mask: 0 is keep, 1 is remove\n        # mask in freq\n        mask_f = torch.ones(N, F, device=x.device)\n        mask_f[:, :len_keep_f] = 0\n        mask_f = (\n            torch.gather(mask_f, dim=1, index=ids_restore_f)\n            .unsqueeze(1)\n            .repeat(1, T, 1)\n        )  # N,T,F\n        # mask in time\n        mask_t = torch.ones(N, T, device=x.device)\n        mask_t[:, :len_keep_t] = 0\n        mask_t = (\n            torch.gather(mask_t, dim=1, index=ids_restore_t)\n            .unsqueeze(1)\n            .repeat(1, F, 1)\n            .permute(0, 2, 1)\n        )  # N,T,F\n        mask = 1 - (1 - mask_t) * (1 - mask_f)  # N, T, F\n\n        # get masked x\n        id2res = torch.Tensor(list(range(N * T * F))).reshape(N, T, F).to(x.device)\n        id2res = id2res + 999 * mask  # add a large value for masked elements\n        id2res2 = torch.argsort(id2res.flatten(start_dim=1))\n        ids_keep = id2res2.flatten(start_dim=1)[:, : len_keep_f * len_keep_t]\n        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))\n\n        ids_restore = torch.argsort(id2res2.flatten(start_dim=1))\n        mask = mask.flatten(start_dim=1)\n\n        return x_masked, mask, ids_restore\n\n    def forward_encoder(self, x, mask_ratio, mask_2d=False):\n        # embed patches\n        x = self.patch_embed(x)\n        # add pos embed w/o cls token\n        x = x + self.pos_embed[:, 1:, :]\n\n        # masking: length -> length * mask_ratio\n        if mask_2d:\n            x, mask, ids_restore = self.random_masking_2d(\n                x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob\n            )\n        else:\n            x, mask, ids_restore = self.random_masking(x, mask_ratio)\n\n        # append cls token\n        cls_token = self.cls_token + self.pos_embed[:, :1, :]\n        cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        # apply Transformer blocks\n        for blk in self.blocks:\n            x = blk(x)\n        x = self.norm(x)\n\n        return x, mask, ids_restore, None\n\n    def forward_encoder_no_random_mask_no_average(self, x):\n        # embed patches\n        x = self.patch_embed(x)\n        # add pos embed w/o cls token\n        x = x + self.pos_embed[:, 1:, :]\n\n        # masking: length -> length * mask_ratio\n        # if mask_2d:\n        #     x, mask, ids_restore = self.random_masking_2d(x, mask_t_prob=self.mask_t_prob, mask_f_prob=self.mask_f_prob)\n        # else:\n        #     x, mask, ids_restore = self.random_masking(x, mask_ratio)\n\n        # append cls token\n        cls_token = self.cls_token + self.pos_embed[:, :1, :]\n        cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        # apply Transformer blocks\n        for blk in self.blocks:\n            x = blk(x)\n        x = self.norm(x)\n\n        return x\n\n    def forward_encoder_no_mask(self, x):\n        # embed patches\n        x = self.patch_embed(x)\n\n        # add pos embed w/o cls token\n        x = x + self.pos_embed[:, 1:, :]\n\n        # masking: length -> length * mask_ratio\n        # x, mask, ids_restore = self.random_masking(x, mask_ratio)\n        # append cls token\n        cls_token = self.cls_token + self.pos_embed[:, :1, :]\n        cls_tokens = cls_token.expand(x.shape[0], -1, -1)\n        x = torch.cat((cls_tokens, x), dim=1)\n\n        # apply Transformer blocks\n        contextual_embs = []\n        for n, blk in enumerate(self.blocks):\n            x = blk(x)\n            if n > self.contextual_depth:\n                contextual_embs.append(self.norm(x))\n        # x = self.norm(x)\n        contextual_emb = torch.stack(contextual_embs, dim=0).mean(dim=0)\n\n        return contextual_emb\n\n    def forward_decoder(self, x, ids_restore):\n        # embed tokens\n        x = self.decoder_embed(x)\n\n        # append mask tokens to sequence\n        mask_tokens = self.mask_token.repeat(\n            x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1\n        )\n        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token\n        x_ = torch.gather(\n            x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])\n        )  # unshuffle\n        x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token\n\n        # add pos embed\n        x = x + self.decoder_pos_embed\n\n        if self.decoder_mode != 0:\n            B, L, D = x.shape\n            x = x[:, 1:, :]\n            if self.use_custom_patch:\n                x = x.reshape(B, 101, 12, D)\n                x = torch.cat([x, x[:, -1, :].unsqueeze(1)], dim=1)  # hack\n                x = x.reshape(B, 1224, D)\n        if self.decoder_mode > 3:  # mvit\n            x = self.decoder_blocks(x)\n        else:\n            # apply Transformer blocks\n            for blk in self.decoder_blocks:\n                x = blk(x)\n        x = self.decoder_norm(x)\n\n        # predictor projection\n        pred = self.decoder_pred(x)\n\n        # remove cls token\n        if self.decoder_mode != 0:\n            if self.use_custom_patch:\n                pred = pred.reshape(B, 102, 12, 256)\n                pred = pred[:, :101, :, :]\n                pred = pred.reshape(B, 1212, 256)\n            else:\n                pred = pred\n        else:\n            pred = pred[:, 1:, :]\n        return pred, None, None  # emb, emb_pixel\n\n    def forward_loss(self, imgs, pred, mask, norm_pix_loss=False):\n        \"\"\"\n        imgs: [N, 3, H, W]\n        pred: [N, L, p*p*3]\n        mask: [N, L], 0 is keep, 1 is remove,\n        \"\"\"\n        target = self.patchify(imgs)\n        if norm_pix_loss:\n            mean = target.mean(dim=-1, keepdim=True)\n            var = target.var(dim=-1, keepdim=True)\n            target = (target - mean) / (var + 1.0e-6) ** 0.5\n\n        loss = (pred - target) ** 2\n        loss = loss.mean(dim=-1)  # [N, L], mean loss per patch\n\n        loss = (loss * mask).sum() / mask.sum()  # mean loss on removed patches\n        return loss\n\n    def forward(self, imgs, mask_ratio=0.8):\n        emb_enc, mask, ids_restore, _ = self.forward_encoder(\n            imgs, mask_ratio, mask_2d=self.mask_2d\n        )\n        pred, _, _ = self.forward_decoder(emb_enc, ids_restore)  # [N, L, p*p*3]\n        loss_recon = self.forward_loss(\n            imgs, pred, mask, norm_pix_loss=self.norm_pix_loss\n        )\n        loss_contrastive = torch.FloatTensor([0.0]).cuda()\n        return loss_recon, pred, mask, loss_contrastive\n\n\ndef mae_vit_small_patch16_dec512d8b(**kwargs):\n    model = MaskedAutoencoderViT(\n        patch_size=16,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        decoder_embed_dim=512,\n        decoder_num_heads=16,\n        mlp_ratio=4,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        **kwargs,\n    )\n    return model\n\n\ndef mae_vit_base_patch16_dec512d8b(**kwargs):\n    model = MaskedAutoencoderViT(\n        patch_size=16,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        decoder_embed_dim=512,\n        decoder_num_heads=16,\n        mlp_ratio=4,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        **kwargs,\n    )\n    return model\n\n\ndef mae_vit_large_patch16_dec512d8b(**kwargs):\n    model = MaskedAutoencoderViT(\n        patch_size=16,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        decoder_embed_dim=512,\n        decoder_num_heads=16,\n        mlp_ratio=4,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        **kwargs,\n    )\n    return model\n\n\ndef mae_vit_huge_patch14_dec512d8b(**kwargs):\n    model = MaskedAutoencoderViT(\n        patch_size=14,\n        embed_dim=1280,\n        depth=32,\n        num_heads=16,\n        decoder_embed_dim=512,\n        decoder_num_heads=16,\n        mlp_ratio=4,\n        norm_layer=partial(nn.LayerNorm, eps=1e-6),\n        **kwargs,\n    )\n    return model\n\n\n# set recommended archs\nmae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b  # decoder: 512 dim, 8 blocks\nmae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b  # decoder: 512 dim, 8 blocks\nmae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b  # decoder: 512 dim, 8 blocks\nmae_vit_small_patch16 = mae_vit_small_patch16_dec512d8b  # decoder: 512 dim, 8 blocks\n"
  },
  {
    "path": "semanticodec/modules/audiomae/patch_embed.py",
    "content": "import torch\nimport torch.nn as nn\nfrom timm.models.layers import to_2tuple\n\n\nclass PatchEmbed_org(nn.Module):\n    \"\"\"Image to Patch Embedding\"\"\"\n\n    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])\n        self.img_size = img_size\n        self.patch_size = patch_size\n        self.num_patches = num_patches\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size\n        )\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #    f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        x = self.proj(x)\n        y = x.flatten(2).transpose(1, 2)\n        return y\n\n\nclass PatchEmbed_new(nn.Module):\n    \"\"\"Flexible Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=10\n    ):\n        super().__init__()\n        img_size = to_2tuple(img_size)\n        patch_size = to_2tuple(patch_size)\n        stride = to_2tuple(stride)\n\n        self.img_size = img_size\n        self.patch_size = patch_size\n\n        self.proj = nn.Conv2d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=stride\n        )  # with overlapped patches\n        # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)\n\n        # self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0])\n        # self.num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])\n        _, _, h, w = self.get_output_shape(img_size)  # n, emb_dim, h, w\n        self.patch_hw = (h, w)\n        self.num_patches = h * w\n\n    def get_output_shape(self, img_size):\n        # todo: don't be lazy..\n        return self.proj(torch.randn(1, 1, img_size[0], img_size[1])).shape\n\n    def forward(self, x):\n        B, C, H, W = x.shape\n        # FIXME look at relaxing size constraints\n        # assert H == self.img_size[0] and W == self.img_size[1], \\\n        #    f\"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).\"\n        # x = self.proj(x).flatten(2).transpose(1, 2)\n        x = self.proj(x)  # 32, 1, 1024, 128 -> 32, 768, 101, 12\n        x = x.flatten(2)  # 32, 768, 101, 12 -> 32, 768, 1212\n        x = x.transpose(1, 2)  # 32, 768, 1212 -> 32, 1212, 768\n        return x\n\n\nclass PatchEmbed3D_new(nn.Module):\n    \"\"\"Flexible Image to Patch Embedding\"\"\"\n\n    def __init__(\n        self,\n        video_size=(16, 224, 224),\n        patch_size=(2, 16, 16),\n        in_chans=3,\n        embed_dim=768,\n        stride=(2, 16, 16),\n    ):\n        super().__init__()\n\n        self.video_size = video_size\n        self.patch_size = patch_size\n        self.in_chans = in_chans\n\n        self.proj = nn.Conv3d(\n            in_chans, embed_dim, kernel_size=patch_size, stride=stride\n        )\n        _, _, t, h, w = self.get_output_shape(video_size)  # n, emb_dim, h, w\n        self.patch_thw = (t, h, w)\n        self.num_patches = t * h * w\n\n    def get_output_shape(self, video_size):\n        # todo: don't be lazy..\n        return self.proj(\n            torch.randn(1, self.in_chans, video_size[0], video_size[1], video_size[2])\n        ).shape\n\n    def forward(self, x):\n        B, C, T, H, W = x.shape\n        x = self.proj(x)  # 32, 3, 16, 224, 224 -> 32, 768, 8, 14, 14\n        x = x.flatten(2)  # 32, 768, 1568\n        x = x.transpose(1, 2)  # 32, 768, 1568 -> 32, 1568, 768\n        return x\n\n\nif __name__ == \"__main__\":\n    # patch_emb = PatchEmbed_new(img_size=224, patch_size=16, in_chans=1, embed_dim=64, stride=(16,16))\n    # input = torch.rand(8,1,1024,128)\n    # output = patch_emb(input)\n    # print(output.shape) # (8,512,64)\n\n    patch_emb = PatchEmbed3D_new(\n        video_size=(6, 224, 224),\n        patch_size=(2, 16, 16),\n        in_chans=3,\n        embed_dim=768,\n        stride=(2, 16, 16),\n    )\n    input = torch.rand(8, 3, 6, 224, 224)\n    output = patch_emb(input)\n    print(output.shape)  # (8,64)\n"
  },
  {
    "path": "semanticodec/modules/audiomae/pos_embed.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n# --------------------------------------------------------\n# Position embedding utils\n# --------------------------------------------------------\n\nimport numpy as np\n\nimport torch\n\n\n# --------------------------------------------------------\n# 2D sine-cosine position embedding\n# References:\n# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py\n# MoCo v3: https://github.com/facebookresearch/moco-v3\n# --------------------------------------------------------\ndef get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size, dtype=np.float32)\n    grid_w = np.arange(grid_size, dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size, grid_size])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token:\n        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):\n    \"\"\"\n    grid_size: int of the grid height and width\n    return:\n    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)\n    \"\"\"\n    grid_h = np.arange(grid_size[0], dtype=np.float32)\n    grid_w = np.arange(grid_size[1], dtype=np.float32)\n    grid = np.meshgrid(grid_w, grid_h)  # here w goes first\n    grid = np.stack(grid, axis=0)\n\n    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])\n    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)\n    if cls_token:\n        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)\n    return pos_embed\n\n\ndef get_2d_sincos_pos_embed_from_grid(embed_dim, grid):\n    assert embed_dim % 2 == 0\n\n    # use half of dimensions to encode grid_h\n    emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)\n    emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)\n\n    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)\n    return emb\n\n\ndef get_1d_sincos_pos_embed_from_grid(embed_dim, pos):\n    \"\"\"\n    embed_dim: output dimension for each position\n    pos: a list of positions to be encoded: size (M,)\n    out: (M, D)\n    \"\"\"\n    assert embed_dim % 2 == 0\n    # omega = np.arange(embed_dim // 2, dtype=np.float)\n    omega = np.arange(embed_dim // 2, dtype=float)\n    omega /= embed_dim / 2.0\n    omega = 1.0 / 10000**omega  # (D/2,)\n\n    pos = pos.reshape(-1)  # (M,)\n    out = np.einsum(\"m,d->md\", pos, omega)  # (M, D/2), outer product\n\n    emb_sin = np.sin(out)  # (M, D/2)\n    emb_cos = np.cos(out)  # (M, D/2)\n\n    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)\n    return emb\n\n\n# --------------------------------------------------------\n# Interpolate position embeddings for high-resolution\n# References:\n# DeiT: https://github.com/facebookresearch/deit\n# --------------------------------------------------------\ndef interpolate_pos_embed(model, checkpoint_model):\n    if \"pos_embed\" in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model[\"pos_embed\"]\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        new_size = int(num_patches**0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\n                \"Position interpolate from %dx%d to %dx%d\"\n                % (orig_size, orig_size, new_size, new_size)\n            )\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(\n                -1, orig_size, orig_size, embedding_size\n            ).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens,\n                size=(new_size, new_size),\n                mode=\"bicubic\",\n                align_corners=False,\n            )\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model[\"pos_embed\"] = new_pos_embed\n\n\ndef interpolate_pos_embed_img2audio(model, checkpoint_model, orig_size, new_size):\n    if \"pos_embed\" in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model[\"pos_embed\"]\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        # height (== width) for the checkpoint position embedding\n        # orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)\n        # height (== width) for the new position embedding\n        # new_size = int(num_patches ** 0.5)\n        # class_token and dist_token are kept unchanged\n        if orig_size != new_size:\n            print(\n                \"Position interpolate from %dx%d to %dx%d\"\n                % (orig_size[0], orig_size[1], new_size[0], new_size[1])\n            )\n            extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]\n            pos_tokens = pos_tokens.reshape(\n                -1, orig_size[0], orig_size[1], embedding_size\n            ).permute(0, 3, 1, 2)\n            pos_tokens = torch.nn.functional.interpolate(\n                pos_tokens,\n                size=(new_size[0], new_size[1]),\n                mode=\"bicubic\",\n                align_corners=False,\n            )\n            pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)\n            checkpoint_model[\"pos_embed\"] = new_pos_embed\n\n\ndef interpolate_pos_embed_audio(model, checkpoint_model, orig_size, new_size):\n    if \"pos_embed\" in checkpoint_model:\n        pos_embed_checkpoint = checkpoint_model[\"pos_embed\"]\n        embedding_size = pos_embed_checkpoint.shape[-1]\n        num_patches = model.patch_embed.num_patches\n        num_extra_tokens = model.pos_embed.shape[-2] - num_patches\n        if orig_size != new_size:\n            print(\n                \"Position interpolate from %dx%d to %dx%d\"\n                % (orig_size[0], orig_size[1], new_size[0], new_size[1])\n            )\n            # extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]\n            # only the position tokens are interpolated\n            cls_token = pos_embed_checkpoint[:, 0, :].unsqueeze(1)\n            pos_tokens = pos_embed_checkpoint[:, 1:, :]  # remove\n            pos_tokens = pos_tokens.reshape(\n                -1, orig_size[0], orig_size[1], embedding_size\n            )  # .permute(0, 3, 1, 2)\n            # pos_tokens = torch.nn.functional.interpolate(\n            #    pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)\n\n            # pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)\n            pos_tokens = pos_tokens[:, :, : new_size[1], :]  # assume only time diff\n            pos_tokens = pos_tokens.flatten(1, 2)\n            new_pos_embed = torch.cat((cls_token, pos_tokens), dim=1)\n            checkpoint_model[\"pos_embed\"] = new_pos_embed\n\n\ndef interpolate_patch_embed_audio(\n    model,\n    checkpoint_model,\n    orig_channel,\n    new_channel=1,\n    kernel_size=(16, 16),\n    stride=(16, 16),\n    padding=(0, 0),\n):\n    if orig_channel != new_channel:\n        if \"patch_embed.proj.weight\" in checkpoint_model:\n            # aggregate 3 channels in rgb ckpt to 1 channel for audio\n            new_proj_weight = torch.nn.Parameter(\n                torch.sum(checkpoint_model[\"patch_embed.proj.weight\"], dim=1).unsqueeze(\n                    1\n                )\n            )\n            checkpoint_model[\"patch_embed.proj.weight\"] = new_proj_weight\n"
  },
  {
    "path": "semanticodec/modules/decoder/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/hifigan/LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Jungil Kong\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "semanticodec/modules/decoder/hifigan/__init__.py",
    "content": "from .models_v2 import Generator\nfrom .models import Generator as Generator_old\n\n\nclass AttrDict(dict):\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n"
  },
  {
    "path": "semanticodec/modules/decoder/hifigan/models.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Conv1d, ConvTranspose1d\nfrom torch.nn.utils import weight_norm, remove_weight_norm\n\nLRELU_SLOPE = 0.1\n\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size * dilation - dilation) / 2)\n\n\nclass ResBlock(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):\n        super(ResBlock, self).__init__()\n        self.h = h\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[2],\n                        padding=get_padding(kernel_size, dilation[2]),\n                    )\n                ),\n            ]\n        )\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n            ]\n        )\n        self.convs2.apply(init_weights)\n\n    def forward(self, x):\n        for c1, c2 in zip(self.convs1, self.convs2):\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c1(xt)\n            xt = F.leaky_relu(xt, LRELU_SLOPE)\n            xt = c2(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n\nclass Generator(torch.nn.Module):\n    def __init__(self, h):\n        super(Generator, self).__init__()\n        self.h = h\n        self.num_kernels = len(h.resblock_kernel_sizes)\n        self.num_upsamples = len(h.upsample_rates)\n        self.conv_pre = weight_norm(\n            Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)\n        )\n        resblock = ResBlock\n\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):\n            self.ups.append(\n                weight_norm(\n                    ConvTranspose1d(\n                        h.upsample_initial_channel // (2**i),\n                        h.upsample_initial_channel // (2 ** (i + 1)),\n                        k,\n                        u,\n                        padding=(k - u) // 2,\n                    )\n                )\n            )\n\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = h.upsample_initial_channel // (2 ** (i + 1))\n            for j, (k, d) in enumerate(\n                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)\n            ):\n                self.resblocks.append(resblock(h, ch, k, d))\n\n        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))\n        self.ups.apply(init_weights)\n        self.conv_post.apply(init_weights)\n\n    def forward(self, x):\n        x = self.conv_pre(x)\n        for i in range(self.num_upsamples):\n            x = F.leaky_relu(x, LRELU_SLOPE)\n            x = self.ups[i](x)\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n        x = F.leaky_relu(x)\n        x = self.conv_post(x)\n        x = torch.tanh(x)\n\n        return x\n\n    def remove_weight_norm(self):\n        # print(\"Removing weight norm...\")\n        for l in self.ups:\n            remove_weight_norm(l)\n        for l in self.resblocks:\n            l.remove_weight_norm()\n        remove_weight_norm(self.conv_pre)\n        remove_weight_norm(self.conv_post)\n"
  },
  {
    "path": "semanticodec/modules/decoder/hifigan/models_v2.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d\nfrom torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm\n\nLRELU_SLOPE = 0.1\n\n\ndef init_weights(m, mean=0.0, std=0.01):\n    classname = m.__class__.__name__\n    if classname.find(\"Conv\") != -1:\n        m.weight.data.normal_(mean, std)\n\n\ndef get_padding(kernel_size, dilation=1):\n    return int((kernel_size * dilation - dilation) / 2)\n\n\nclass ResBlock1(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):\n        super(ResBlock1, self).__init__()\n        self.h = h\n        self.convs1 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[2],\n                        padding=get_padding(kernel_size, dilation[2]),\n                    )\n                ),\n            ]\n        )\n        self.convs1.apply(init_weights)\n\n        self.convs2 = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=1,\n                        padding=get_padding(kernel_size, 1),\n                    )\n                ),\n            ]\n        )\n        self.convs2.apply(init_weights)\n\n    def forward(self, x):\n        for c1, c2 in zip(self.convs1, self.convs2):\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c1(xt)\n            xt = F.leaky_relu(xt, LRELU_SLOPE)\n            xt = c2(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs1:\n            remove_weight_norm(l)\n        for l in self.convs2:\n            remove_weight_norm(l)\n\n\nclass ResBlock2(torch.nn.Module):\n    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):\n        super(ResBlock2, self).__init__()\n        self.h = h\n        self.convs = nn.ModuleList(\n            [\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[0],\n                        padding=get_padding(kernel_size, dilation[0]),\n                    )\n                ),\n                weight_norm(\n                    Conv1d(\n                        channels,\n                        channels,\n                        kernel_size,\n                        1,\n                        dilation=dilation[1],\n                        padding=get_padding(kernel_size, dilation[1]),\n                    )\n                ),\n            ]\n        )\n        self.convs.apply(init_weights)\n\n    def forward(self, x):\n        for c in self.convs:\n            xt = F.leaky_relu(x, LRELU_SLOPE)\n            xt = c(xt)\n            x = xt + x\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.convs:\n            remove_weight_norm(l)\n\n\nclass Generator(torch.nn.Module):\n    def __init__(self, h):\n        super(Generator, self).__init__()\n        self.h = h\n        self.num_kernels = len(h.resblock_kernel_sizes)\n        self.num_upsamples = len(h.upsample_rates)\n        self.conv_pre = weight_norm(\n            Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)\n        )\n        resblock = ResBlock1 if h.resblock == \"1\" else ResBlock2\n\n        self.ups = nn.ModuleList()\n        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):\n            self.ups.append(\n                weight_norm(\n                    ConvTranspose1d(\n                        h.upsample_initial_channel // (2**i),\n                        h.upsample_initial_channel // (2 ** (i + 1)),\n                        u * 2,\n                        u,\n                        padding=u // 2 + u % 2,\n                        output_padding=u % 2,\n                    )\n                )\n            )\n\n        self.resblocks = nn.ModuleList()\n        for i in range(len(self.ups)):\n            ch = h.upsample_initial_channel // (2 ** (i + 1))\n            for j, (k, d) in enumerate(\n                zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)\n            ):\n                self.resblocks.append(resblock(h, ch, k, d))\n\n        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))\n        self.ups.apply(init_weights)\n        self.conv_post.apply(init_weights)\n\n    def forward(self, x):\n        # import ipdb; ipdb.set_trace()\n        x = self.conv_pre(x)\n        for i in range(self.num_upsamples):\n            x = F.leaky_relu(x, LRELU_SLOPE)\n            x = self.ups[i](x)\n            xs = None\n            for j in range(self.num_kernels):\n                if xs is None:\n                    xs = self.resblocks[i * self.num_kernels + j](x)\n                else:\n                    xs += self.resblocks[i * self.num_kernels + j](x)\n            x = xs / self.num_kernels\n        x = F.leaky_relu(x)\n        x = self.conv_post(x)\n        x = torch.tanh(x)\n\n        return x\n\n    def remove_weight_norm(self):\n        for l in self.ups:\n            remove_weight_norm(l)\n        for l in self.resblocks:\n            l.remove_weight_norm()\n        remove_weight_norm(self.conv_pre)\n        remove_weight_norm(self.conv_post)\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/models/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/models/ddim.py",
    "content": "\"\"\"SAMPLING ONLY.\"\"\"\n\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (\n    make_ddim_sampling_parameters,\n    make_ddim_timesteps,\n    noise_like,\n    extract_into_tensor,\n)\n\n\nclass DDIMSampler(object):\n    def __init__(self, model, schedule=\"linear\", device=torch.device(\"cuda\"), **kwargs):\n        super().__init__()\n        self.model = model\n        self.ddpm_num_timesteps = model.num_timesteps\n        self.schedule = schedule\n        self.device = device\n\n    def register_buffer(self, name, attr):\n        if type(attr) == torch.Tensor:\n            if attr.device != self.device:\n                attr = attr.to(self.device)\n        setattr(self, name, attr)\n\n    def make_schedule(\n        self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0.0, verbose=True\n    ):\n        self.ddim_timesteps = make_ddim_timesteps(\n            ddim_discr_method=ddim_discretize,\n            num_ddim_timesteps=ddim_num_steps,\n            num_ddpm_timesteps=self.ddpm_num_timesteps,\n            verbose=verbose,\n        )\n        alphas_cumprod = self.model.alphas_cumprod\n        assert (\n            alphas_cumprod.shape[0] == self.ddpm_num_timesteps\n        ), \"alphas have to be defined for each timestep\"\n        to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n\n        self.register_buffer(\"betas\", to_torch(self.model.betas))\n        self.register_buffer(\"alphas_cumprod\", to_torch(alphas_cumprod))\n        self.register_buffer(\n            \"alphas_cumprod_prev\", to_torch(self.model.alphas_cumprod_prev)\n        )\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\n            \"sqrt_alphas_cumprod\", to_torch(np.sqrt(alphas_cumprod.cpu()))\n        )\n        self.register_buffer(\n            \"sqrt_one_minus_alphas_cumprod\",\n            to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),\n        )\n        self.register_buffer(\n            \"log_one_minus_alphas_cumprod\", to_torch(np.log(1.0 - alphas_cumprod.cpu()))\n        )\n        self.register_buffer(\n            \"sqrt_recip_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))\n        )\n        self.register_buffer(\n            \"sqrt_recipm1_alphas_cumprod\",\n            to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),\n        )\n\n        # ddim sampling parameters\n        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(\n            alphacums=alphas_cumprod.cpu(),\n            ddim_timesteps=self.ddim_timesteps,\n            eta=ddim_eta,\n            verbose=verbose,\n        )\n        \n        if torch.backends.mps.is_available():\n            ddim_sigmas = ddim_sigmas.to(torch.float32)\n            ddim_alphas = ddim_alphas.to(torch.float32)\n            ddim_alphas_prev = ddim_alphas_prev.astype(np.float32)\n\n        self.register_buffer(\"ddim_sigmas\", ddim_sigmas)\n        self.register_buffer(\"ddim_alphas\", ddim_alphas)\n        self.register_buffer(\"ddim_alphas_prev\", ddim_alphas_prev)\n        self.register_buffer(\"ddim_sqrt_one_minus_alphas\", np.sqrt(1.0 - ddim_alphas))\n        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n            (1 - self.alphas_cumprod_prev)\n            / (1 - self.alphas_cumprod)\n            * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)\n        )\n        self.register_buffer(\n            \"ddim_sigmas_for_original_num_steps\", sigmas_for_original_sampling_steps\n        )\n\n    @torch.no_grad()\n    def sample(\n        self,\n        S,\n        batch_size,\n        shape,\n        conditioning=None,\n        callback=None,\n        normals_sequence=None,\n        img_callback=None,\n        quantize_x0=False,\n        eta=0.0,\n        mask=None,\n        x0=None,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        verbose=True,\n        x_T=None,\n        log_every_t=100,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,  # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...\n        dynamic_threshold=None,\n        ucg_schedule=None,\n        **kwargs\n    ):\n        # if conditioning is not None:\n        #     if isinstance(conditioning, dict):\n        #         ctmp = conditioning[list(conditioning.keys())[0]]\n        #         while isinstance(ctmp, list): ctmp = ctmp[0]\n        #         cbs = ctmp.shape[0]\n        #         if cbs != batch_size:\n        #             print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n\n        #     elif isinstance(conditioning, list):\n        #         for ctmp in conditioning:\n        #             if ctmp.shape[0] != batch_size:\n        #                 print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n\n        #     else:\n        #         if conditioning.shape[0] != batch_size:\n        #             print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n\n        self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n        # sampling\n        C, H, W = shape\n        size = (batch_size, C, H, W)\n\n        samples, intermediates = self.ddim_sampling(\n            conditioning,\n            size,\n            callback=callback,\n            img_callback=img_callback,\n            quantize_denoised=quantize_x0,\n            mask=mask,\n            x0=x0,\n            ddim_use_original_steps=False,\n            noise_dropout=noise_dropout,\n            temperature=temperature,\n            score_corrector=score_corrector,\n            corrector_kwargs=corrector_kwargs,\n            x_T=x_T,\n            log_every_t=log_every_t,\n            unconditional_guidance_scale=unconditional_guidance_scale,\n            unconditional_conditioning=unconditional_conditioning,\n            dynamic_threshold=dynamic_threshold,\n            ucg_schedule=ucg_schedule,\n        )\n        return samples, intermediates\n\n    @torch.no_grad()\n    def ddim_sampling(\n        self,\n        cond,\n        shape,\n        x_T=None,\n        ddim_use_original_steps=False,\n        callback=None,\n        timesteps=None,\n        quantize_denoised=False,\n        mask=None,\n        x0=None,\n        img_callback=None,\n        log_every_t=100,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        dynamic_threshold=None,\n        ucg_schedule=None,\n    ):\n        device = self.model.betas.device\n        b = shape[0]\n        if x_T is None:\n            img = torch.randn(shape, device=device)\n        else:\n            img = x_T\n\n        if timesteps is None:\n            timesteps = (\n                self.ddpm_num_timesteps\n                if ddim_use_original_steps\n                else self.ddim_timesteps\n            )\n        elif timesteps is not None and not ddim_use_original_steps:\n            subset_end = (\n                int(\n                    min(timesteps / self.ddim_timesteps.shape[0], 1)\n                    * self.ddim_timesteps.shape[0]\n                )\n                - 1\n            )\n            timesteps = self.ddim_timesteps[:subset_end]\n\n        intermediates = {\"x_inter\": [img], \"pred_x0\": [img]}\n        time_range = (\n            reversed(range(0, timesteps))\n            if ddim_use_original_steps\n            else np.flip(timesteps)\n        )\n        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n\n        iterator = tqdm(time_range, desc=\"DDIM Sampler\", total=total_steps)\n\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full((b,), step, device=device, dtype=torch.long)\n\n            if mask is not None:\n                assert x0 is not None\n                img_orig = self.model.q_sample(\n                    x0, ts\n                )  # TODO: deterministic forward pass?\n                img = img_orig * mask + (1.0 - mask) * img\n\n            if ucg_schedule is not None:\n                assert len(ucg_schedule) == len(time_range)\n                unconditional_guidance_scale = ucg_schedule[i]\n\n            outs = self.p_sample_ddim(\n                img,\n                cond,\n                ts,\n                index=index,\n                use_original_steps=ddim_use_original_steps,\n                quantize_denoised=quantize_denoised,\n                temperature=temperature,\n                noise_dropout=noise_dropout,\n                score_corrector=score_corrector,\n                corrector_kwargs=corrector_kwargs,\n                unconditional_guidance_scale=unconditional_guidance_scale,\n                unconditional_conditioning=unconditional_conditioning,\n                dynamic_threshold=dynamic_threshold,\n            )\n            img, pred_x0 = outs\n            if callback:\n                callback(i)\n            if img_callback:\n                img_callback(pred_x0, i)\n\n            if index % log_every_t == 0 or index == total_steps - 1:\n                intermediates[\"x_inter\"].append(img)\n                intermediates[\"pred_x0\"].append(pred_x0)\n\n        return img, intermediates\n\n    @torch.no_grad()\n    def p_sample_ddim(\n        self,\n        x,\n        c,\n        t,\n        index,\n        repeat_noise=False,\n        use_original_steps=False,\n        quantize_denoised=False,\n        temperature=1.0,\n        noise_dropout=0.0,\n        score_corrector=None,\n        corrector_kwargs=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        dynamic_threshold=None,\n    ):\n        b, *_, device = *x.shape, x.device\n\n        if unconditional_conditioning is None or unconditional_guidance_scale == 1.0:\n            model_output = self.model.apply_model(x, t, c)\n        else:\n            x_in = x\n            t_in = t\n\n            assert isinstance(c, dict)\n            assert isinstance(unconditional_conditioning, dict)\n\n            model_uncond = self.model.apply_model(\n                x_in, t_in, unconditional_conditioning\n            )\n            model_t = self.model.apply_model(x_in, t_in, c)\n\n            model_output = model_uncond + unconditional_guidance_scale * (\n                model_t - model_uncond\n            )\n\n        if self.model.parameterization == \"v\":\n            e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)\n        else:\n            e_t = model_output\n\n        if score_corrector is not None:\n            assert self.model.parameterization == \"eps\", \"not implemented\"\n            e_t = score_corrector.modify_score(\n                self.model, e_t, x, t, c, **corrector_kwargs\n            )\n\n        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n        alphas_prev = (\n            self.model.alphas_cumprod_prev\n            if use_original_steps\n            else self.ddim_alphas_prev\n        )\n        sqrt_one_minus_alphas = (\n            self.model.sqrt_one_minus_alphas_cumprod\n            if use_original_steps\n            else self.ddim_sqrt_one_minus_alphas\n        )\n        sigmas = (\n            self.model.ddim_sigmas_for_original_num_steps\n            if use_original_steps\n            else self.ddim_sigmas\n        )\n        # select parameters corresponding to the currently considered timestep\n        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n        sqrt_one_minus_at = torch.full(\n            (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device\n        )\n\n        # current prediction for x_0\n        if self.model.parameterization != \"v\":\n            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n        else:\n            pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)\n\n        if quantize_denoised:\n            pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n\n        if dynamic_threshold is not None:\n            raise NotImplementedError()\n\n        # direction pointing to x_t\n        dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t\n        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n        if noise_dropout > 0.0:\n            noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n        return x_prev, pred_x0\n\n    @torch.no_grad()\n    def encode(\n        self,\n        x0,\n        c,\n        t_enc,\n        use_original_steps=False,\n        return_intermediates=None,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        callback=None,\n    ):\n        num_reference_steps = (\n            self.ddpm_num_timesteps\n            if use_original_steps\n            else self.ddim_timesteps.shape[0]\n        )\n\n        assert t_enc <= num_reference_steps\n        num_steps = t_enc\n\n        if use_original_steps:\n            alphas_next = self.alphas_cumprod[:num_steps]\n            alphas = self.alphas_cumprod_prev[:num_steps]\n        else:\n            alphas_next = self.ddim_alphas[:num_steps]\n            alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])\n\n        x_next = x0\n        intermediates = []\n        inter_steps = []\n        for i in tqdm(range(num_steps), desc=\"Encoding Image\"):\n            t = torch.full(\n                (x0.shape[0],), i, device=self.model.device, dtype=torch.long\n            )\n            if unconditional_guidance_scale == 1.0:\n                noise_pred = self.model.apply_model(x_next, t, c)\n            else:\n                assert unconditional_conditioning is not None\n                e_t_uncond, noise_pred = torch.chunk(\n                    self.model.apply_model(\n                        torch.cat((x_next, x_next)),\n                        torch.cat((t, t)),\n                        torch.cat((unconditional_conditioning, c)),\n                    ),\n                    2,\n                )\n                noise_pred = e_t_uncond + unconditional_guidance_scale * (\n                    noise_pred - e_t_uncond\n                )\n\n            xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next\n            weighted_noise_pred = (\n                alphas_next[i].sqrt()\n                * ((1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt())\n                * noise_pred\n            )\n            x_next = xt_weighted + weighted_noise_pred\n            if (\n                return_intermediates\n                and i % (num_steps // return_intermediates) == 0\n                and i < num_steps - 1\n            ):\n                intermediates.append(x_next)\n                inter_steps.append(i)\n            elif return_intermediates and i >= num_steps - 2:\n                intermediates.append(x_next)\n                inter_steps.append(i)\n            if callback:\n                callback(i)\n\n        out = {\"x_encoded\": x_next, \"intermediate_steps\": inter_steps}\n        if return_intermediates:\n            out.update({\"intermediates\": intermediates})\n        return x_next, out\n\n    @torch.no_grad()\n    def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):\n        # fast, but does not allow for exact reconstruction\n        # t serves as an index to gather the correct alphas\n        if use_original_steps:\n            sqrt_alphas_cumprod = self.sqrt_alphas_cumprod\n            sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod\n        else:\n            sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)\n            sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas\n\n        if noise is None:\n            noise = torch.randn_like(x0)\n        return (\n            extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0\n            + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise\n        )\n\n    @torch.no_grad()\n    def decode(\n        self,\n        x_latent,\n        cond,\n        t_start,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        use_original_steps=False,\n        callback=None,\n    ):\n        timesteps = (\n            np.arange(self.ddpm_num_timesteps)\n            if use_original_steps\n            else self.ddim_timesteps\n        )\n        timesteps = timesteps[:t_start]\n\n        time_range = np.flip(timesteps)\n        total_steps = timesteps.shape[0]\n\n        iterator = tqdm(time_range, desc=\"Decoding image\", total=total_steps)\n        x_dec = x_latent\n        for i, step in enumerate(iterator):\n            index = total_steps - i - 1\n            ts = torch.full(\n                (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long\n            )\n            x_dec, _ = self.p_sample_ddim(\n                x_dec,\n                cond,\n                ts,\n                index=index,\n                use_original_steps=use_original_steps,\n                unconditional_guidance_scale=unconditional_guidance_scale,\n                unconditional_conditioning=unconditional_conditioning,\n            )\n            if callback:\n                callback(i)\n        return x_dec\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/models/ddpm.py",
    "content": "import torch\nimport torch.nn as nn\nimport numpy as np\nfrom contextlib import contextmanager\nfrom functools import partial\nfrom tqdm import tqdm\n\nfrom semanticodec.modules.decoder.latent_diffusion.util import (\n    exists,\n    default,\n    count_params,\n    instantiate_from_config,\n)\nfrom semanticodec.modules.decoder.latent_diffusion.modules.ema import LitEma\n\nfrom semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (\n    make_beta_schedule,\n    extract_into_tensor,\n    noise_like,\n)\n\nfrom semanticodec.modules.decoder.latent_diffusion.models.ddim import DDIMSampler\nfrom semanticodec.modules.decoder.latent_diffusion.util import disabled_train\nfrom semanticodec.utils import PositionalEncoding\n\n\nclass DDPM(nn.Module):\n    # classic DDPM with Gaussian diffusion, in image space\n    def __init__(\n        self,\n        unet_config,\n        sampling_rate=None,\n        timesteps=1000,\n        beta_schedule=\"linear\",\n        use_ema=True,\n        first_stage_key=\"image\",\n        latent_t_size=256,\n        latent_f_size=16,\n        channels=3,\n        clip_denoised=True,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n        given_betas=None,\n        v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta\n        conditioning_key=None,\n        parameterization=\"eps\",  # all assuming fixed variance schedules\n        logvar_init=0.0,\n    ):\n        super().__init__()\n        assert parameterization in [\n            \"eps\",\n            \"x0\",\n            \"v\",\n        ], 'currently only supporting \"eps\" and \"x0\" and \"v\"'\n        self.parameterization = parameterization\n        self.state = None\n        assert sampling_rate is not None\n        self.validation_folder_name = \"temp_name\"\n        self.clip_denoised = clip_denoised\n        self.first_stage_key = first_stage_key\n        self.sampling_rate = sampling_rate\n\n        self.latent_t_size = latent_t_size\n        self.latent_f_size = latent_f_size\n        self.v_posterior = v_posterior\n\n        self.channels = channels\n        self.model = DiffusionWrapper(unet_config, conditioning_key)\n        count_params(self.model, verbose=True)\n        self.use_ema = use_ema\n        if self.use_ema:\n            self.model_ema = LitEma(self.model)\n            # print(f\"Keeping EMAs of {len(list(self.model_ema.buffers()))}.\")\n\n        self.register_schedule(\n            given_betas=given_betas,\n            beta_schedule=beta_schedule,\n            timesteps=timesteps,\n            linear_start=linear_start,\n            linear_end=linear_end,\n            cosine_s=cosine_s,\n        )\n\n        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))\n        self.logvar = nn.Parameter(self.logvar, requires_grad=False)\n        self.pos_embed = PositionalEncoding(seq_length=512, embedding_dim=192)\n\n    def register_schedule(\n        self,\n        given_betas=None,\n        beta_schedule=\"linear\",\n        timesteps=1000,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n    ):\n        if exists(given_betas):\n            betas = given_betas\n        else:\n            betas = make_beta_schedule(\n                beta_schedule,\n                timesteps,\n                linear_start=linear_start,\n                linear_end=linear_end,\n                cosine_s=cosine_s,\n            )\n        alphas = 1.0 - betas\n        alphas_cumprod = np.cumprod(alphas, axis=0)\n        alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])\n\n        (timesteps,) = betas.shape\n        self.num_timesteps = int(timesteps)\n        self.linear_start = linear_start\n        self.linear_end = linear_end\n        assert (\n            alphas_cumprod.shape[0] == self.num_timesteps\n        ), \"alphas have to be defined for each timestep\"\n\n        to_torch = partial(torch.tensor, dtype=torch.float32)\n\n        self.register_buffer(\"betas\", to_torch(betas))\n        self.register_buffer(\"alphas_cumprod\", to_torch(alphas_cumprod))\n        self.register_buffer(\"alphas_cumprod_prev\", to_torch(alphas_cumprod_prev))\n\n        # calculations for diffusion q(x_t | x_{t-1}) and others\n        self.register_buffer(\"sqrt_alphas_cumprod\", to_torch(np.sqrt(alphas_cumprod)))\n        self.register_buffer(\n            \"sqrt_one_minus_alphas_cumprod\", to_torch(np.sqrt(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"log_one_minus_alphas_cumprod\", to_torch(np.log(1.0 - alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recip_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod))\n        )\n        self.register_buffer(\n            \"sqrt_recipm1_alphas_cumprod\", to_torch(np.sqrt(1.0 / alphas_cumprod - 1))\n        )\n\n        # calculations for posterior q(x_{t-1} | x_t, x_0)\n        posterior_variance = (1 - self.v_posterior) * betas * (\n            1.0 - alphas_cumprod_prev\n        ) / (1.0 - alphas_cumprod) + self.v_posterior * betas\n        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)\n        self.register_buffer(\"posterior_variance\", to_torch(posterior_variance))\n        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain\n        self.register_buffer(\n            \"posterior_log_variance_clipped\",\n            to_torch(np.log(np.maximum(posterior_variance, 1e-20))),\n        )\n        self.register_buffer(\n            \"posterior_mean_coef1\",\n            to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)),\n        )\n        self.register_buffer(\n            \"posterior_mean_coef2\",\n            to_torch(\n                (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)\n            ),\n        )\n\n        if self.parameterization == \"eps\":\n            lvlb_weights = self.betas**2 / (\n                2\n                * self.posterior_variance\n                * to_torch(alphas)\n                * (1 - self.alphas_cumprod)\n            )\n        elif self.parameterization == \"x0\":\n            lvlb_weights = (\n                0.5\n                * np.sqrt(torch.Tensor(alphas_cumprod))\n                / (2.0 * 1 - torch.Tensor(alphas_cumprod))\n            )\n        elif self.parameterization == \"v\":\n            lvlb_weights = torch.ones_like(\n                self.betas**2\n                / (\n                    2\n                    * self.posterior_variance\n                    * to_torch(alphas)\n                    * (1 - self.alphas_cumprod)\n                )\n            )\n        else:\n            raise NotImplementedError(\"mu not supported\")\n        # TODO how to choose this term\n        lvlb_weights[0] = lvlb_weights[1]\n        self.register_buffer(\"lvlb_weights\", lvlb_weights, persistent=False)\n        assert not torch.isnan(self.lvlb_weights).all()\n\n    @contextmanager\n    def ema_scope(self, context=None):\n        if self.use_ema:\n            self.model_ema.store(self.model.parameters())\n            self.model_ema.copy_to(self.model)\n            # if context is not None:\n            # print(f\"{context}: Switched to EMA weights\")\n        try:\n            yield None\n        finally:\n            if self.use_ema:\n                self.model_ema.restore(self.model.parameters())\n                # if context is not None:\n                # print(f\"{context}: Restored training weights\")\n\n    def q_mean_variance(self, x_start, t):\n        \"\"\"\n        Get the distribution q(x_t | x_0).\n        :param x_start: the [N x C x ...] tensor of noiseless inputs.\n        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.\n        :return: A tuple (mean, variance, log_variance), all of x_start's shape.\n        \"\"\"\n        mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n        variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)\n        log_variance = extract_into_tensor(\n            self.log_one_minus_alphas_cumprod, t, x_start.shape\n        )\n        return mean, variance, log_variance\n\n    def predict_start_from_noise(self, x_t, t, noise):\n        return (\n            extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t\n            - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)\n            * noise\n        )\n\n    def q_posterior(self, x_start, x_t, t):\n        posterior_mean = (\n            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start\n            + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t\n        )\n        posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)\n        posterior_log_variance_clipped = extract_into_tensor(\n            self.posterior_log_variance_clipped, t, x_t.shape\n        )\n        return posterior_mean, posterior_variance, posterior_log_variance_clipped\n\n    def p_mean_variance(self, x, t, clip_denoised: bool):\n        model_out = self.model(x, t)\n        if self.parameterization == \"eps\":\n            x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)\n        elif self.parameterization == \"x0\":\n            x_recon = model_out\n        if clip_denoised:\n            x_recon.clamp_(-1.0, 1.0)\n\n        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(\n            x_start=x_recon, x_t=x, t=t\n        )\n        return model_mean, posterior_variance, posterior_log_variance\n\n    @torch.no_grad()\n    def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):\n        b, *_, device = *x.shape, x.device\n        model_mean, _, model_log_variance = self.p_mean_variance(\n            x=x, t=t, clip_denoised=clip_denoised\n        )\n        noise = noise_like(x.shape, device, repeat_noise)\n        # no noise when t == 0\n        nonzero_mask = (\n            (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous()\n        )\n        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise\n\n    @torch.no_grad()\n    def p_sample_loop(self, shape):\n        device = self.betas.device\n        b = shape[0]\n        img = torch.randn(shape, device=device)\n        for i in tqdm(\n            reversed(range(0, self.num_timesteps)),\n            desc=\"Sampling t\",\n            total=self.num_timesteps,\n        ):\n            img = self.p_sample(\n                img,\n                torch.full((b,), i, device=device, dtype=torch.long),\n                clip_denoised=self.clip_denoised,\n            )\n        return img\n\n    @torch.no_grad()\n    def sample(self, batch_size=16, return_intermediates=False):\n        shape = (batch_size, channels, self.latent_t_size, self.latent_f_size)\n        channels = self.channels\n        return self.p_sample_loop(shape, return_intermediates=return_intermediates)\n\n    def q_sample(self, x_start, t, noise=None):\n        noise = default(noise, lambda: torch.randn_like(x_start))\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start\n            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)\n            * noise\n        )\n\n    def predict_start_from_z_and_v(self, x_t, t, v):\n        # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))\n        # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t\n            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v\n        )\n\n    def predict_eps_from_z_and_v(self, x_t, t, v):\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v\n            + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape)\n            * x_t\n        )\n\n    def get_v(self, x, noise, t):\n        return (\n            extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise\n            - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x\n        )\n\n\nclass LatentDiffusion(DDPM):\n    \"\"\"main class\"\"\"\n\n    def __init__(\n        self,\n        first_stage_config,\n        cond_stage_config=None,\n        num_timesteps_cond=None,\n        scale_factor=1.0,\n        evaluation_params={},\n        scale_by_std=False,\n        base_learning_rate=None,\n        *args,\n        **kwargs,\n    ):\n        if torch.cuda.is_available():\n            self.device = torch.device(\"cuda\")\n        elif torch.backends.mps.is_available():\n            self.device = torch.device(\"mps\")\n        else:\n            self.device = torch.device(\"cpu\")\n\n        self.learning_rate = base_learning_rate\n        self.num_timesteps_cond = default(num_timesteps_cond, 1)\n        self.scale_by_std = scale_by_std\n\n        self.evaluation_params = evaluation_params\n        assert self.num_timesteps_cond <= kwargs[\"timesteps\"]\n\n        conditioning_key = list(cond_stage_config.keys())\n\n        self.conditioning_key = conditioning_key\n\n        super().__init__(conditioning_key=conditioning_key, *args, **kwargs)\n\n        try:\n            self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1\n        except:\n            self.num_downs = 0\n\n        if not scale_by_std:\n            self.scale_factor = scale_factor\n        else:\n            self.register_buffer(\"scale_factor\", torch.tensor(scale_factor))\n        self.instantiate_first_stage(first_stage_config)\n        self.cond_stage_models = nn.ModuleList([])\n        self.clip_denoised = False\n        self.bbox_tokenizer = None\n        self.conditional_dry_run_finished = False\n        self.restarted_from_ckpt = False\n\n    def make_cond_schedule(\n        self,\n    ):\n        self.cond_ids = torch.full(\n            size=(self.num_timesteps,),\n            fill_value=self.num_timesteps - 1,\n            dtype=torch.long,\n        )\n        ids = torch.round(\n            torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)\n        ).long()\n        self.cond_ids[: self.num_timesteps_cond] = ids\n\n    def register_schedule(\n        self,\n        given_betas=None,\n        beta_schedule=\"linear\",\n        timesteps=1000,\n        linear_start=1e-4,\n        linear_end=2e-2,\n        cosine_s=8e-3,\n    ):\n        super().register_schedule(\n            given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s\n        )\n\n        self.shorten_cond_schedule = self.num_timesteps_cond > 1\n        if self.shorten_cond_schedule:\n            self.make_cond_schedule()\n\n    def instantiate_first_stage(self, config):\n        model = instantiate_from_config(config)\n        self.first_stage_model = model.eval()\n        self.first_stage_model.train = disabled_train\n        for param in self.first_stage_model.parameters():\n            param.requires_grad = False\n\n    def decode_first_stage(self, z):\n        with torch.no_grad():\n            z = 1.0 / self.scale_factor * z\n            decoding = self.first_stage_model.decode(z)\n        return decoding\n\n    def mel_spectrogram_to_waveform(self, mel):\n        # Mel: [bs, 1, t-steps, fbins]\n        if len(mel.size()) == 4:\n            mel = mel.squeeze(1)\n        mel = mel.permute(0, 2, 1)\n        waveform = self.first_stage_model.vocoder(mel)\n        waveform = waveform.cpu().detach().numpy()\n        return waveform\n\n    def encode_first_stage(self, x):\n        with torch.no_grad():\n            return self.first_stage_model.encode(x)\n\n    @torch.no_grad()\n    def sample_log(\n        self,\n        cond,\n        batch_size,\n        ddim_steps,\n        unconditional_guidance_scale=1.0,\n        unconditional_conditioning=None,\n        mask=None,\n        **kwargs,\n    ):\n        if mask is not None:\n            shape = (self.channels, mask.size()[-2], mask.size()[-1])\n        else:\n            shape = (self.channels, self.latent_t_size, self.latent_f_size)\n\n        # print(\"Use ddim sampler\")\n\n        ddim_sampler = DDIMSampler(self, device = self.device)\n        samples, intermediates = ddim_sampler.sample(\n            ddim_steps,\n            batch_size,\n            shape,\n            cond,\n            verbose=False,\n            unconditional_guidance_scale=unconditional_guidance_scale,\n            unconditional_conditioning=unconditional_conditioning,\n            mask=mask,\n            **kwargs,\n        )\n        return samples, intermediates\n\n    def apply_model(self, x_noisy, t, cond, return_ids=False):\n        x_recon = self.model(x_noisy, t, cond_dict=cond)\n\n        if isinstance(x_recon, tuple) and not return_ids:\n            return x_recon[0]\n        else:\n            return x_recon\n\n    @torch.no_grad()\n    def generate_sample(\n        self,\n        quanized_feature,\n        ddim_steps=200,\n        ddim_eta=1.0,\n        x_T=None,\n        unconditional_guidance_scale=1.0,\n    ):\n        batch_size = quanized_feature.shape[0]\n\n        pe = self.pos_embed(quanized_feature)\n\n        unconditional_conditioning = {}\n        if unconditional_guidance_scale != 1.0:\n            unconditional_quanized_feature = torch.cat(\n                [\n                    quanized_feature * 0.0,\n                    pe.repeat(quanized_feature.size(0), 1, 1).to(\n                        quanized_feature.device\n                    ),\n                ],\n                dim=-1,\n            )\n            unconditional_conditioning = {\n                \"crossattn_audiomae_pooled\": [\n                    unconditional_quanized_feature,\n                    torch.ones(\n                        (\n                            unconditional_quanized_feature.size(0),\n                            unconditional_quanized_feature.size(1),\n                        )\n                    )\n                    .to(unconditional_quanized_feature.device)\n                    .float(),\n                ]\n            }\n\n        quanized_feature = torch.cat(\n            [\n                quanized_feature,\n                pe.repeat(quanized_feature.size(0), 1, 1).to(quanized_feature.device),\n            ],\n            dim=-1,\n        )\n        latent = {\n            \"crossattn_audiomae_pooled\": [\n                quanized_feature,\n                torch.ones((quanized_feature.size(0), quanized_feature.size(1)))\n                .to(quanized_feature.device)\n                .float(),\n            ]\n        }\n\n        samples, _ = self.sample_log(\n            cond=latent,\n            batch_size=batch_size,\n            x_T=x_T,\n            ddim=True,\n            ddim_steps=ddim_steps,\n            eta=ddim_eta,\n            unconditional_guidance_scale=unconditional_guidance_scale,\n            unconditional_conditioning=unconditional_conditioning,\n        )\n\n        mel = self.decode_first_stage(samples)\n\n        return self.mel_spectrogram_to_waveform(mel)\n\n\nclass DiffusionWrapper(nn.Module):\n    def __init__(self, diff_model_config, conditioning_key):\n        super().__init__()\n        self.diffusion_model = instantiate_from_config(diff_model_config)\n        self.conditioning_key = conditioning_key\n\n    def forward(self, x, t, cond_dict: dict = {}):\n        x = x.contiguous()\n        t = t.contiguous()\n        context_list, attn_mask_list = [], []\n        context, attn_mask = cond_dict[\"crossattn_audiomae_pooled\"]\n        context_list.append(context)\n        attn_mask_list.append(attn_mask)\n        out = self.diffusion_model(\n            x,\n            t,\n            context_list=context_list,\n            y=None,\n            context_attn_mask_list=attn_mask_list,\n        )\n        return out\n\n\ndef extract_encoder_state_dict(checkpoint_path):\n    state_dict = torch.load(checkpoint_path)[\"state_dict\"]\n    new_state_dict = {}\n    for key in state_dict.keys():\n        if \"cond_stage_models.0\" in key:\n            if \"pos_embed.pe\" in key:\n                continue\n            new_key_name = key.replace(\"cond_stage_models.0.\", \"\")\n            new_state_dict[new_key_name] = state_dict[key]\n    return new_state_dict\n\n\ndef overlap_add_waveform(windowed_waveforms, overlap_duration=0.64):\n    \"\"\"\n    Concatenates a series of windowed waveforms with overlap, applying fade-in and fade-out effects to the overlaps.\n\n    Parameters:\n    - windowed_waveforms: a list of numpy arrays with shape (1, 1, samples_per_waveform)\n\n    Returns:\n    - A single waveform numpy array resulting from the overlap-add process.\n    \"\"\"\n    # Assuming a sampling rate of 16000 Hz and 0.64 seconds overlap\n    if overlap_duration < 1e-4:\n        return np.concatenate(windowed_waveforms, axis=-1)\n\n    sampling_rate = 16000\n    overlap_samples = int(overlap_duration * sampling_rate)\n\n    # Initialize the output waveform\n    output_waveform = np.array([]).reshape(1, 1, -1)\n\n    for i, waveform in enumerate(windowed_waveforms):\n        # If not the first waveform, apply fade-in at the beginning\n        if i > 0:\n            fade_in = np.linspace(0, 1, overlap_samples).reshape(1, 1, -1)\n            waveform[:, :, :overlap_samples] *= fade_in\n\n        # If output waveform already has content, apply fade-out to its last overlap and add the overlapping parts\n        if output_waveform.size > 0:\n            fade_out = np.linspace(1, 0, overlap_samples).reshape(1, 1, -1)\n            # Apply fade-out to the end of the output waveform\n            output_waveform[:, :, -overlap_samples:] *= fade_out\n            # Add the faded-in start of the current waveform to the faded-out end of the output waveform\n            output_waveform[:, :, -overlap_samples:] += waveform[:, :, :overlap_samples]\n\n        # Concatenate the current waveform (minus the initial overlap if not the first) to the output\n        if output_waveform.size == 0:\n            output_waveform = waveform\n        else:\n            output_waveform = np.concatenate(\n                (output_waveform, waveform[:, :, overlap_samples:]), axis=2\n            )\n\n    return output_waveform\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/__init__.py",
    "content": "from .sampler import DPMSolverSampler\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/models/dpm_solver/dpm_solver.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport math\n\n\nclass NoiseScheduleVP:\n    def __init__(\n        self,\n        schedule=\"discrete\",\n        betas=None,\n        alphas_cumprod=None,\n        continuous_beta_0=0.1,\n        continuous_beta_1=20.0,\n    ):\n        \"\"\"Create a wrapper class for the forward SDE (VP type).\n\n        ***\n        Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.\n                We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.\n        ***\n\n        The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).\n        We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).\n        Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:\n\n            log_alpha_t = self.marginal_log_mean_coeff(t)\n            sigma_t = self.marginal_std(t)\n            lambda_t = self.marginal_lambda(t)\n\n        Moreover, as lambda(t) is an invertible function, we also support its inverse function:\n\n            t = self.inverse_lambda(lambda_t)\n\n        ===============================================================\n\n        We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).\n\n        1. For discrete-time DPMs:\n\n            For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:\n                t_i = (i + 1) / N\n            e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.\n            We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.\n\n            Args:\n                betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)\n                alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)\n\n            Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.\n\n            **Important**:  Please pay special attention for the args for `alphas_cumprod`:\n                The `alphas_cumprod` is the \\hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that\n                    q_{t_n | 0}(x_{t_n} | x_0) = N ( \\sqrt{\\hat{alpha_n}} * x_0, (1 - \\hat{alpha_n}) * I ).\n                Therefore, the notation \\hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have\n                    alpha_{t_n} = \\sqrt{\\hat{alpha_n}},\n                and\n                    log(alpha_{t_n}) = 0.5 * log(\\hat{alpha_n}).\n\n\n        2. For continuous-time DPMs:\n\n            We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise\n            schedule are the default settings in DDPM and improved-DDPM:\n\n            Args:\n                beta_min: A `float` number. The smallest beta for the linear schedule.\n                beta_max: A `float` number. The largest beta for the linear schedule.\n                cosine_s: A `float` number. The hyperparameter in the cosine schedule.\n                cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.\n                T: A `float` number. The ending time of the forward process.\n\n        ===============================================================\n\n        Args:\n            schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,\n                    'linear' or 'cosine' for continuous-time DPMs.\n        Returns:\n            A wrapper object of the forward SDE (VP type).\n\n        ===============================================================\n\n        Example:\n\n        # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):\n        >>> ns = NoiseScheduleVP('discrete', betas=betas)\n\n        # For discrete-time DPMs, given alphas_cumprod (the \\hat{alpha_n} array for n = 0, 1, ..., N - 1):\n        >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)\n\n        # For continuous-time DPMs (VPSDE), linear schedule:\n        >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)\n\n        \"\"\"\n\n        if schedule not in [\"discrete\", \"linear\", \"cosine\"]:\n            raise ValueError(\n                \"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'\".format(\n                    schedule\n                )\n            )\n\n        self.schedule = schedule\n        if schedule == \"discrete\":\n            if betas is not None:\n                log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)\n            else:\n                assert alphas_cumprod is not None\n                log_alphas = 0.5 * torch.log(alphas_cumprod)\n            self.total_N = len(log_alphas)\n            self.T = 1.0\n            self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape(\n                (1, -1)\n            )\n            self.log_alpha_array = log_alphas.reshape(\n                (\n                    1,\n                    -1,\n                )\n            )\n        else:\n            self.total_N = 1000\n            self.beta_0 = continuous_beta_0\n            self.beta_1 = continuous_beta_1\n            self.cosine_s = 0.008\n            self.cosine_beta_max = 999.0\n            self.cosine_t_max = (\n                math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)\n                * 2.0\n                * (1.0 + self.cosine_s)\n                / math.pi\n                - self.cosine_s\n            )\n            self.cosine_log_alpha_0 = math.log(\n                math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)\n            )\n            self.schedule = schedule\n            if schedule == \"cosine\":\n                # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.\n                # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.\n                self.T = 0.9946\n            else:\n                self.T = 1.0\n\n    def marginal_log_mean_coeff(self, t):\n        \"\"\"\n        Compute log(alpha_t) of a given continuous-time label t in [0, T].\n        \"\"\"\n        if self.schedule == \"discrete\":\n            return interpolate_fn(\n                t.reshape((-1, 1)),\n                self.t_array.to(t.device),\n                self.log_alpha_array.to(t.device),\n            ).reshape((-1))\n        elif self.schedule == \"linear\":\n            return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0\n        elif self.schedule == \"cosine\":\n            log_alpha_fn = lambda s: torch.log(\n                torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)\n            )\n            log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0\n            return log_alpha_t\n\n    def marginal_alpha(self, t):\n        \"\"\"\n        Compute alpha_t of a given continuous-time label t in [0, T].\n        \"\"\"\n        return torch.exp(self.marginal_log_mean_coeff(t))\n\n    def marginal_std(self, t):\n        \"\"\"\n        Compute sigma_t of a given continuous-time label t in [0, T].\n        \"\"\"\n        return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))\n\n    def marginal_lambda(self, t):\n        \"\"\"\n        Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].\n        \"\"\"\n        log_mean_coeff = self.marginal_log_mean_coeff(t)\n        log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))\n        return log_mean_coeff - log_std\n\n    def inverse_lambda(self, lamb):\n        \"\"\"\n        Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.\n        \"\"\"\n        if self.schedule == \"linear\":\n            tmp = (\n                2.0\n                * (self.beta_1 - self.beta_0)\n                * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))\n            )\n            Delta = self.beta_0**2 + tmp\n            return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)\n        elif self.schedule == \"discrete\":\n            log_alpha = -0.5 * torch.logaddexp(\n                torch.zeros((1,)).to(lamb.device), -2.0 * lamb\n            )\n            t = interpolate_fn(\n                log_alpha.reshape((-1, 1)),\n                torch.flip(self.log_alpha_array.to(lamb.device), [1]),\n                torch.flip(self.t_array.to(lamb.device), [1]),\n            )\n            return t.reshape((-1,))\n        else:\n            log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))\n            t_fn = (\n                lambda log_alpha_t: torch.arccos(\n                    torch.exp(log_alpha_t + self.cosine_log_alpha_0)\n                )\n                * 2.0\n                * (1.0 + self.cosine_s)\n                / math.pi\n                - self.cosine_s\n            )\n            t = t_fn(log_alpha)\n            return t\n\n\ndef model_wrapper(\n    model,\n    noise_schedule,\n    model_type=\"noise\",\n    model_kwargs={},\n    guidance_type=\"uncond\",\n    condition=None,\n    unconditional_condition=None,\n    guidance_scale=1.0,\n    classifier_fn=None,\n    classifier_kwargs={},\n):\n    \"\"\"Create a wrapper function for the noise prediction model.\n\n    DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to\n    firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.\n\n    We support four types of the diffusion model by setting `model_type`:\n\n        1. \"noise\": noise prediction model. (Trained by predicting noise).\n\n        2. \"x_start\": data prediction model. (Trained by predicting the data x_0 at time 0).\n\n        3. \"v\": velocity prediction model. (Trained by predicting the velocity).\n            The \"v\" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].\n\n            [1] Salimans, Tim, and Jonathan Ho. \"Progressive distillation for fast sampling of diffusion models.\"\n                arXiv preprint arXiv:2202.00512 (2022).\n            [2] Ho, Jonathan, et al. \"Imagen Video: High Definition Video Generation with Diffusion Models.\"\n                arXiv preprint arXiv:2210.02303 (2022).\n\n        4. \"score\": marginal score function. (Trained by denoising score matching).\n            Note that the score function and the noise prediction model follows a simple relationship:\n            ```\n                noise(x_t, t) = -sigma_t * score(x_t, t)\n            ```\n\n    We support three types of guided sampling by DPMs by setting `guidance_type`:\n        1. \"uncond\": unconditional sampling by DPMs.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, **model_kwargs) -> noise | x_start | v | score\n            ``\n\n        2. \"classifier\": classifier guidance sampling [3] by DPMs and another classifier.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, **model_kwargs) -> noise | x_start | v | score\n            ``\n\n            The input `classifier_fn` has the following format:\n            ``\n                classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)\n            ``\n\n            [3] P. Dhariwal and A. Q. Nichol, \"Diffusion models beat GANs on image synthesis,\"\n                in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.\n\n        3. \"classifier-free\": classifier-free guidance sampling by conditional DPMs.\n            The input `model` has the following format:\n            ``\n                model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score\n            ``\n            And if cond == `unconditional_condition`, the model output is the unconditional DPM output.\n\n            [4] Ho, Jonathan, and Tim Salimans. \"Classifier-free diffusion guidance.\"\n                arXiv preprint arXiv:2207.12598 (2022).\n\n\n    The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)\n    or continuous-time labels (i.e. epsilon to T).\n\n    We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:\n    ``\n        def model_fn(x, t_continuous) -> noise:\n            t_input = get_model_input_time(t_continuous)\n            return noise_pred(model, x, t_input, **model_kwargs)\n    ``\n    where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.\n\n    ===============================================================\n\n    Args:\n        model: A diffusion model with the corresponding format described above.\n        noise_schedule: A noise schedule object, such as NoiseScheduleVP.\n        model_type: A `str`. The parameterization type of the diffusion model.\n                    \"noise\" or \"x_start\" or \"v\" or \"score\".\n        model_kwargs: A `dict`. A dict for the other inputs of the model function.\n        guidance_type: A `str`. The type of the guidance for sampling.\n                    \"uncond\" or \"classifier\" or \"classifier-free\".\n        condition: A pytorch tensor. The condition for the guided sampling.\n                    Only used for \"classifier\" or \"classifier-free\" guidance type.\n        unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.\n                    Only used for \"classifier-free\" guidance type.\n        guidance_scale: A `float`. The scale for the guided sampling.\n        classifier_fn: A classifier function. Only used for the classifier guidance.\n        classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.\n    Returns:\n        A noise prediction model that accepts the noised data and the continuous time as the inputs.\n    \"\"\"\n\n    def get_model_input_time(t_continuous):\n        \"\"\"\n        Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.\n        For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].\n        For continuous-time DPMs, we just use `t_continuous`.\n        \"\"\"\n        if noise_schedule.schedule == \"discrete\":\n            return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0\n        else:\n            return t_continuous\n\n    def noise_pred_fn(x, t_continuous, cond=None):\n        if t_continuous.reshape((-1,)).shape[0] == 1:\n            t_continuous = t_continuous.expand((x.shape[0]))\n        t_input = get_model_input_time(t_continuous)\n        if cond is None:\n            output = model(x, t_input, **model_kwargs)\n        else:\n            output = model(x, t_input, cond, **model_kwargs)\n        if model_type == \"noise\":\n            return output\n        elif model_type == \"x_start\":\n            alpha_t, sigma_t = noise_schedule.marginal_alpha(\n                t_continuous\n            ), noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return (x - expand_dims(alpha_t, dims) * output) / expand_dims(\n                sigma_t, dims\n            )\n        elif model_type == \"v\":\n            alpha_t, sigma_t = noise_schedule.marginal_alpha(\n                t_continuous\n            ), noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x\n        elif model_type == \"score\":\n            sigma_t = noise_schedule.marginal_std(t_continuous)\n            dims = x.dim()\n            return -expand_dims(sigma_t, dims) * output\n\n    def cond_grad_fn(x, t_input):\n        \"\"\"\n        Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).\n        \"\"\"\n        with torch.enable_grad():\n            x_in = x.detach().requires_grad_(True)\n            log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)\n            return torch.autograd.grad(log_prob.sum(), x_in)[0]\n\n    def model_fn(x, t_continuous):\n        \"\"\"\n        The noise predicition model function that is used for DPM-Solver.\n        \"\"\"\n        if t_continuous.reshape((-1,)).shape[0] == 1:\n            t_continuous = t_continuous.expand((x.shape[0]))\n        if guidance_type == \"uncond\":\n            return noise_pred_fn(x, t_continuous)\n        elif guidance_type == \"classifier\":\n            assert classifier_fn is not None\n            t_input = get_model_input_time(t_continuous)\n            cond_grad = cond_grad_fn(x, t_input)\n            sigma_t = noise_schedule.marginal_std(t_continuous)\n            noise = noise_pred_fn(x, t_continuous)\n            return (\n                noise\n                - guidance_scale\n                * expand_dims(sigma_t, dims=cond_grad.dim())\n                * cond_grad\n            )\n        elif guidance_type == \"classifier-free\":\n            if guidance_scale == 1.0 or unconditional_condition is None:\n                return noise_pred_fn(x, t_continuous, cond=condition)\n            else:\n                x_in = torch.cat([x] * 2)\n                t_in = torch.cat([t_continuous] * 2)\n                c_in = torch.cat([unconditional_condition, condition])\n                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)\n                return noise_uncond + guidance_scale * (noise - noise_uncond)\n\n    assert model_type in [\"noise\", \"x_start\", \"v\"]\n    assert guidance_type in [\"uncond\", \"classifier\", \"classifier-free\"]\n    return model_fn\n\n\nclass DPM_Solver:\n    def __init__(\n        self,\n        model_fn,\n        noise_schedule,\n        predict_x0=False,\n        thresholding=False,\n        max_val=1.0,\n    ):\n        \"\"\"Construct a DPM-Solver.\n\n        We support both the noise prediction model (\"predicting epsilon\") and the data prediction model (\"predicting x0\").\n        If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).\n        If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).\n            In such case, we further support the \"dynamic thresholding\" in [1] when `thresholding` is True.\n            The \"dynamic thresholding\" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.\n\n        Args:\n            model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):\n                ``\n                def model_fn(x, t_continuous):\n                    return noise\n                ``\n            noise_schedule: A noise schedule object, such as NoiseScheduleVP.\n            predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.\n            thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the \"dynamic thresholding\" in [1].\n            max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.\n\n        [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.\n        \"\"\"\n        self.model = model_fn\n        self.noise_schedule = noise_schedule\n        self.predict_x0 = predict_x0\n        self.thresholding = thresholding\n        self.max_val = max_val\n\n    def noise_prediction_fn(self, x, t):\n        \"\"\"\n        Return the noise prediction model.\n        \"\"\"\n        return self.model(x, t)\n\n    def data_prediction_fn(self, x, t):\n        \"\"\"\n        Return the data prediction model (with thresholding).\n        \"\"\"\n        noise = self.noise_prediction_fn(x, t)\n        dims = x.dim()\n        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(\n            t\n        ), self.noise_schedule.marginal_std(t)\n        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)\n        if self.thresholding:\n            p = 0.995  # A hyperparameter in the paper of \"Imagen\" [1].\n            s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)\n            s = expand_dims(\n                torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims\n            )\n            x0 = torch.clamp(x0, -s, s) / s\n        return x0\n\n    def model_fn(self, x, t):\n        \"\"\"\n        Convert the model to the noise prediction model or the data prediction model.\n        \"\"\"\n        if self.predict_x0:\n            return self.data_prediction_fn(x, t)\n        else:\n            return self.noise_prediction_fn(x, t)\n\n    def get_time_steps(self, skip_type, t_T, t_0, N, device):\n        \"\"\"Compute the intermediate time steps for sampling.\n\n        Args:\n            skip_type: A `str`. The type for the spacing of the time steps. We support three types:\n                - 'logSNR': uniform logSNR for the time steps.\n                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)\n                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)\n            t_T: A `float`. The starting time of the sampling (default is T).\n            t_0: A `float`. The ending time of the sampling (default is epsilon).\n            N: A `int`. The total number of the spacing of the time steps.\n            device: A torch device.\n        Returns:\n            A pytorch tensor of the time steps, with the shape (N + 1,).\n        \"\"\"\n        if skip_type == \"logSNR\":\n            lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))\n            lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))\n            logSNR_steps = torch.linspace(\n                lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1\n            ).to(device)\n            return self.noise_schedule.inverse_lambda(logSNR_steps)\n        elif skip_type == \"time_uniform\":\n            return torch.linspace(t_T, t_0, N + 1).to(device)\n        elif skip_type == \"time_quadratic\":\n            t_order = 2\n            t = (\n                torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1)\n                .pow(t_order)\n                .to(device)\n            )\n            return t\n        else:\n            raise ValueError(\n                \"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'\".format(\n                    skip_type\n                )\n            )\n\n    def get_orders_and_timesteps_for_singlestep_solver(\n        self, steps, order, skip_type, t_T, t_0, device\n    ):\n        \"\"\"\n        Get the order of each step for sampling by the singlestep DPM-Solver.\n\n        We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as \"DPM-Solver-fast\".\n        Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:\n            - If order == 1:\n                We take `steps` of DPM-Solver-1 (i.e. DDIM).\n            - If order == 2:\n                - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.\n                - If steps % 2 == 0, we use K steps of DPM-Solver-2.\n                - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.\n            - If order == 3:\n                - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.\n                - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.\n                - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.\n                - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.\n\n        ============================================\n        Args:\n            order: A `int`. The max order for the solver (2 or 3).\n            steps: A `int`. The total number of function evaluations (NFE).\n            skip_type: A `str`. The type for the spacing of the time steps. We support three types:\n                - 'logSNR': uniform logSNR for the time steps.\n                - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)\n                - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)\n            t_T: A `float`. The starting time of the sampling (default is T).\n            t_0: A `float`. The ending time of the sampling (default is epsilon).\n            device: A torch device.\n        Returns:\n            orders: A list of the solver order of each step.\n        \"\"\"\n        if order == 3:\n            K = steps // 3 + 1\n            if steps % 3 == 0:\n                orders = [\n                    3,\n                ] * (\n                    K - 2\n                ) + [2, 1]\n            elif steps % 3 == 1:\n                orders = [\n                    3,\n                ] * (\n                    K - 1\n                ) + [1]\n            else:\n                orders = [\n                    3,\n                ] * (\n                    K - 1\n                ) + [2]\n        elif order == 2:\n            if steps % 2 == 0:\n                K = steps // 2\n                orders = [\n                    2,\n                ] * K\n            else:\n                K = steps // 2 + 1\n                orders = [\n                    2,\n                ] * (\n                    K - 1\n                ) + [1]\n        elif order == 1:\n            K = 1\n            orders = [\n                1,\n            ] * steps\n        else:\n            raise ValueError(\"'order' must be '1' or '2' or '3'.\")\n        if skip_type == \"logSNR\":\n            # To reproduce the results in DPM-Solver paper\n            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)\n        else:\n            timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[\n                torch.cumsum(\n                    torch.tensor(\n                        [\n                            0,\n                        ]\n                        + orders\n                    )\n                ).to(device)\n            ]\n        return timesteps_outer, orders\n\n    def denoise_to_zero_fn(self, x, s):\n        \"\"\"\n        Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.\n        \"\"\"\n        return self.data_prediction_fn(x, s)\n\n    def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):\n        \"\"\"\n        DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            model_s: A pytorch tensor. The model function evaluated at time `s`.\n                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.\n            return_intermediate: A `bool`. If true, also return the model value at time `s`.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        ns = self.noise_schedule\n        dims = x.dim()\n        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)\n        h = lambda_t - lambda_s\n        log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(\n            s\n        ), ns.marginal_log_mean_coeff(t)\n        sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)\n        alpha_t = torch.exp(log_alpha_t)\n\n        if self.predict_x0:\n            phi_1 = torch.expm1(-h)\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            x_t = (\n                expand_dims(sigma_t / sigma_s, dims) * x\n                - expand_dims(alpha_t * phi_1, dims) * model_s\n            )\n            if return_intermediate:\n                return x_t, {\"model_s\": model_s}\n            else:\n                return x_t\n        else:\n            phi_1 = torch.expm1(h)\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            x_t = (\n                expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x\n                - expand_dims(sigma_t * phi_1, dims) * model_s\n            )\n            if return_intermediate:\n                return x_t, {\"model_s\": model_s}\n            else:\n                return x_t\n\n    def singlestep_dpm_solver_second_update(\n        self,\n        x,\n        s,\n        t,\n        r1=0.5,\n        model_s=None,\n        return_intermediate=False,\n        solver_type=\"dpm_solver\",\n    ):\n        \"\"\"\n        Singlestep solver DPM-Solver-2 from time `s` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            r1: A `float`. The hyperparameter of the second-order solver.\n            model_s: A pytorch tensor. The model function evaluated at time `s`.\n                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.\n            return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        if solver_type not in [\"dpm_solver\", \"taylor\"]:\n            raise ValueError(\n                \"'solver_type' must be either 'dpm_solver' or 'taylor', got {}\".format(\n                    solver_type\n                )\n            )\n        if r1 is None:\n            r1 = 0.5\n        ns = self.noise_schedule\n        dims = x.dim()\n        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)\n        h = lambda_t - lambda_s\n        lambda_s1 = lambda_s + r1 * h\n        s1 = ns.inverse_lambda(lambda_s1)\n        log_alpha_s, log_alpha_s1, log_alpha_t = (\n            ns.marginal_log_mean_coeff(s),\n            ns.marginal_log_mean_coeff(s1),\n            ns.marginal_log_mean_coeff(t),\n        )\n        sigma_s, sigma_s1, sigma_t = (\n            ns.marginal_std(s),\n            ns.marginal_std(s1),\n            ns.marginal_std(t),\n        )\n        alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)\n\n        if self.predict_x0:\n            phi_11 = torch.expm1(-r1 * h)\n            phi_1 = torch.expm1(-h)\n\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            x_s1 = (\n                expand_dims(sigma_s1 / sigma_s, dims) * x\n                - expand_dims(alpha_s1 * phi_11, dims) * model_s\n            )\n            model_s1 = self.model_fn(x_s1, s1)\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(sigma_t / sigma_s, dims) * x\n                    - expand_dims(alpha_t * phi_1, dims) * model_s\n                    - (0.5 / r1)\n                    * expand_dims(alpha_t * phi_1, dims)\n                    * (model_s1 - model_s)\n                )\n            elif solver_type == \"taylor\":\n                x_t = (\n                    expand_dims(sigma_t / sigma_s, dims) * x\n                    - expand_dims(alpha_t * phi_1, dims) * model_s\n                    + (1.0 / r1)\n                    * expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)\n                    * (model_s1 - model_s)\n                )\n        else:\n            phi_11 = torch.expm1(r1 * h)\n            phi_1 = torch.expm1(h)\n\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            x_s1 = (\n                expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x\n                - expand_dims(sigma_s1 * phi_11, dims) * model_s\n            )\n            model_s1 = self.model_fn(x_s1, s1)\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x\n                    - expand_dims(sigma_t * phi_1, dims) * model_s\n                    - (0.5 / r1)\n                    * expand_dims(sigma_t * phi_1, dims)\n                    * (model_s1 - model_s)\n                )\n            elif solver_type == \"taylor\":\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x\n                    - expand_dims(sigma_t * phi_1, dims) * model_s\n                    - (1.0 / r1)\n                    * expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)\n                    * (model_s1 - model_s)\n                )\n        if return_intermediate:\n            return x_t, {\"model_s\": model_s, \"model_s1\": model_s1}\n        else:\n            return x_t\n\n    def singlestep_dpm_solver_third_update(\n        self,\n        x,\n        s,\n        t,\n        r1=1.0 / 3.0,\n        r2=2.0 / 3.0,\n        model_s=None,\n        model_s1=None,\n        return_intermediate=False,\n        solver_type=\"dpm_solver\",\n    ):\n        \"\"\"\n        Singlestep solver DPM-Solver-3 from time `s` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            r1: A `float`. The hyperparameter of the third-order solver.\n            r2: A `float`. The hyperparameter of the third-order solver.\n            model_s: A pytorch tensor. The model function evaluated at time `s`.\n                If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.\n            model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).\n                If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.\n            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        if solver_type not in [\"dpm_solver\", \"taylor\"]:\n            raise ValueError(\n                \"'solver_type' must be either 'dpm_solver' or 'taylor', got {}\".format(\n                    solver_type\n                )\n            )\n        if r1 is None:\n            r1 = 1.0 / 3.0\n        if r2 is None:\n            r2 = 2.0 / 3.0\n        ns = self.noise_schedule\n        dims = x.dim()\n        lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)\n        h = lambda_t - lambda_s\n        lambda_s1 = lambda_s + r1 * h\n        lambda_s2 = lambda_s + r2 * h\n        s1 = ns.inverse_lambda(lambda_s1)\n        s2 = ns.inverse_lambda(lambda_s2)\n        log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (\n            ns.marginal_log_mean_coeff(s),\n            ns.marginal_log_mean_coeff(s1),\n            ns.marginal_log_mean_coeff(s2),\n            ns.marginal_log_mean_coeff(t),\n        )\n        sigma_s, sigma_s1, sigma_s2, sigma_t = (\n            ns.marginal_std(s),\n            ns.marginal_std(s1),\n            ns.marginal_std(s2),\n            ns.marginal_std(t),\n        )\n        alpha_s1, alpha_s2, alpha_t = (\n            torch.exp(log_alpha_s1),\n            torch.exp(log_alpha_s2),\n            torch.exp(log_alpha_t),\n        )\n\n        if self.predict_x0:\n            phi_11 = torch.expm1(-r1 * h)\n            phi_12 = torch.expm1(-r2 * h)\n            phi_1 = torch.expm1(-h)\n            phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0\n            phi_2 = phi_1 / h + 1.0\n            phi_3 = phi_2 / h - 0.5\n\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            if model_s1 is None:\n                x_s1 = (\n                    expand_dims(sigma_s1 / sigma_s, dims) * x\n                    - expand_dims(alpha_s1 * phi_11, dims) * model_s\n                )\n                model_s1 = self.model_fn(x_s1, s1)\n            x_s2 = (\n                expand_dims(sigma_s2 / sigma_s, dims) * x\n                - expand_dims(alpha_s2 * phi_12, dims) * model_s\n                + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)\n            )\n            model_s2 = self.model_fn(x_s2, s2)\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(sigma_t / sigma_s, dims) * x\n                    - expand_dims(alpha_t * phi_1, dims) * model_s\n                    + (1.0 / r2)\n                    * expand_dims(alpha_t * phi_2, dims)\n                    * (model_s2 - model_s)\n                )\n            elif solver_type == \"taylor\":\n                D1_0 = (1.0 / r1) * (model_s1 - model_s)\n                D1_1 = (1.0 / r2) * (model_s2 - model_s)\n                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)\n                D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)\n                x_t = (\n                    expand_dims(sigma_t / sigma_s, dims) * x\n                    - expand_dims(alpha_t * phi_1, dims) * model_s\n                    + expand_dims(alpha_t * phi_2, dims) * D1\n                    - expand_dims(alpha_t * phi_3, dims) * D2\n                )\n        else:\n            phi_11 = torch.expm1(r1 * h)\n            phi_12 = torch.expm1(r2 * h)\n            phi_1 = torch.expm1(h)\n            phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0\n            phi_2 = phi_1 / h - 1.0\n            phi_3 = phi_2 / h - 0.5\n\n            if model_s is None:\n                model_s = self.model_fn(x, s)\n            if model_s1 is None:\n                x_s1 = (\n                    expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x\n                    - expand_dims(sigma_s1 * phi_11, dims) * model_s\n                )\n                model_s1 = self.model_fn(x_s1, s1)\n            x_s2 = (\n                expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x\n                - expand_dims(sigma_s2 * phi_12, dims) * model_s\n                - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)\n            )\n            model_s2 = self.model_fn(x_s2, s2)\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x\n                    - expand_dims(sigma_t * phi_1, dims) * model_s\n                    - (1.0 / r2)\n                    * expand_dims(sigma_t * phi_2, dims)\n                    * (model_s2 - model_s)\n                )\n            elif solver_type == \"taylor\":\n                D1_0 = (1.0 / r1) * (model_s1 - model_s)\n                D1_1 = (1.0 / r2) * (model_s2 - model_s)\n                D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)\n                D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x\n                    - expand_dims(sigma_t * phi_1, dims) * model_s\n                    - expand_dims(sigma_t * phi_2, dims) * D1\n                    - expand_dims(sigma_t * phi_3, dims) * D2\n                )\n\n        if return_intermediate:\n            return x_t, {\"model_s\": model_s, \"model_s1\": model_s1, \"model_s2\": model_s2}\n        else:\n            return x_t\n\n    def multistep_dpm_solver_second_update(\n        self, x, model_prev_list, t_prev_list, t, solver_type=\"dpm_solver\"\n    ):\n        \"\"\"\n        Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            model_prev_list: A list of pytorch tensor. The previous computed model values.\n            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        if solver_type not in [\"dpm_solver\", \"taylor\"]:\n            raise ValueError(\n                \"'solver_type' must be either 'dpm_solver' or 'taylor', got {}\".format(\n                    solver_type\n                )\n            )\n        ns = self.noise_schedule\n        dims = x.dim()\n        model_prev_1, model_prev_0 = model_prev_list\n        t_prev_1, t_prev_0 = t_prev_list\n        lambda_prev_1, lambda_prev_0, lambda_t = (\n            ns.marginal_lambda(t_prev_1),\n            ns.marginal_lambda(t_prev_0),\n            ns.marginal_lambda(t),\n        )\n        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(\n            t_prev_0\n        ), ns.marginal_log_mean_coeff(t)\n        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)\n        alpha_t = torch.exp(log_alpha_t)\n\n        h_0 = lambda_prev_0 - lambda_prev_1\n        h = lambda_t - lambda_prev_0\n        r0 = h_0 / h\n        D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)\n        if self.predict_x0:\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(sigma_t / sigma_prev_0, dims) * x\n                    - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0\n                    - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * D1_0\n                )\n            elif solver_type == \"taylor\":\n                x_t = (\n                    expand_dims(sigma_t / sigma_prev_0, dims) * x\n                    - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0\n                    + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims)\n                    * D1_0\n                )\n        else:\n            if solver_type == \"dpm_solver\":\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x\n                    - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0\n                    - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * D1_0\n                )\n            elif solver_type == \"taylor\":\n                x_t = (\n                    expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x\n                    - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0\n                    - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims)\n                    * D1_0\n                )\n        return x_t\n\n    def multistep_dpm_solver_third_update(\n        self, x, model_prev_list, t_prev_list, t, solver_type=\"dpm_solver\"\n    ):\n        \"\"\"\n        Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            model_prev_list: A list of pytorch tensor. The previous computed model values.\n            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        ns = self.noise_schedule\n        dims = x.dim()\n        model_prev_2, model_prev_1, model_prev_0 = model_prev_list\n        t_prev_2, t_prev_1, t_prev_0 = t_prev_list\n        lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (\n            ns.marginal_lambda(t_prev_2),\n            ns.marginal_lambda(t_prev_1),\n            ns.marginal_lambda(t_prev_0),\n            ns.marginal_lambda(t),\n        )\n        log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(\n            t_prev_0\n        ), ns.marginal_log_mean_coeff(t)\n        sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)\n        alpha_t = torch.exp(log_alpha_t)\n\n        h_1 = lambda_prev_1 - lambda_prev_2\n        h_0 = lambda_prev_0 - lambda_prev_1\n        h = lambda_t - lambda_prev_0\n        r0, r1 = h_0 / h, h_1 / h\n        D1_0 = expand_dims(1.0 / r0, dims) * (model_prev_0 - model_prev_1)\n        D1_1 = expand_dims(1.0 / r1, dims) * (model_prev_1 - model_prev_2)\n        D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)\n        D2 = expand_dims(1.0 / (r0 + r1), dims) * (D1_0 - D1_1)\n        if self.predict_x0:\n            x_t = (\n                expand_dims(sigma_t / sigma_prev_0, dims) * x\n                - expand_dims(alpha_t * (torch.exp(-h) - 1.0), dims) * model_prev_0\n                + expand_dims(alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0), dims) * D1\n                - expand_dims(\n                    alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5), dims\n                )\n                * D2\n            )\n        else:\n            x_t = (\n                expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x\n                - expand_dims(sigma_t * (torch.exp(h) - 1.0), dims) * model_prev_0\n                - expand_dims(sigma_t * ((torch.exp(h) - 1.0) / h - 1.0), dims) * D1\n                - expand_dims(sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5), dims)\n                * D2\n            )\n        return x_t\n\n    def singlestep_dpm_solver_update(\n        self,\n        x,\n        s,\n        t,\n        order,\n        return_intermediate=False,\n        solver_type=\"dpm_solver\",\n        r1=None,\n        r2=None,\n    ):\n        \"\"\"\n        Singlestep DPM-Solver with the order `order` from time `s` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            s: A pytorch tensor. The starting time, with the shape (x.shape[0],).\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.\n            return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n            r1: A `float`. The hyperparameter of the second-order or third-order solver.\n            r2: A `float`. The hyperparameter of the third-order solver.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        if order == 1:\n            return self.dpm_solver_first_update(\n                x, s, t, return_intermediate=return_intermediate\n            )\n        elif order == 2:\n            return self.singlestep_dpm_solver_second_update(\n                x,\n                s,\n                t,\n                return_intermediate=return_intermediate,\n                solver_type=solver_type,\n                r1=r1,\n            )\n        elif order == 3:\n            return self.singlestep_dpm_solver_third_update(\n                x,\n                s,\n                t,\n                return_intermediate=return_intermediate,\n                solver_type=solver_type,\n                r1=r1,\n                r2=r2,\n            )\n        else:\n            raise ValueError(\"Solver order must be 1 or 2 or 3, got {}\".format(order))\n\n    def multistep_dpm_solver_update(\n        self, x, model_prev_list, t_prev_list, t, order, solver_type=\"dpm_solver\"\n    ):\n        \"\"\"\n        Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `s`.\n            model_prev_list: A list of pytorch tensor. The previous computed model values.\n            t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)\n            t: A pytorch tensor. The ending time, with the shape (x.shape[0],).\n            order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_t: A pytorch tensor. The approximated solution at time `t`.\n        \"\"\"\n        if order == 1:\n            return self.dpm_solver_first_update(\n                x, t_prev_list[-1], t, model_s=model_prev_list[-1]\n            )\n        elif order == 2:\n            return self.multistep_dpm_solver_second_update(\n                x, model_prev_list, t_prev_list, t, solver_type=solver_type\n            )\n        elif order == 3:\n            return self.multistep_dpm_solver_third_update(\n                x, model_prev_list, t_prev_list, t, solver_type=solver_type\n            )\n        else:\n            raise ValueError(\"Solver order must be 1 or 2 or 3, got {}\".format(order))\n\n    def dpm_solver_adaptive(\n        self,\n        x,\n        order,\n        t_T,\n        t_0,\n        h_init=0.05,\n        atol=0.0078,\n        rtol=0.05,\n        theta=0.9,\n        t_err=1e-5,\n        solver_type=\"dpm_solver\",\n    ):\n        \"\"\"\n        The adaptive step size solver based on singlestep DPM-Solver.\n\n        Args:\n            x: A pytorch tensor. The initial value at time `t_T`.\n            order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.\n            t_T: A `float`. The starting time of the sampling (default is T).\n            t_0: A `float`. The ending time of the sampling (default is epsilon).\n            h_init: A `float`. The initial step size (for logSNR).\n            atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].\n            rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.\n            theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].\n            t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the\n                current time and `t_0` is less than `t_err`. The default setting is 1e-5.\n            solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.\n                The type slightly impacts the performance. We recommend to use 'dpm_solver' type.\n        Returns:\n            x_0: A pytorch tensor. The approximated solution at time `t_0`.\n\n        [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, \"Gotta go fast when generating data with score-based models,\" arXiv preprint arXiv:2105.14080, 2021.\n        \"\"\"\n        ns = self.noise_schedule\n        s = t_T * torch.ones((x.shape[0],)).to(x)\n        lambda_s = ns.marginal_lambda(s)\n        lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))\n        h = h_init * torch.ones_like(s).to(x)\n        x_prev = x\n        nfe = 0\n        if order == 2:\n            r1 = 0.5\n            lower_update = lambda x, s, t: self.dpm_solver_first_update(\n                x, s, t, return_intermediate=True\n            )\n            higher_update = (\n                lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(\n                    x, s, t, r1=r1, solver_type=solver_type, **kwargs\n                )\n            )\n        elif order == 3:\n            r1, r2 = 1.0 / 3.0, 2.0 / 3.0\n            lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(\n                x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type\n            )\n            higher_update = (\n                lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(\n                    x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs\n                )\n            )\n        else:\n            raise ValueError(\n                \"For adaptive step size solver, order must be 2 or 3, got {}\".format(\n                    order\n                )\n            )\n        while torch.abs((s - t_0)).mean() > t_err:\n            t = ns.inverse_lambda(lambda_s + h)\n            x_lower, lower_noise_kwargs = lower_update(x, s, t)\n            x_higher = higher_update(x, s, t, **lower_noise_kwargs)\n            delta = torch.max(\n                torch.ones_like(x).to(x) * atol,\n                rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)),\n            )\n            norm_fn = lambda v: torch.sqrt(\n                torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)\n            )\n            E = norm_fn((x_higher - x_lower) / delta).max()\n            if torch.all(E <= 1.0):\n                x = x_higher\n                s = t\n                x_prev = x_lower\n                lambda_s = ns.marginal_lambda(s)\n            h = torch.min(\n                theta * h * torch.float_power(E, -1.0 / order).float(),\n                lambda_0 - lambda_s,\n            )\n            nfe += order\n        print(\"adaptive solver nfe\", nfe)\n        return x\n\n    def sample(\n        self,\n        x,\n        steps=20,\n        t_start=None,\n        t_end=None,\n        order=3,\n        skip_type=\"time_uniform\",\n        method=\"singlestep\",\n        lower_order_final=True,\n        denoise_to_zero=False,\n        solver_type=\"dpm_solver\",\n        atol=0.0078,\n        rtol=0.05,\n    ):\n        \"\"\"\n        Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.\n\n        =====================================================\n\n        We support the following algorithms for both noise prediction model and data prediction model:\n            - 'singlestep':\n                Singlestep DPM-Solver (i.e. \"DPM-Solver-fast\" in the paper), which combines different orders of singlestep DPM-Solver.\n                We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).\n                The total number of function evaluations (NFE) == `steps`.\n                Given a fixed NFE == `steps`, the sampling procedure is:\n                    - If `order` == 1:\n                        - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).\n                    - If `order` == 2:\n                        - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.\n                        - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.\n                        - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.\n                    - If `order` == 3:\n                        - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.\n                        - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.\n                        - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.\n                        - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.\n            - 'multistep':\n                Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.\n                We initialize the first `order` values by lower order multistep solvers.\n                Given a fixed NFE == `steps`, the sampling procedure is:\n                    Denote K = steps.\n                    - If `order` == 1:\n                        - We use K steps of DPM-Solver-1 (i.e. DDIM).\n                    - If `order` == 2:\n                        - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.\n                    - If `order` == 3:\n                        - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.\n            - 'singlestep_fixed':\n                Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).\n                We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.\n            - 'adaptive':\n                Adaptive step size DPM-Solver (i.e. \"DPM-Solver-12\" and \"DPM-Solver-23\" in the paper).\n                We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.\n                You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs\n                (NFE) and the sample quality.\n                    - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.\n                    - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.\n\n        =====================================================\n\n        Some advices for choosing the algorithm:\n            - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:\n                Use singlestep DPM-Solver (\"DPM-Solver-fast\" in the paper) with `order = 3`.\n                e.g.\n                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)\n                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,\n                            skip_type='time_uniform', method='singlestep')\n            - For **guided sampling with large guidance scale** by DPMs:\n                Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.\n                e.g.\n                    >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)\n                    >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,\n                            skip_type='time_uniform', method='multistep')\n\n        We support three types of `skip_type`:\n            - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**\n            - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.\n            - 'time_quadratic': quadratic time for the time steps.\n\n        =====================================================\n        Args:\n            x: A pytorch tensor. The initial value at time `t_start`\n                e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.\n            steps: A `int`. The total number of function evaluations (NFE).\n            t_start: A `float`. The starting time of the sampling.\n                If `T` is None, we use self.noise_schedule.T (default is 1.0).\n            t_end: A `float`. The ending time of the sampling.\n                If `t_end` is None, we use 1. / self.noise_schedule.total_N.\n                e.g. if total_N == 1000, we have `t_end` == 1e-3.\n                For discrete-time DPMs:\n                    - We recommend `t_end` == 1. / self.noise_schedule.total_N.\n                For continuous-time DPMs:\n                    - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.\n            order: A `int`. The order of DPM-Solver.\n            skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.\n            method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.\n            denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.\n                Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).\n\n                This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and\n                score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID\n                for diffusion models sampling by diffusion SDEs for low-resolutional images\n                (such as CIFAR-10). However, we observed that such trick does not matter for\n                high-resolutional images. As it needs an additional NFE, we do not recommend\n                it for high-resolutional images.\n            lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.\n                Only valid for `method=multistep` and `steps < 15`. We empirically find that\n                this trick is a key to stabilizing the sampling by DPM-Solver with very few steps\n                (especially for steps <= 10). So we recommend to set it to be `True`.\n            solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.\n            atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.\n            rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.\n        Returns:\n            x_end: A pytorch tensor. The approximated solution at time `t_end`.\n\n        \"\"\"\n        t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end\n        t_T = self.noise_schedule.T if t_start is None else t_start\n        device = x.device\n        if method == \"adaptive\":\n            with torch.no_grad():\n                x = self.dpm_solver_adaptive(\n                    x,\n                    order=order,\n                    t_T=t_T,\n                    t_0=t_0,\n                    atol=atol,\n                    rtol=rtol,\n                    solver_type=solver_type,\n                )\n        elif method == \"multistep\":\n            assert steps >= order\n            timesteps = self.get_time_steps(\n                skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device\n            )\n            assert timesteps.shape[0] - 1 == steps\n            with torch.no_grad():\n                vec_t = timesteps[0].expand((x.shape[0]))\n                model_prev_list = [self.model_fn(x, vec_t)]\n                t_prev_list = [vec_t]\n                # Init the first `order` values by lower order multistep DPM-Solver.\n                for init_order in range(1, order):\n                    vec_t = timesteps[init_order].expand(x.shape[0])\n                    x = self.multistep_dpm_solver_update(\n                        x,\n                        model_prev_list,\n                        t_prev_list,\n                        vec_t,\n                        init_order,\n                        solver_type=solver_type,\n                    )\n                    model_prev_list.append(self.model_fn(x, vec_t))\n                    t_prev_list.append(vec_t)\n                # Compute the remaining values by `order`-th order multistep DPM-Solver.\n                for step in range(order, steps + 1):\n                    vec_t = timesteps[step].expand(x.shape[0])\n                    if lower_order_final and steps < 15:\n                        step_order = min(order, steps + 1 - step)\n                    else:\n                        step_order = order\n                    x = self.multistep_dpm_solver_update(\n                        x,\n                        model_prev_list,\n                        t_prev_list,\n                        vec_t,\n                        step_order,\n                        solver_type=solver_type,\n                    )\n                    for i in range(order - 1):\n                        t_prev_list[i] = t_prev_list[i + 1]\n                        model_prev_list[i] = model_prev_list[i + 1]\n                    t_prev_list[-1] = vec_t\n                    # We do not need to evaluate the final model value.\n                    if step < steps:\n                        model_prev_list[-1] = self.model_fn(x, vec_t)\n        elif method in [\"singlestep\", \"singlestep_fixed\"]:\n            if method == \"singlestep\":\n                (\n                    timesteps_outer,\n                    orders,\n                ) = self.get_orders_and_timesteps_for_singlestep_solver(\n                    steps=steps,\n                    order=order,\n                    skip_type=skip_type,\n                    t_T=t_T,\n                    t_0=t_0,\n                    device=device,\n                )\n            elif method == \"singlestep_fixed\":\n                K = steps // order\n                orders = [\n                    order,\n                ] * K\n                timesteps_outer = self.get_time_steps(\n                    skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device\n                )\n            for i, order in enumerate(orders):\n                t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]\n                timesteps_inner = self.get_time_steps(\n                    skip_type=skip_type,\n                    t_T=t_T_inner.item(),\n                    t_0=t_0_inner.item(),\n                    N=order,\n                    device=device,\n                )\n                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)\n                vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])\n                h = lambda_inner[-1] - lambda_inner[0]\n                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h\n                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h\n                x = self.singlestep_dpm_solver_update(\n                    x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2\n                )\n        if denoise_to_zero:\n            x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)\n        return x\n\n\n#############################################################\n# other utility functions\n#############################################################\n\n\ndef interpolate_fn(x, xp, yp):\n    \"\"\"\n    A piecewise linear function y = f(x), using xp and yp as keypoints.\n    We implement f(x) in a differentiable way (i.e. applicable for autograd).\n    The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)\n\n    Args:\n        x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).\n        xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.\n        yp: PyTorch tensor with shape [C, K].\n    Returns:\n        The function values f(x), with shape [N, C].\n    \"\"\"\n    N, K = x.shape[0], xp.shape[1]\n    all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)\n    sorted_all_x, x_indices = torch.sort(all_x, dim=2)\n    x_idx = torch.argmin(x_indices, dim=2)\n    cand_start_idx = x_idx - 1\n    start_idx = torch.where(\n        torch.eq(x_idx, 0),\n        torch.tensor(1, device=x.device),\n        torch.where(\n            torch.eq(x_idx, K),\n            torch.tensor(K - 2, device=x.device),\n            cand_start_idx,\n        ),\n    )\n    end_idx = torch.where(\n        torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1\n    )\n    start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)\n    end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)\n    start_idx2 = torch.where(\n        torch.eq(x_idx, 0),\n        torch.tensor(0, device=x.device),\n        torch.where(\n            torch.eq(x_idx, K),\n            torch.tensor(K - 2, device=x.device),\n            cand_start_idx,\n        ),\n    )\n    y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)\n    start_y = torch.gather(\n        y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)\n    ).squeeze(2)\n    end_y = torch.gather(\n        y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)\n    ).squeeze(2)\n    cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)\n    return cand\n\n\ndef expand_dims(v, dims):\n    \"\"\"\n    Expand the tensor `v` to the dim `dims`.\n\n    Args:\n        `v`: a PyTorch tensor with shape [N].\n        `dim`: a `int`.\n    Returns:\n        a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.\n    \"\"\"\n    return v[(...,) + (None,) * (dims - 1)]\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/attention.py",
    "content": "from inspect import isfunction\nimport math\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn, einsum\nfrom einops import rearrange, repeat\n\nfrom semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (\n    checkpoint,\n)\n\n\ndef exists(val):\n    return val is not None\n\n\ndef uniq(arr):\n    return {el: True for el in arr}.keys()\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef max_neg_value(t):\n    return -torch.finfo(t.dtype).max\n\n\ndef init_(tensor):\n    dim = tensor.shape[-1]\n    std = 1 / math.sqrt(dim)\n    tensor.uniform_(-std, std)\n    return tensor\n\n\n# feedforward\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = (\n            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())\n            if not glu\n            else GEGLU(dim, inner_dim)\n        )\n\n        self.net = nn.Sequential(\n            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef Normalize(in_channels):\n    return torch.nn.GroupNorm(\n        num_groups=32, num_channels=in_channels, eps=1e-6, affine=True\n    )\n\n\nclass LinearAttention(nn.Module):\n    def __init__(self, dim, heads=4, dim_head=32):\n        super().__init__()\n        self.heads = heads\n        hidden_dim = dim_head * heads\n        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)\n        self.to_out = nn.Conv2d(hidden_dim, dim, 1)\n\n    def forward(self, x):\n        b, c, h, w = x.shape\n        qkv = self.to_qkv(x)\n        q, k, v = rearrange(\n            qkv, \"b (qkv heads c) h w -> qkv b heads c (h w)\", heads=self.heads, qkv=3\n        )\n        k = k.softmax(dim=-1)\n        context = torch.einsum(\"bhdn,bhen->bhde\", k, v)\n        out = torch.einsum(\"bhde,bhdn->bhen\", context, q)\n        out = rearrange(\n            out, \"b heads c (h w) -> b (heads c) h w\", heads=self.heads, h=h, w=w\n        )\n        return self.to_out(out)\n\n\nclass SpatialSelfAttention(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.k = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.v = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.proj_out = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, h, w = q.shape\n        q = rearrange(q, \"b c h w -> b (h w) c\")\n        k = rearrange(k, \"b c h w -> b c (h w)\")\n        w_ = torch.einsum(\"bij,bjk->bik\", q, k)\n\n        w_ = w_ * (int(c) ** (-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = rearrange(v, \"b c h w -> b c (h w)\")\n        w_ = rearrange(w_, \"b i j -> b j i\")\n        h_ = torch.einsum(\"bij,bjk->bik\", v, w_)\n        h_ = rearrange(h_, \"b c (h w) -> b c h w\", h=h)\n        h_ = self.proj_out(h_)\n\n        return x + h_\n\n\n# class CrossAttention(nn.Module):\n#     \"\"\"\n#     ### Cross Attention Layer\n#     This falls-back to self-attention when conditional embeddings are not specified.\n#     \"\"\"\n\n#     use_flash_attention: bool = True\n\n#     # use_flash_attention: bool = False\n#     def __init__(\n#         self,\n#         query_dim,\n#         context_dim=None,\n#         heads=8,\n#         dim_head=64,\n#         dropout=0.0,\n#         is_inplace: bool = True,\n#     ):\n#         # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):\n#         \"\"\"\n#         :param d_model: is the input embedding size\n#         :param n_heads: is the number of attention heads\n#         :param d_head: is the size of a attention head\n#         :param d_cond: is the size of the conditional embeddings\n#         :param is_inplace: specifies whether to perform the attention softmax computation inplace to\n#             save memory\n#         \"\"\"\n#         super().__init__()\n\n#         self.is_inplace = is_inplace\n#         self.n_heads = heads\n#         self.d_head = dim_head\n\n#         # Attention scaling factor\n#         self.scale = dim_head**-0.5\n\n#         # The normal self-attention layer\n#         if context_dim is None:\n#             context_dim = query_dim\n\n#         # Query, key and value mappings\n#         d_attn = dim_head * heads\n#         self.to_q = nn.Linear(query_dim, d_attn, bias=False)\n#         self.to_k = nn.Linear(context_dim, d_attn, bias=False)\n#         self.to_v = nn.Linear(context_dim, d_attn, bias=False)\n\n#         # Final linear layer\n#         self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))\n\n#         # Setup [flash attention](https://github.com/HazyResearch/flash-attention).\n#         # Flash attention is only used if it's installed\n#         # and `CrossAttention.use_flash_attention` is set to `True`.\n#         try:\n#             # You can install flash attention by cloning their Github repo,\n#             # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)\n#             # and then running `python setup.py install`\n#             from flash_attn.flash_attention import FlashAttention\n\n#             self.flash = FlashAttention()\n#             # Set the scale for scaled dot-product attention.\n#             self.flash.softmax_scale = self.scale\n#         # Set to `None` if it's not installed\n#         except ImportError:\n#             self.flash = None\n\n#     def forward(self, x, context=None, mask=None):\n#         \"\"\"\n#         :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`\n#         :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`\n#         \"\"\"\n\n#         # If `cond` is `None` we perform self attention\n#         has_cond = context is not None\n#         if not has_cond:\n#             context = x\n\n#         # Get query, key and value vectors\n#         q = self.to_q(x)\n#         k = self.to_k(context)\n#         v = self.to_v(context)\n\n#         # Use flash attention if it's available and the head size is less than or equal to `128`\n#         if (\n#             CrossAttention.use_flash_attention\n#             and self.flash is not None\n#             and not has_cond\n#             and self.d_head <= 128\n#         ):\n#             return self.flash_attention(q, k, v)\n#         # Otherwise, fallback to normal attention\n#         else:\n#             return self.normal_attention(q, k, v)\n\n#     def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):\n#         \"\"\"\n#         #### Flash Attention\n#         :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         \"\"\"\n\n#         # Get batch size and number of elements along sequence axis (`width * height`)\n#         batch_size, seq_len, _ = q.shape\n\n#         # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of\n#         # shape `[batch_size, seq_len, 3, n_heads * d_head]`\n#         qkv = torch.stack((q, k, v), dim=2)\n#         # Split the heads\n#         qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)\n\n#         # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to\n#         # fit this size.\n#         if self.d_head <= 32:\n#             pad = 32 - self.d_head\n#         elif self.d_head <= 64:\n#             pad = 64 - self.d_head\n#         elif self.d_head <= 128:\n#             pad = 128 - self.d_head\n#         else:\n#             raise ValueError(f\"Head size ${self.d_head} too large for Flash Attention\")\n\n#         # Pad the heads\n#         if pad:\n#             qkv = torch.cat(\n#                 (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1\n#             )\n\n#         # Compute attention\n#         # $$\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V$$\n#         # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`\n#         # TODO here I add the dtype changing\n#         out, _ = self.flash(qkv.type(torch.float16))\n#         # Truncate the extra head size\n#         out = out[:, :, :, : self.d_head].float()\n#         # Reshape to `[batch_size, seq_len, n_heads * d_head]`\n#         out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)\n\n#         # Map to `[batch_size, height * width, d_model]` with a linear layer\n#         return self.to_out(out)\n\n#     def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):\n#         \"\"\"\n#         #### Normal Attention\n\n#         :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`\n#         \"\"\"\n\n#         # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`\n#         q = q.view(*q.shape[:2], self.n_heads, -1)  # [bs, 64, 20, 32]\n#         k = k.view(*k.shape[:2], self.n_heads, -1)  # [bs, 1, 20, 32]\n#         v = v.view(*v.shape[:2], self.n_heads, -1)\n\n#         # Calculate attention $\\frac{Q K^\\top}{\\sqrt{d_{key}}}$\n#         attn = torch.einsum(\"bihd,bjhd->bhij\", q, k) * self.scale\n\n#         # Compute softmax\n#         # $$\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)$$\n#         if self.is_inplace:\n#             half = attn.shape[0] // 2\n#             attn[half:] = attn[half:].softmax(dim=-1)\n#             attn[:half] = attn[:half].softmax(dim=-1)\n#         else:\n#             attn = attn.softmax(dim=-1)\n\n#         # Compute attention output\n#         # $$\\underset{seq}{softmax}\\Bigg(\\frac{Q K^\\top}{\\sqrt{d_{key}}}\\Bigg)V$$\n#         # attn: [bs, 20, 64, 1]\n#         # v: [bs, 1, 20, 32]\n#         out = torch.einsum(\"bhij,bjhd->bihd\", attn, v)\n#         # Reshape to `[batch_size, height * width, n_heads * d_head]`\n#         out = out.reshape(*out.shape[:2], -1)\n#         # Map to `[batch_size, height * width, d_model]` with a linear layer\n#         return self.to_out(out)\n\n\nclass CrossAttention(nn.Module):\n    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):\n        super().__init__()\n        inner_dim = dim_head * heads\n        context_dim = default(context_dim, query_dim)\n\n        self.scale = dim_head**-0.5\n        self.heads = heads\n\n        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)\n\n        self.to_out = nn.Sequential(\n            nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)\n        )\n\n    def forward(self, x, context=None, mask=None):\n        h = self.heads\n\n        q = self.to_q(x)\n        context = default(context, x)\n\n        k = self.to_k(context)\n        v = self.to_v(context)\n\n        q, k, v = map(lambda t: rearrange(t, \"b n (h d) -> (b h) n d\", h=h), (q, k, v))\n\n        sim = einsum(\"b i d, b j d -> b i j\", q, k) * self.scale\n\n        if exists(mask):\n            mask = rearrange(mask, \"b ... -> b (...)\")\n            max_neg_value = -torch.finfo(sim.dtype).max\n            mask = repeat(mask, \"b j -> (b h) () j\", h=h)\n            sim.masked_fill_(~(mask == 1), max_neg_value)\n\n        # attention, what we cannot get enough of\n        attn = sim.softmax(dim=-1)\n\n        out = einsum(\"b i j, b j d -> b i d\", attn, v)\n        out = rearrange(out, \"(b h) n d -> b n (h d)\", h=h)\n        return self.to_out(out)\n\n\nclass BasicTransformerBlock(nn.Module):\n    def __init__(\n        self,\n        dim,\n        n_heads,\n        d_head,\n        dropout=0.0,\n        context_dim=None,\n        gated_ff=True,\n        checkpoint=True,\n    ):\n        super().__init__()\n        self.attn1 = CrossAttention(\n            query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout\n        )  # is a self-attention\n        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)\n        self.attn2 = CrossAttention(\n            query_dim=dim,\n            context_dim=context_dim,\n            heads=n_heads,\n            dim_head=d_head,\n            dropout=dropout,\n        )  # is self-attn if context is none\n        self.norm1 = nn.LayerNorm(dim)\n        self.norm2 = nn.LayerNorm(dim)\n        self.norm3 = nn.LayerNorm(dim)\n        self.checkpoint = checkpoint\n\n    def forward(self, x, context=None, mask=None):\n        if context is None:\n            return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)\n        else:\n            return checkpoint(\n                self._forward, (x, context, mask), self.parameters(), self.checkpoint\n            )\n\n    def _forward(self, x, context=None, mask=None):\n        x = self.attn1(self.norm1(x)) + x\n        x = self.attn2(self.norm2(x), context=context, mask=mask) + x\n        x = self.ff(self.norm3(x)) + x\n        return x\n\n\nclass SpatialTransformer(nn.Module):\n    \"\"\"\n    Transformer block for image-like data.\n    First, project the input (aka embedding)\n    and reshape to b, t, d.\n    Then apply standard transformer action.\n    Finally, reshape to image\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        n_heads,\n        d_head,\n        depth=1,\n        dropout=0.0,\n        context_dim=None,\n    ):\n        super().__init__()\n\n        context_dim = context_dim\n\n        self.in_channels = in_channels\n        inner_dim = n_heads * d_head\n        self.norm = Normalize(in_channels)\n\n        self.proj_in = nn.Conv2d(\n            in_channels, inner_dim, kernel_size=1, stride=1, padding=0\n        )\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                BasicTransformerBlock(\n                    inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim\n                )\n                for d in range(depth)\n            ]\n        )\n\n        self.proj_out = zero_module(\n            nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)\n        )\n\n    def forward(self, x, context=None, mask=None):\n        # note: if no context is given, cross-attention defaults to self-attention\n        b, c, h, w = x.shape\n        x_in = x\n        x = self.norm(x)\n        x = self.proj_in(x)\n        x = rearrange(x, \"b c h w -> b (h w) c\")\n        for block in self.transformer_blocks:\n            x = block(x, context=context, mask=mask)\n        x = rearrange(x, \"b (h w) c -> b c h w\", h=h, w=w)\n        x = self.proj_out(x)\n        return x + x_in\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/model.py",
    "content": "# pytorch_diffusion + derived encoder decoder\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import rearrange\n\nfrom semanticodec.modules.decoder.latent_diffusion.util import instantiate_from_config\nfrom semanticodec.modules.decoder.latent_diffusion.modules.attention import (\n    LinearAttention,\n)\n\n\ndef get_timestep_embedding(timesteps, embedding_dim):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models:\n    From Fairseq.\n    Build sinusoidal embeddings.\n    This matches the implementation in tensor2tensor, but differs slightly\n    from the description in Section 3.5 of \"Attention Is All You Need\".\n    \"\"\"\n    assert len(timesteps.shape) == 1\n\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - 1)\n    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)\n    emb = emb.to(device=timesteps.device)\n    emb = timesteps.float()[:, None] * emb[None, :]\n    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n    if embedding_dim % 2 == 1:  # zero pad\n        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))\n    return emb\n\n\ndef nonlinearity(x):\n    # swish\n    return x * torch.sigmoid(x)\n\n\ndef Normalize(in_channels, num_groups=32):\n    return torch.nn.GroupNorm(\n        num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True\n    )\n\n\nclass Upsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=3, stride=1, padding=1\n            )\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass UpsampleTimeStride4(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=5, stride=1, padding=2\n            )\n\n    def forward(self, x):\n        x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode=\"nearest\")\n        if self.with_conv:\n            x = self.conv(x)\n        return x\n\n\nclass Downsample(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # Do time downsampling here\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=3, stride=2, padding=0\n            )\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0, 1, 0, 1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)\n        return x\n\n\nclass DownsampleTimeStride4(nn.Module):\n    def __init__(self, in_channels, with_conv):\n        super().__init__()\n        self.with_conv = with_conv\n        if self.with_conv:\n            # Do time downsampling here\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1\n            )\n\n    def forward(self, x):\n        if self.with_conv:\n            pad = (0, 1, 0, 1)\n            x = torch.nn.functional.pad(x, pad, mode=\"constant\", value=0)\n            x = self.conv(x)\n        else:\n            x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2))\n        return x\n\n\nclass ResnetBlock(nn.Module):\n    def __init__(\n        self,\n        *,\n        in_channels,\n        out_channels=None,\n        conv_shortcut=False,\n        dropout,\n        temb_channels=512,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        out_channels = in_channels if out_channels is None else out_channels\n        self.out_channels = out_channels\n        self.use_conv_shortcut = conv_shortcut\n\n        self.norm1 = Normalize(in_channels)\n        self.conv1 = torch.nn.Conv2d(\n            in_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        if temb_channels > 0:\n            self.temb_proj = torch.nn.Linear(temb_channels, out_channels)\n        self.norm2 = Normalize(out_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n        self.conv2 = torch.nn.Conv2d(\n            out_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                self.conv_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=3, stride=1, padding=1\n                )\n            else:\n                self.nin_shortcut = torch.nn.Conv2d(\n                    in_channels, out_channels, kernel_size=1, stride=1, padding=0\n                )\n\n    def forward(self, x, temb):\n        h = x\n        h = self.norm1(h)\n        h = nonlinearity(h)\n        h = self.conv1(h)\n\n        if temb is not None:\n            h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]\n\n        h = self.norm2(h)\n        h = nonlinearity(h)\n        h = self.dropout(h)\n        h = self.conv2(h)\n\n        if self.in_channels != self.out_channels:\n            if self.use_conv_shortcut:\n                x = self.conv_shortcut(x)\n            else:\n                x = self.nin_shortcut(x)\n\n        return x + h\n\n\nclass LinAttnBlock(LinearAttention):\n    \"\"\"to match AttnBlock usage\"\"\"\n\n    def __init__(self, in_channels):\n        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)\n\n\nclass AttnBlock(nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.in_channels = in_channels\n\n        self.norm = Normalize(in_channels)\n        self.q = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.k = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.v = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n        self.proj_out = torch.nn.Conv2d(\n            in_channels, in_channels, kernel_size=1, stride=1, padding=0\n        )\n\n    def forward(self, x):\n        h_ = x\n        h_ = self.norm(h_)\n        q = self.q(h_)\n        k = self.k(h_)\n        v = self.v(h_)\n\n        # compute attention\n        b, c, h, w = q.shape\n        q = q.reshape(b, c, h * w).contiguous()\n        q = q.permute(0, 2, 1).contiguous()  # b,hw,c\n        k = k.reshape(b, c, h * w).contiguous()  # b,c,hw\n        w_ = torch.bmm(q, k).contiguous()  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n        w_ = w_ * (int(c) ** (-0.5))\n        w_ = torch.nn.functional.softmax(w_, dim=2)\n\n        # attend to values\n        v = v.reshape(b, c, h * w).contiguous()\n        w_ = w_.permute(0, 2, 1).contiguous()  # b,hw,hw (first hw of k, second of q)\n        h_ = torch.bmm(\n            v, w_\n        ).contiguous()  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n        h_ = h_.reshape(b, c, h, w).contiguous()\n\n        h_ = self.proj_out(h_)\n\n        return x + h_\n\n\ndef make_attn(in_channels, attn_type=\"vanilla\"):\n    assert attn_type in [\"vanilla\", \"linear\", \"none\"], f\"attn_type {attn_type} unknown\"\n    if attn_type == \"vanilla\":\n        return AttnBlock(in_channels)\n    elif attn_type == \"none\":\n        return nn.Identity(in_channels)\n    else:\n        return LinAttnBlock(in_channels)\n\n\nclass Model(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        use_timestep=True,\n        use_linear_attn=False,\n        attn_type=\"vanilla\",\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = self.ch * 4\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n\n        self.use_timestep = use_timestep\n        if self.use_timestep:\n            # timestep embedding\n            self.temb = nn.Module()\n            self.temb.dense = nn.ModuleList(\n                [\n                    torch.nn.Linear(self.ch, self.temb_ch),\n                    torch.nn.Linear(self.temb_ch, self.temb_ch),\n                ]\n            )\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(\n            in_channels, self.ch, kernel_size=3, stride=1, padding=1\n        )\n\n        curr_res = resolution\n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch * in_ch_mult[i_level]\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            skip_in = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                if i_block == self.num_res_blocks:\n                    skip_in = ch * in_ch_mult[i_level]\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in + skip_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up)  # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in, out_ch, kernel_size=3, stride=1, padding=1\n        )\n\n    def forward(self, x, t=None, context=None):\n        # assert x.shape[2] == x.shape[3] == self.resolution\n        if context is not None:\n            # assume aligned context, cat along channel axis\n            x = torch.cat((x, context), dim=1)\n        if self.use_timestep:\n            # timestep embedding\n            assert t is not None\n            temb = get_timestep_embedding(t, self.ch)\n            temb = self.temb.dense[0](temb)\n            temb = nonlinearity(temb)\n            temb = self.temb.dense[1](temb)\n        else:\n            temb = None\n\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions - 1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](\n                    torch.cat([h, hs.pop()], dim=1), temb\n                )\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n    def get_last_layer(self):\n        return self.conv_out.weight\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        z_channels,\n        double_z=True,\n        use_linear_attn=False,\n        attn_type=\"vanilla\",\n        downsample_time_stride4_levels=[],\n        **ignore_kwargs,\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.downsample_time_stride4_levels = downsample_time_stride4_levels\n\n        if len(self.downsample_time_stride4_levels) > 0:\n            assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (\n                \"The level to perform downsample 4 operation need to be smaller than the total resolution number %s\"\n                % str(self.num_resolutions)\n            )\n\n        # downsampling\n        self.conv_in = torch.nn.Conv2d(\n            in_channels, self.ch, kernel_size=3, stride=1, padding=1\n        )\n\n        curr_res = resolution\n        in_ch_mult = (1,) + tuple(ch_mult)\n        self.in_ch_mult = in_ch_mult\n        self.down = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_in = ch * in_ch_mult[i_level]\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks):\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            down = nn.Module()\n            down.block = block\n            down.attn = attn\n            if i_level != self.num_resolutions - 1:\n                if i_level in self.downsample_time_stride4_levels:\n                    down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv)\n                else:\n                    down.downsample = Downsample(block_in, resamp_with_conv)\n                curr_res = curr_res // 2\n            self.down.append(down)\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in,\n            2 * z_channels if double_z else z_channels,\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n    def forward(self, x):\n        # timestep embedding\n        temb = None\n        # downsampling\n        hs = [self.conv_in(x)]\n        for i_level in range(self.num_resolutions):\n            for i_block in range(self.num_res_blocks):\n                h = self.down[i_level].block[i_block](hs[-1], temb)\n                if len(self.down[i_level].attn) > 0:\n                    h = self.down[i_level].attn[i_block](h)\n                hs.append(h)\n            if i_level != self.num_resolutions - 1:\n                hs.append(self.down[i_level].downsample(hs[-1]))\n\n        # middle\n        h = hs[-1]\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # end\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        *,\n        ch,\n        out_ch,\n        ch_mult=(1, 2, 4, 8),\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        in_channels,\n        resolution,\n        z_channels,\n        give_pre_end=False,\n        tanh_out=False,\n        use_linear_attn=False,\n        downsample_time_stride4_levels=[],\n        attn_type=\"vanilla\",\n        **ignorekwargs,\n    ):\n        super().__init__()\n        if use_linear_attn:\n            attn_type = \"linear\"\n        self.ch = ch\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        self.resolution = resolution\n        self.in_channels = in_channels\n        self.give_pre_end = give_pre_end\n        self.tanh_out = tanh_out\n        self.downsample_time_stride4_levels = downsample_time_stride4_levels\n\n        if len(self.downsample_time_stride4_levels) > 0:\n            assert max(self.downsample_time_stride4_levels) < self.num_resolutions, (\n                \"The level to perform downsample 4 operation need to be smaller than the total resolution number %s\"\n                % str(self.num_resolutions)\n            )\n\n        # compute in_ch_mult, block_in and curr_res at lowest res\n        in_ch_mult = (1,) + tuple(ch_mult)\n        block_in = ch * ch_mult[self.num_resolutions - 1]\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.z_shape = (1, z_channels, curr_res, curr_res)\n        # print(\n        #     \"Working with z of shape {} = {} dimensions.\".format(\n        #         self.z_shape, np.prod(self.z_shape)\n        #     )\n        # )\n\n        # z to block_in\n        self.conv_in = torch.nn.Conv2d(\n            z_channels, block_in, kernel_size=3, stride=1, padding=1\n        )\n\n        # middle\n        self.mid = nn.Module()\n        self.mid.block_1 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n        self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)\n        self.mid.block_2 = ResnetBlock(\n            in_channels=block_in,\n            out_channels=block_in,\n            temb_channels=self.temb_ch,\n            dropout=dropout,\n        )\n\n        # upsampling\n        self.up = nn.ModuleList()\n        for i_level in reversed(range(self.num_resolutions)):\n            block = nn.ModuleList()\n            attn = nn.ModuleList()\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n                if curr_res in attn_resolutions:\n                    attn.append(make_attn(block_in, attn_type=attn_type))\n            up = nn.Module()\n            up.block = block\n            up.attn = attn\n            if i_level != 0:\n                if i_level - 1 in self.downsample_time_stride4_levels:\n                    up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv)\n                else:\n                    up.upsample = Upsample(block_in, resamp_with_conv)\n                curr_res = curr_res * 2\n            self.up.insert(0, up)  # prepend to get consistent order\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in, out_ch, kernel_size=3, stride=1, padding=1\n        )\n\n    def forward(self, z):\n        # assert z.shape[1:] == self.z_shape[1:]\n        self.last_z_shape = z.shape\n\n        # timestep embedding\n        temb = None\n\n        # z to block_in\n        h = self.conv_in(z)\n\n        # middle\n        h = self.mid.block_1(h, temb)\n        h = self.mid.attn_1(h)\n        h = self.mid.block_2(h, temb)\n\n        # upsampling\n        for i_level in reversed(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.up[i_level].block[i_block](h, temb)\n                if len(self.up[i_level].attn) > 0:\n                    h = self.up[i_level].attn[i_block](h)\n            if i_level != 0:\n                h = self.up[i_level].upsample(h)\n\n        # end\n        if self.give_pre_end:\n            return h\n\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        if self.tanh_out:\n            h = torch.tanh(h)\n        return h\n\n\nclass SimpleDecoder(nn.Module):\n    def __init__(self, in_channels, out_channels, *args, **kwargs):\n        super().__init__()\n        self.model = nn.ModuleList(\n            [\n                nn.Conv2d(in_channels, in_channels, 1),\n                ResnetBlock(\n                    in_channels=in_channels,\n                    out_channels=2 * in_channels,\n                    temb_channels=0,\n                    dropout=0.0,\n                ),\n                ResnetBlock(\n                    in_channels=2 * in_channels,\n                    out_channels=4 * in_channels,\n                    temb_channels=0,\n                    dropout=0.0,\n                ),\n                ResnetBlock(\n                    in_channels=4 * in_channels,\n                    out_channels=2 * in_channels,\n                    temb_channels=0,\n                    dropout=0.0,\n                ),\n                nn.Conv2d(2 * in_channels, in_channels, 1),\n                Upsample(in_channels, with_conv=True),\n            ]\n        )\n        # end\n        self.norm_out = Normalize(in_channels)\n        self.conv_out = torch.nn.Conv2d(\n            in_channels, out_channels, kernel_size=3, stride=1, padding=1\n        )\n\n    def forward(self, x):\n        for i, layer in enumerate(self.model):\n            if i in [1, 2, 3]:\n                x = layer(x, None)\n            else:\n                x = layer(x)\n\n        h = self.norm_out(x)\n        h = nonlinearity(h)\n        x = self.conv_out(h)\n        return x\n\n\nclass UpsampleDecoder(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        ch,\n        num_res_blocks,\n        resolution,\n        ch_mult=(2, 2),\n        dropout=0.0,\n    ):\n        super().__init__()\n        # upsampling\n        self.temb_ch = 0\n        self.num_resolutions = len(ch_mult)\n        self.num_res_blocks = num_res_blocks\n        block_in = in_channels\n        curr_res = resolution // 2 ** (self.num_resolutions - 1)\n        self.res_blocks = nn.ModuleList()\n        self.upsample_blocks = nn.ModuleList()\n        for i_level in range(self.num_resolutions):\n            res_block = []\n            block_out = ch * ch_mult[i_level]\n            for i_block in range(self.num_res_blocks + 1):\n                res_block.append(\n                    ResnetBlock(\n                        in_channels=block_in,\n                        out_channels=block_out,\n                        temb_channels=self.temb_ch,\n                        dropout=dropout,\n                    )\n                )\n                block_in = block_out\n            self.res_blocks.append(nn.ModuleList(res_block))\n            if i_level != self.num_resolutions - 1:\n                self.upsample_blocks.append(Upsample(block_in, True))\n                curr_res = curr_res * 2\n\n        # end\n        self.norm_out = Normalize(block_in)\n        self.conv_out = torch.nn.Conv2d(\n            block_in, out_channels, kernel_size=3, stride=1, padding=1\n        )\n\n    def forward(self, x):\n        # upsampling\n        h = x\n        for k, i_level in enumerate(range(self.num_resolutions)):\n            for i_block in range(self.num_res_blocks + 1):\n                h = self.res_blocks[i_level][i_block](h, None)\n            if i_level != self.num_resolutions - 1:\n                h = self.upsample_blocks[k](h)\n        h = self.norm_out(h)\n        h = nonlinearity(h)\n        h = self.conv_out(h)\n        return h\n\n\nclass LatentRescaler(nn.Module):\n    def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):\n        super().__init__()\n        # residual block, interpolate, residual block\n        self.factor = factor\n        self.conv_in = nn.Conv2d(\n            in_channels, mid_channels, kernel_size=3, stride=1, padding=1\n        )\n        self.res_block1 = nn.ModuleList(\n            [\n                ResnetBlock(\n                    in_channels=mid_channels,\n                    out_channels=mid_channels,\n                    temb_channels=0,\n                    dropout=0.0,\n                )\n                for _ in range(depth)\n            ]\n        )\n        self.attn = AttnBlock(mid_channels)\n        self.res_block2 = nn.ModuleList(\n            [\n                ResnetBlock(\n                    in_channels=mid_channels,\n                    out_channels=mid_channels,\n                    temb_channels=0,\n                    dropout=0.0,\n                )\n                for _ in range(depth)\n            ]\n        )\n\n        self.conv_out = nn.Conv2d(\n            mid_channels,\n            out_channels,\n            kernel_size=1,\n        )\n\n    def forward(self, x):\n        x = self.conv_in(x)\n        for block in self.res_block1:\n            x = block(x, None)\n        x = torch.nn.functional.interpolate(\n            x,\n            size=(\n                int(round(x.shape[2] * self.factor)),\n                int(round(x.shape[3] * self.factor)),\n            ),\n        )\n        x = self.attn(x).contiguous()\n        for block in self.res_block2:\n            x = block(x, None)\n        x = self.conv_out(x)\n        return x\n\n\nclass MergedRescaleEncoder(nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        ch,\n        resolution,\n        out_ch,\n        num_res_blocks,\n        attn_resolutions,\n        dropout=0.0,\n        resamp_with_conv=True,\n        ch_mult=(1, 2, 4, 8),\n        rescale_factor=1.0,\n        rescale_module_depth=1,\n    ):\n        super().__init__()\n        intermediate_chn = ch * ch_mult[-1]\n        self.encoder = Encoder(\n            in_channels=in_channels,\n            num_res_blocks=num_res_blocks,\n            ch=ch,\n            ch_mult=ch_mult,\n            z_channels=intermediate_chn,\n            double_z=False,\n            resolution=resolution,\n            attn_resolutions=attn_resolutions,\n            dropout=dropout,\n            resamp_with_conv=resamp_with_conv,\n            out_ch=None,\n        )\n        self.rescaler = LatentRescaler(\n            factor=rescale_factor,\n            in_channels=intermediate_chn,\n            mid_channels=intermediate_chn,\n            out_channels=out_ch,\n            depth=rescale_module_depth,\n        )\n\n    def forward(self, x):\n        x = self.encoder(x)\n        x = self.rescaler(x)\n        return x\n\n\nclass MergedRescaleDecoder(nn.Module):\n    def __init__(\n        self,\n        z_channels,\n        out_ch,\n        resolution,\n        num_res_blocks,\n        attn_resolutions,\n        ch,\n        ch_mult=(1, 2, 4, 8),\n        dropout=0.0,\n        resamp_with_conv=True,\n        rescale_factor=1.0,\n        rescale_module_depth=1,\n    ):\n        super().__init__()\n        tmp_chn = z_channels * ch_mult[-1]\n        self.decoder = Decoder(\n            out_ch=out_ch,\n            z_channels=tmp_chn,\n            attn_resolutions=attn_resolutions,\n            dropout=dropout,\n            resamp_with_conv=resamp_with_conv,\n            in_channels=None,\n            num_res_blocks=num_res_blocks,\n            ch_mult=ch_mult,\n            resolution=resolution,\n            ch=ch,\n        )\n        self.rescaler = LatentRescaler(\n            factor=rescale_factor,\n            in_channels=z_channels,\n            mid_channels=tmp_chn,\n            out_channels=tmp_chn,\n            depth=rescale_module_depth,\n        )\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Upsampler(nn.Module):\n    def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):\n        super().__init__()\n        assert out_size >= in_size\n        num_blocks = int(np.log2(out_size // in_size)) + 1\n        factor_up = 1.0 + (out_size % in_size)\n        print(\n            f\"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}\"\n        )\n        self.rescaler = LatentRescaler(\n            factor=factor_up,\n            in_channels=in_channels,\n            mid_channels=2 * in_channels,\n            out_channels=in_channels,\n        )\n        self.decoder = Decoder(\n            out_ch=out_channels,\n            resolution=out_size,\n            z_channels=in_channels,\n            num_res_blocks=2,\n            attn_resolutions=[],\n            in_channels=None,\n            ch=in_channels,\n            ch_mult=[ch_mult for _ in range(num_blocks)],\n        )\n\n    def forward(self, x):\n        x = self.rescaler(x)\n        x = self.decoder(x)\n        return x\n\n\nclass Resize(nn.Module):\n    def __init__(self, in_channels=None, learned=False, mode=\"bilinear\"):\n        super().__init__()\n        self.with_conv = learned\n        self.mode = mode\n        if self.with_conv:\n            print(\n                f\"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode\"\n            )\n            raise NotImplementedError()\n            assert in_channels is not None\n            # no asymmetric padding in torch conv, must do it ourselves\n            self.conv = torch.nn.Conv2d(\n                in_channels, in_channels, kernel_size=4, stride=2, padding=1\n            )\n\n    def forward(self, x, scale_factor=1.0):\n        if scale_factor == 1.0:\n            return x\n        else:\n            x = torch.nn.functional.interpolate(\n                x, mode=self.mode, align_corners=False, scale_factor=scale_factor\n            )\n        return x\n\n\nclass FirstStagePostProcessor(nn.Module):\n    def __init__(\n        self,\n        ch_mult: list,\n        in_channels,\n        pretrained_model: nn.Module = None,\n        reshape=False,\n        n_channels=None,\n        dropout=0.0,\n        pretrained_config=None,\n    ):\n        super().__init__()\n        if pretrained_config is None:\n            assert (\n                pretrained_model is not None\n            ), 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.pretrained_model = pretrained_model\n        else:\n            assert (\n                pretrained_config is not None\n            ), 'Either \"pretrained_model\" or \"pretrained_config\" must not be None'\n            self.instantiate_pretrained(pretrained_config)\n\n        self.do_reshape = reshape\n\n        if n_channels is None:\n            n_channels = self.pretrained_model.encoder.ch\n\n        self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)\n        self.proj = nn.Conv2d(\n            in_channels, n_channels, kernel_size=3, stride=1, padding=1\n        )\n\n        blocks = []\n        downs = []\n        ch_in = n_channels\n        for m in ch_mult:\n            blocks.append(\n                ResnetBlock(\n                    in_channels=ch_in, out_channels=m * n_channels, dropout=dropout\n                )\n            )\n            ch_in = m * n_channels\n            downs.append(Downsample(ch_in, with_conv=False))\n\n        self.model = nn.ModuleList(blocks)\n        self.downsampler = nn.ModuleList(downs)\n\n    def instantiate_pretrained(self, config):\n        model = instantiate_from_config(config)\n        self.pretrained_model = model.eval()\n        # self.pretrained_model.train = False\n        for param in self.pretrained_model.parameters():\n            param.requires_grad = False\n\n    @torch.no_grad()\n    def encode_with_pretrained(self, x):\n        c = self.pretrained_model.encode(x)\n        if isinstance(c, DiagonalGaussianDistribution):\n            c = c.mode()\n        return c\n\n    def forward(self, x):\n        z_fs = self.encode_with_pretrained(x)\n        z = self.proj_norm(z_fs)\n        z = self.proj(z)\n        z = nonlinearity(z)\n\n        for submodel, downmodel in zip(self.model, self.downsampler):\n            z = submodel(z, temb=None)\n            z = downmodel(z)\n\n        if self.do_reshape:\n            z = rearrange(z, \"b c h w -> b (h w) c\")\n        return z\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/openaimodel.py",
    "content": "from abc import abstractmethod\nfrom functools import partial\nimport math\nfrom typing import Iterable\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.util import (\n    checkpoint,\n    conv_nd,\n    linear,\n    avg_pool_nd,\n    zero_module,\n    normalization,\n    timestep_embedding,\n)\nfrom semanticodec.modules.decoder.latent_diffusion.modules.attention import (\n    SpatialTransformer,\n)\n\n\n# dummy replace\ndef convert_module_to_f16(x):\n    pass\n\n\ndef convert_module_to_f32(x):\n    pass\n\n\n## go\nclass AttentionPool2d(nn.Module):\n    \"\"\"\n    Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py\n    \"\"\"\n\n    def __init__(\n        self,\n        spacial_dim: int,\n        embed_dim: int,\n        num_heads_channels: int,\n        output_dim: int = None,\n    ):\n        super().__init__()\n        self.positional_embedding = nn.Parameter(\n            th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5\n        )\n        self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)\n        self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)\n        self.num_heads = embed_dim // num_heads_channels\n        self.attention = QKVAttention(self.num_heads)\n\n    def forward(self, x):\n        b, c, *_spatial = x.shape\n        x = x.reshape(b, c, -1).contiguous()  # NC(HW)\n        x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1)  # NC(HW+1)\n        x = x + self.positional_embedding[None, :, :].to(x.dtype)  # NC(HW+1)\n        x = self.qkv_proj(x)\n        x = self.attention(x)\n        x = self.c_proj(x)\n        return x[:, :, 0]\n\n\nclass TimestepBlock(nn.Module):\n    \"\"\"\n    Any module where forward() takes timestep embeddings as a second argument.\n    \"\"\"\n\n    @abstractmethod\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the module to `x` given `emb` timestep embeddings.\n        \"\"\"\n\n\nclass TimestepEmbedSequential(nn.Sequential, TimestepBlock):\n    \"\"\"\n    A sequential module that passes timestep embeddings to the children that\n    support it as an extra input.\n    \"\"\"\n\n    def forward(self, x, emb, context_list=None, mask_list=None):\n        # The first spatial transformer block does not have context\n        spatial_transformer_id = 0\n        context_list = [None] + context_list\n        mask_list = [None] + mask_list\n\n        for layer in self:\n            if isinstance(layer, TimestepBlock):\n                x = layer(x, emb)\n            elif isinstance(layer, SpatialTransformer):\n                if spatial_transformer_id >= len(context_list):\n                    context, mask = None, None\n                else:\n                    context, mask = (\n                        context_list[spatial_transformer_id],\n                        mask_list[spatial_transformer_id],\n                    )\n\n                x = layer(x, context, mask=mask)\n                spatial_transformer_id += 1\n            else:\n                x = layer(x)\n        return x\n\n\nclass Upsample(nn.Module):\n    \"\"\"\n    An upsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 upsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        if use_conv:\n            self.conv = conv_nd(\n                dims, self.channels, self.out_channels, 3, padding=padding\n            )\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        if self.dims == 3:\n            x = F.interpolate(\n                x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode=\"nearest\"\n            )\n        else:\n            x = F.interpolate(x, scale_factor=2, mode=\"nearest\")\n        if self.use_conv:\n            x = self.conv(x)\n        return x\n\n\nclass TransposedUpsample(nn.Module):\n    \"Learned 2x upsampling without padding\"\n\n    def __init__(self, channels, out_channels=None, ks=5):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n\n        self.up = nn.ConvTranspose2d(\n            self.channels, self.out_channels, kernel_size=ks, stride=2\n        )\n\n    def forward(self, x):\n        return self.up(x)\n\n\nclass Downsample(nn.Module):\n    \"\"\"\n    A downsampling layer with an optional convolution.\n    :param channels: channels in the inputs and outputs.\n    :param use_conv: a bool determining if a convolution is applied.\n    :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then\n                 downsampling occurs in the inner-two dimensions.\n    \"\"\"\n\n    def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):\n        super().__init__()\n        self.channels = channels\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.dims = dims\n        stride = 2 if dims != 3 else (1, 2, 2)\n        if use_conv:\n            self.op = conv_nd(\n                dims,\n                self.channels,\n                self.out_channels,\n                3,\n                stride=stride,\n                padding=padding,\n            )\n        else:\n            assert self.channels == self.out_channels\n            self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)\n\n    def forward(self, x):\n        assert x.shape[1] == self.channels\n        return self.op(x)\n\n\nclass ResBlock(TimestepBlock):\n    \"\"\"\n    A residual block that can optionally change the number of channels.\n    :param channels: the number of input channels.\n    :param emb_channels: the number of timestep embedding channels.\n    :param dropout: the rate of dropout.\n    :param out_channels: if specified, the number of out channels.\n    :param use_conv: if True and out_channels is specified, use a spatial\n        convolution instead of a smaller 1x1 convolution to change the\n        channels in the skip connection.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param use_checkpoint: if True, use gradient checkpointing on this module.\n    :param up: if True, use this block for upsampling.\n    :param down: if True, use this block for downsampling.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        emb_channels,\n        dropout,\n        out_channels=None,\n        use_conv=False,\n        use_scale_shift_norm=False,\n        dims=2,\n        use_checkpoint=False,\n        up=False,\n        down=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        self.emb_channels = emb_channels\n        self.dropout = dropout\n        self.out_channels = out_channels or channels\n        self.use_conv = use_conv\n        self.use_checkpoint = use_checkpoint\n        self.use_scale_shift_norm = use_scale_shift_norm\n\n        self.in_layers = nn.Sequential(\n            normalization(channels),\n            nn.SiLU(),\n            conv_nd(dims, channels, self.out_channels, 3, padding=1),\n        )\n\n        self.updown = up or down\n\n        if up:\n            self.h_upd = Upsample(channels, False, dims)\n            self.x_upd = Upsample(channels, False, dims)\n        elif down:\n            self.h_upd = Downsample(channels, False, dims)\n            self.x_upd = Downsample(channels, False, dims)\n        else:\n            self.h_upd = self.x_upd = nn.Identity()\n\n        self.emb_layers = nn.Sequential(\n            nn.SiLU(),\n            linear(\n                emb_channels,\n                2 * self.out_channels if use_scale_shift_norm else self.out_channels,\n            ),\n        )\n        self.out_layers = nn.Sequential(\n            normalization(self.out_channels),\n            nn.SiLU(),\n            nn.Dropout(p=dropout),\n            zero_module(\n                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)\n            ),\n        )\n\n        if self.out_channels == channels:\n            self.skip_connection = nn.Identity()\n        elif use_conv:\n            self.skip_connection = conv_nd(\n                dims, channels, self.out_channels, 3, padding=1\n            )\n        else:\n            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)\n\n    def forward(self, x, emb):\n        \"\"\"\n        Apply the block to a Tensor, conditioned on a timestep embedding.\n        :param x: an [N x C x ...] Tensor of features.\n        :param emb: an [N x emb_channels] Tensor of timestep embeddings.\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        return checkpoint(\n            self._forward, (x, emb), self.parameters(), self.use_checkpoint\n        )\n\n    def _forward(self, x, emb):\n        if self.updown:\n            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]\n            h = in_rest(x)\n            h = self.h_upd(h)\n            x = self.x_upd(x)\n            h = in_conv(h)\n        else:\n            h = self.in_layers(x)\n        emb_out = self.emb_layers(emb).type(h.dtype)\n        while len(emb_out.shape) < len(h.shape):\n            emb_out = emb_out[..., None]\n        if self.use_scale_shift_norm:\n            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]\n            scale, shift = th.chunk(emb_out, 2, dim=1)\n            h = out_norm(h) * (1 + scale) + shift\n            h = out_rest(h)\n        else:\n            h = h + emb_out\n            h = self.out_layers(h)\n        return self.skip_connection(x) + h\n\n\nclass AttentionBlock(nn.Module):\n    \"\"\"\n    An attention block that allows spatial positions to attend to each other.\n    Originally ported from here, but adapted to the N-d case.\n    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.\n    \"\"\"\n\n    def __init__(\n        self,\n        channels,\n        num_heads=1,\n        num_head_channels=-1,\n        use_checkpoint=False,\n        use_new_attention_order=False,\n    ):\n        super().__init__()\n        self.channels = channels\n        if num_head_channels == -1:\n            self.num_heads = num_heads\n        else:\n            assert (\n                channels % num_head_channels == 0\n            ), f\"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}\"\n            self.num_heads = channels // num_head_channels\n        self.use_checkpoint = use_checkpoint\n        self.norm = normalization(channels)\n        self.qkv = conv_nd(1, channels, channels * 3, 1)\n        if use_new_attention_order:\n            # split qkv before split heads\n            self.attention = QKVAttention(self.num_heads)\n        else:\n            # split heads before split qkv\n            self.attention = QKVAttentionLegacy(self.num_heads)\n\n        self.proj_out = zero_module(conv_nd(1, channels, channels, 1))\n\n    def forward(self, x):\n        return checkpoint(\n            self._forward, (x,), self.parameters(), True\n        )  # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!\n        # return pt_checkpoint(self._forward, x)  # pytorch\n\n    def _forward(self, x):\n        b, c, *spatial = x.shape\n        x = x.reshape(b, c, -1).contiguous()\n        qkv = self.qkv(self.norm(x)).contiguous()\n        h = self.attention(qkv).contiguous()\n        h = self.proj_out(h).contiguous()\n        return (x + h).reshape(b, c, *spatial).contiguous()\n\n\ndef count_flops_attn(model, _x, y):\n    \"\"\"\n    A counter for the `thop` package to count the operations in an\n    attention operation.\n    Meant to be used like:\n        macs, params = thop.profile(\n            model,\n            inputs=(inputs, timestamps),\n            custom_ops={QKVAttention: QKVAttention.count_flops},\n        )\n    \"\"\"\n    b, c, *spatial = y[0].shape\n    num_spatial = int(np.prod(spatial))\n    # We perform two matmuls with the same number of ops.\n    # The first computes the weight matrix, the second computes\n    # the combination of the value vectors.\n    matmul_ops = 2 * b * (num_spatial**2) * c\n    model.total_ops += th.DoubleTensor([matmul_ops])\n\n\nclass QKVAttentionLegacy(nn.Module):\n    \"\"\"\n    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = (\n            qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1)\n        )\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\", q * scale, k * scale\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\"bts,bcs->bct\", weight, v)\n        return a.reshape(bs, -1, length).contiguous()\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass QKVAttention(nn.Module):\n    \"\"\"\n    A module which performs QKV attention and splits in a different order.\n    \"\"\"\n\n    def __init__(self, n_heads):\n        super().__init__()\n        self.n_heads = n_heads\n\n    def forward(self, qkv):\n        \"\"\"\n        Apply QKV attention.\n        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.\n        :return: an [N x (H * C) x T] tensor after attention.\n        \"\"\"\n        bs, width, length = qkv.shape\n        assert width % (3 * self.n_heads) == 0\n        ch = width // (3 * self.n_heads)\n        q, k, v = qkv.chunk(3, dim=1)\n        scale = 1 / math.sqrt(math.sqrt(ch))\n        weight = th.einsum(\n            \"bct,bcs->bts\",\n            (q * scale).view(bs * self.n_heads, ch, length),\n            (k * scale).view(bs * self.n_heads, ch, length),\n        )  # More stable with f16 than dividing afterwards\n        weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)\n        a = th.einsum(\n            \"bts,bcs->bct\",\n            weight,\n            v.reshape(bs * self.n_heads, ch, length).contiguous(),\n        )\n        return a.reshape(bs, -1, length).contiguous()\n\n    @staticmethod\n    def count_flops(model, _x, y):\n        return count_flops_attn(model, _x, y)\n\n\nclass UNetModel(nn.Module):\n    \"\"\"\n    The full UNet model with attention and timestep embedding.\n    :param in_channels: channels in the input Tensor.\n    :param model_channels: base channel count for the model.\n    :param out_channels: channels in the output Tensor.\n    :param num_res_blocks: number of residual blocks per downsample.\n    :param attention_resolutions: a collection of downsample rates at which\n        attention will take place. May be a set, list, or tuple.\n        For example, if this contains 4, then at 4x downsampling, attention\n        will be used.\n    :param dropout: the dropout probability.\n    :param channel_mult: channel multiplier for each level of the UNet.\n    :param conv_resample: if True, use learned convolutions for upsampling and\n        downsampling.\n    :param dims: determines if the signal is 1D, 2D, or 3D.\n    :param num_classes: if specified (as an int), then this model will be\n        class-conditional with `num_classes` classes.\n    :param use_checkpoint: use gradient checkpointing to reduce memory usage.\n    :param num_heads: the number of attention heads in each attention layer.\n    :param num_heads_channels: if specified, ignore num_heads and instead use\n                               a fixed channel width per attention head.\n    :param num_heads_upsample: works with num_heads to set a different number\n                               of heads for upsampling. Deprecated.\n    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.\n    :param resblock_updown: use residual blocks for up/downsampling.\n    :param use_new_attention_order: use a different attention pattern for potentially\n                                    increased efficiency.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        extra_sa_layer=True,\n        num_classes=None,\n        extra_film_condition_dim=None,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=-1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        use_spatial_transformer=True,  # custom transformer support\n        transformer_depth=1,  # custom transformer support\n        context_dim=None,  # custom transformer support\n        n_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq model\n        legacy=True,\n    ):\n        super().__init__()\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        if num_heads == -1:\n            assert (\n                num_head_channels != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        if num_head_channels == -1:\n            assert (\n                num_heads != -1\n            ), \"Either num_heads or num_head_channels has to be set\"\n\n        self.image_size = image_size\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.num_classes = num_classes\n        self.extra_film_condition_dim = extra_film_condition_dim\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n        self.predict_codebook_ids = n_embed is not None\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        # assert not (\n        #     self.num_classes is not None and self.extra_film_condition_dim is not None\n        # ), \"As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim.\"\n\n        if self.num_classes is not None:\n            self.label_emb = nn.Embedding(num_classes, time_embed_dim)\n\n        self.use_extra_film_by_concat = self.extra_film_condition_dim is not None\n\n        if self.extra_film_condition_dim is not None:\n            self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim)\n            print(\n                \"+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. \"\n                % self.extra_film_condition_dim\n            )\n\n        if context_dim is not None and not use_spatial_transformer:\n            assert (\n                use_spatial_transformer\n            ), \"Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...\"\n\n        if context_dim is not None and not isinstance(context_dim, list):\n            context_dim = [context_dim]\n        elif context_dim is None:\n            context_dim = [None]  # At least use one spatial transformer\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim\n                        if (not self.use_extra_film_by_concat)\n                        else time_embed_dim * 2,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n                    if extra_sa_layer:\n                        layers.append(\n                            SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth,\n                                context_dim=None,\n                            )\n                        )\n                    for context_dim_id in range(len(context_dim)):\n                        layers.append(\n                            AttentionBlock(\n                                ch,\n                                use_checkpoint=use_checkpoint,\n                                num_heads=num_heads,\n                                num_head_channels=dim_head,\n                                use_new_attention_order=use_new_attention_order,\n                            )\n                            if not use_spatial_transformer\n                            else SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth,\n                                context_dim=context_dim[context_dim_id],\n                            )\n                        )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim\n                            if (not self.use_extra_film_by_concat)\n                            else time_embed_dim * 2,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        if num_head_channels == -1:\n            dim_head = ch // num_heads\n        else:\n            num_heads = ch // num_head_channels\n            dim_head = num_head_channels\n        if legacy:\n            # num_heads = 1\n            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels\n        middle_layers = [\n            ResBlock(\n                ch,\n                time_embed_dim\n                if (not self.use_extra_film_by_concat)\n                else time_embed_dim * 2,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            )\n        ]\n        if extra_sa_layer:\n            middle_layers.append(\n                SpatialTransformer(\n                    ch, num_heads, dim_head, depth=transformer_depth, context_dim=None\n                )\n            )\n        for context_dim_id in range(len(context_dim)):\n            middle_layers.append(\n                AttentionBlock(\n                    ch,\n                    use_checkpoint=use_checkpoint,\n                    num_heads=num_heads,\n                    num_head_channels=dim_head,\n                    use_new_attention_order=use_new_attention_order,\n                )\n                if not use_spatial_transformer\n                else SpatialTransformer(\n                    ch,\n                    num_heads,\n                    dim_head,\n                    depth=transformer_depth,\n                    context_dim=context_dim[context_dim_id],\n                )\n            )\n        middle_layers.append(\n            ResBlock(\n                ch,\n                time_embed_dim\n                if (not self.use_extra_film_by_concat)\n                else time_embed_dim * 2,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            )\n        )\n        self.middle_block = TimestepEmbedSequential(*middle_layers)\n\n        self._feature_size += ch\n\n        self.output_blocks = nn.ModuleList([])\n        for level, mult in list(enumerate(channel_mult))[::-1]:\n            for i in range(num_res_blocks + 1):\n                ich = input_block_chans.pop()\n                layers = [\n                    ResBlock(\n                        ch + ich,\n                        time_embed_dim\n                        if (not self.use_extra_film_by_concat)\n                        else time_embed_dim * 2,\n                        dropout,\n                        out_channels=model_channels * mult,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = model_channels * mult\n                if ds in attention_resolutions:\n                    if num_head_channels == -1:\n                        dim_head = ch // num_heads\n                    else:\n                        num_heads = ch // num_head_channels\n                        dim_head = num_head_channels\n                    if legacy:\n                        # num_heads = 1\n                        dim_head = (\n                            ch // num_heads\n                            if use_spatial_transformer\n                            else num_head_channels\n                        )\n                    if extra_sa_layer:\n                        layers.append(\n                            SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth,\n                                context_dim=None,\n                            )\n                        )\n                    for context_dim_id in range(len(context_dim)):\n                        layers.append(\n                            AttentionBlock(\n                                ch,\n                                use_checkpoint=use_checkpoint,\n                                num_heads=num_heads_upsample,\n                                num_head_channels=dim_head,\n                                use_new_attention_order=use_new_attention_order,\n                            )\n                            if not use_spatial_transformer\n                            else SpatialTransformer(\n                                ch,\n                                num_heads,\n                                dim_head,\n                                depth=transformer_depth,\n                                context_dim=context_dim[context_dim_id],\n                            )\n                        )\n                if level and i == num_res_blocks:\n                    out_ch = ch\n                    layers.append(\n                        ResBlock(\n                            ch,\n                            time_embed_dim\n                            if (not self.use_extra_film_by_concat)\n                            else time_embed_dim * 2,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            up=True,\n                        )\n                        if resblock_updown\n                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)\n                    )\n                    ds //= 2\n                self.output_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n\n        self.out = nn.Sequential(\n            normalization(ch),\n            nn.SiLU(),\n            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),\n        )\n        if self.predict_codebook_ids:\n            self.id_predictor = nn.Sequential(\n                normalization(ch),\n                conv_nd(dims, model_channels, n_embed, 1),\n                # nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits\n            )\n\n        self.shape_reported = False\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n        self.output_blocks.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n        self.output_blocks.apply(convert_module_to_f32)\n\n    def forward(\n        self,\n        x,\n        timesteps=None,\n        y=None,\n        context_list=None,\n        context_attn_mask_list=None,\n        **kwargs,\n    ):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :param context: conditioning plugged in via crossattn\n        :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional\n        :return: an [N x C x ...] Tensor of outputs.\n        \"\"\"\n        if not self.shape_reported:\n            self.shape_reported = True\n\n        assert (y is not None) == (\n            self.num_classes is not None or self.extra_film_condition_dim is not None\n        ), \"must specify y if and only if the model is class-conditional or film embedding conditional\"\n        hs = []\n        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)\n        emb = self.time_embed(t_emb)\n\n        # if self.num_classes is not None:\n        #     assert y.shape == (x.shape[0],)\n        #     emb = emb + self.label_emb(y)\n\n        if self.use_extra_film_by_concat:\n            emb = th.cat([emb, self.film_emb(y)], dim=-1)\n\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb, context_list, context_attn_mask_list)\n            hs.append(h)\n        h = self.middle_block(h, emb, context_list, context_attn_mask_list)\n        for module in self.output_blocks:\n            concate_tensor = hs.pop()\n            h = th.cat([h, concate_tensor], dim=1)\n            h = module(h, emb, context_list, context_attn_mask_list)\n        h = h.type(x.dtype)\n        if self.predict_codebook_ids:\n            return self.id_predictor(h)\n        else:\n            return self.out(h)\n\n\nclass EncoderUNetModel(nn.Module):\n    \"\"\"\n    The half UNet model with attention and timestep embedding.\n    For usage, see UNet.\n    \"\"\"\n\n    def __init__(\n        self,\n        image_size,\n        in_channels,\n        model_channels,\n        out_channels,\n        num_res_blocks,\n        attention_resolutions,\n        dropout=0,\n        channel_mult=(1, 2, 4, 8),\n        conv_resample=True,\n        dims=2,\n        use_checkpoint=False,\n        use_fp16=False,\n        num_heads=1,\n        num_head_channels=-1,\n        num_heads_upsample=-1,\n        use_scale_shift_norm=False,\n        resblock_updown=False,\n        use_new_attention_order=False,\n        pool=\"adaptive\",\n        *args,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if num_heads_upsample == -1:\n            num_heads_upsample = num_heads\n\n        self.in_channels = in_channels\n        self.model_channels = model_channels\n        self.out_channels = out_channels\n        self.num_res_blocks = num_res_blocks\n        self.attention_resolutions = attention_resolutions\n        self.dropout = dropout\n        self.channel_mult = channel_mult\n        self.conv_resample = conv_resample\n        self.use_checkpoint = use_checkpoint\n        self.dtype = th.float16 if use_fp16 else th.float32\n        self.num_heads = num_heads\n        self.num_head_channels = num_head_channels\n        self.num_heads_upsample = num_heads_upsample\n\n        time_embed_dim = model_channels * 4\n        self.time_embed = nn.Sequential(\n            linear(model_channels, time_embed_dim),\n            nn.SiLU(),\n            linear(time_embed_dim, time_embed_dim),\n        )\n\n        self.input_blocks = nn.ModuleList(\n            [\n                TimestepEmbedSequential(\n                    conv_nd(dims, in_channels, model_channels, 3, padding=1)\n                )\n            ]\n        )\n        self._feature_size = model_channels\n        input_block_chans = [model_channels]\n        ch = model_channels\n        ds = 1\n        for level, mult in enumerate(channel_mult):\n            for _ in range(num_res_blocks):\n                layers = [\n                    ResBlock(\n                        ch,\n                        time_embed_dim,\n                        dropout,\n                        out_channels=mult * model_channels,\n                        dims=dims,\n                        use_checkpoint=use_checkpoint,\n                        use_scale_shift_norm=use_scale_shift_norm,\n                    )\n                ]\n                ch = mult * model_channels\n                if ds in attention_resolutions:\n                    layers.append(\n                        AttentionBlock(\n                            ch,\n                            use_checkpoint=use_checkpoint,\n                            num_heads=num_heads,\n                            num_head_channels=num_head_channels,\n                            use_new_attention_order=use_new_attention_order,\n                        )\n                    )\n                self.input_blocks.append(TimestepEmbedSequential(*layers))\n                self._feature_size += ch\n                input_block_chans.append(ch)\n            if level != len(channel_mult) - 1:\n                out_ch = ch\n                self.input_blocks.append(\n                    TimestepEmbedSequential(\n                        ResBlock(\n                            ch,\n                            time_embed_dim,\n                            dropout,\n                            out_channels=out_ch,\n                            dims=dims,\n                            use_checkpoint=use_checkpoint,\n                            use_scale_shift_norm=use_scale_shift_norm,\n                            down=True,\n                        )\n                        if resblock_updown\n                        else Downsample(\n                            ch, conv_resample, dims=dims, out_channels=out_ch\n                        )\n                    )\n                )\n                ch = out_ch\n                input_block_chans.append(ch)\n                ds *= 2\n                self._feature_size += ch\n\n        self.middle_block = TimestepEmbedSequential(\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n            AttentionBlock(\n                ch,\n                use_checkpoint=use_checkpoint,\n                num_heads=num_heads,\n                num_head_channels=num_head_channels,\n                use_new_attention_order=use_new_attention_order,\n            ),\n            ResBlock(\n                ch,\n                time_embed_dim,\n                dropout,\n                dims=dims,\n                use_checkpoint=use_checkpoint,\n                use_scale_shift_norm=use_scale_shift_norm,\n            ),\n        )\n        self._feature_size += ch\n        self.pool = pool\n        if pool == \"adaptive\":\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                nn.AdaptiveAvgPool2d((1, 1)),\n                zero_module(conv_nd(dims, ch, out_channels, 1)),\n                nn.Flatten(),\n            )\n        elif pool == \"attention\":\n            assert num_head_channels != -1\n            self.out = nn.Sequential(\n                normalization(ch),\n                nn.SiLU(),\n                AttentionPool2d(\n                    (image_size // ds), ch, num_head_channels, out_channels\n                ),\n            )\n        elif pool == \"spatial\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                nn.ReLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        elif pool == \"spatial_v2\":\n            self.out = nn.Sequential(\n                nn.Linear(self._feature_size, 2048),\n                normalization(2048),\n                nn.SiLU(),\n                nn.Linear(2048, self.out_channels),\n            )\n        else:\n            raise NotImplementedError(f\"Unexpected {pool} pooling\")\n\n    def convert_to_fp16(self):\n        \"\"\"\n        Convert the torso of the model to float16.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f16)\n        self.middle_block.apply(convert_module_to_f16)\n\n    def convert_to_fp32(self):\n        \"\"\"\n        Convert the torso of the model to float32.\n        \"\"\"\n        self.input_blocks.apply(convert_module_to_f32)\n        self.middle_block.apply(convert_module_to_f32)\n\n    def forward(self, x, timesteps):\n        \"\"\"\n        Apply the model to an input batch.\n        :param x: an [N x C x ...] Tensor of inputs.\n        :param timesteps: a 1-D batch of timesteps.\n        :return: an [N x K] Tensor of outputs.\n        \"\"\"\n        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))\n\n        results = []\n        h = x.type(self.dtype)\n        for module in self.input_blocks:\n            h = module(h, emb)\n            if self.pool.startswith(\"spatial\"):\n                results.append(h.type(x.dtype).mean(dim=(2, 3)))\n        h = self.middle_block(h, emb)\n        if self.pool.startswith(\"spatial\"):\n            results.append(h.type(x.dtype).mean(dim=(2, 3)))\n            h = th.cat(results, axis=-1)\n            return self.out(h)\n        else:\n            h = h.type(x.dtype)\n            return self.out(h)\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/diffusionmodules/util.py",
    "content": "# adopted from\n# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py\n# and\n# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py\n# and\n# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py\n#\n# thanks!\n\n\nimport os\nimport math\nimport torch\nimport torch.nn as nn\nimport numpy as np\nfrom einops import repeat\n\nfrom semanticodec.modules.decoder.latent_diffusion.util import instantiate_from_config\n\n\ndef make_beta_schedule(\n    schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3\n):\n    if schedule == \"linear\":\n        betas = (\n            torch.linspace(\n                linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64\n            )\n            ** 2\n        )\n\n    elif schedule == \"cosine\":\n        timesteps = (\n            torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s\n        )\n        alphas = timesteps / (1 + cosine_s) * np.pi / 2\n        alphas = torch.cos(alphas).pow(2)\n        alphas = alphas / alphas[0]\n        betas = 1 - alphas[1:] / alphas[:-1]\n        # betas = np.clip(betas, a_min=0, a_max=0.999)\n\n    elif schedule == \"sqrt_linear\":\n        betas = torch.linspace(\n            linear_start, linear_end, n_timestep, dtype=torch.float64\n        )\n    elif schedule == \"sqrt\":\n        betas = (\n            torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)\n            ** 0.5\n        )\n    else:\n        raise ValueError(f\"schedule '{schedule}' unknown.\")\n    return betas.numpy()\n\n\ndef make_ddim_timesteps(\n    ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True\n):\n    if ddim_discr_method == \"uniform\":\n        c = num_ddpm_timesteps // num_ddim_timesteps\n        ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))\n    elif ddim_discr_method == \"quad\":\n        ddim_timesteps = (\n            (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2\n        ).astype(int)\n    else:\n        raise NotImplementedError(\n            f'There is no ddim discretization method called \"{ddim_discr_method}\"'\n        )\n\n    # assert ddim_timesteps.shape[0] == num_ddim_timesteps\n    # add one to get the final alpha values right (the ones from first scale to data during sampling)\n    steps_out = ddim_timesteps + 1\n    if verbose:\n        print(f\"Selected timesteps for ddim sampler: {steps_out}\")\n    return steps_out\n\n\ndef make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):\n    # select alphas for computing the variance schedule\n    alphas = alphacums[ddim_timesteps]\n    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())\n\n    # according the the formula provided in https://arxiv.org/abs/2010.02502\n    sigmas = eta * np.sqrt(\n        (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)\n    )\n    if verbose:\n        print(\n            f\"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}\"\n        )\n        print(\n            f\"For the chosen value of eta, which is {eta}, \"\n            f\"this results in the following sigma_t schedule for ddim sampler {sigmas}\"\n        )\n    return sigmas, alphas, alphas_prev\n\n\ndef betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):\n    \"\"\"\n    Create a beta schedule that discretizes the given alpha_t_bar function,\n    which defines the cumulative product of (1-beta) over time from t = [0,1].\n    :param num_diffusion_timesteps: the number of betas to produce.\n    :param alpha_bar: a lambda that takes an argument t from 0 to 1 and\n                      produces the cumulative product of (1-beta) up to that\n                      part of the diffusion process.\n    :param max_beta: the maximum beta to use; use values lower than 1 to\n                     prevent singularities.\n    \"\"\"\n    betas = []\n    for i in range(num_diffusion_timesteps):\n        t1 = i / num_diffusion_timesteps\n        t2 = (i + 1) / num_diffusion_timesteps\n        betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))\n    return np.array(betas)\n\n\ndef extract_into_tensor(a, t, x_shape):\n    b, *_ = t.shape\n    out = a.gather(-1, t).contiguous()\n    return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous()\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n\n        with torch.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with torch.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = torch.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period)\n            * torch.arange(start=0, end=half, dtype=torch.float32)\n            / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat(\n                [embedding, torch.zeros_like(embedding[:, :1])], dim=-1\n            )\n    else:\n        embedding = repeat(timesteps, \"b -> b d\", d=dim)\n    return embedding\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels):\n    \"\"\"\n    Make a standard normalization layer.\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(32, channels)\n\n\n# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.\nclass SiLU(nn.Module):\n    def forward(self, x):\n        return x * torch.sigmoid(x)\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def forward(self, x):\n        return super().forward(x.float()).type(x.dtype)\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\nclass HybridConditioner(nn.Module):\n    def __init__(self, c_concat_config, c_crossattn_config):\n        super().__init__()\n        self.concat_conditioner = instantiate_from_config(c_concat_config)\n        self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)\n\n    def forward(self, c_concat, c_crossattn):\n        c_concat = self.concat_conditioner(c_concat)\n        c_crossattn = self.crossattn_conditioner(c_crossattn)\n        return {\"c_concat\": [c_concat], \"c_crossattn\": [c_crossattn]}\n\n\ndef noise_like(shape, device, repeat=False):\n    repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(\n        shape[0], *((1,) * (len(shape) - 1))\n    )\n    noise = lambda: torch.randn(shape, device=device)\n    return repeat_noise() if repeat else noise()\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/distributions/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/distributions/distributions.py",
    "content": "import torch\nimport numpy as np\n\n\nclass AbstractDistribution:\n    def sample(self):\n        raise NotImplementedError()\n\n    def mode(self):\n        raise NotImplementedError()\n\n\nclass DiracDistribution(AbstractDistribution):\n    def __init__(self, value):\n        self.value = value\n\n    def sample(self):\n        return self.value\n\n    def mode(self):\n        return self.value\n\n\nclass DiagonalGaussianDistribution(object):\n    def __init__(self, parameters, deterministic=False):\n        self.parameters = parameters\n        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)\n        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)\n        self.deterministic = deterministic\n        self.std = torch.exp(0.5 * self.logvar)\n        self.var = torch.exp(self.logvar)\n        if self.deterministic:\n            self.var = self.std = torch.zeros_like(self.mean).to(\n                device=self.parameters.device\n            )\n\n    def sample(self):\n        x = self.mean + self.std * torch.randn(self.mean.shape).to(\n            device=self.parameters.device\n        )\n        return x\n\n    def kl(self, other=None):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        else:\n            if other is None:\n                return 0.5 * torch.mean(\n                    torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,\n                    dim=[1, 2, 3],\n                )\n            else:\n                return 0.5 * torch.mean(\n                    torch.pow(self.mean - other.mean, 2) / other.var\n                    + self.var / other.var\n                    - 1.0\n                    - self.logvar\n                    + other.logvar,\n                    dim=[1, 2, 3],\n                )\n\n    def nll(self, sample, dims=[1, 2, 3]):\n        if self.deterministic:\n            return torch.Tensor([0.0])\n        logtwopi = np.log(2.0 * np.pi)\n        return 0.5 * torch.sum(\n            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,\n            dim=dims,\n        )\n\n    def mode(self):\n        return self.mean\n\n\ndef normal_kl(mean1, logvar1, mean2, logvar2):\n    \"\"\"\n    source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12\n    Compute the KL divergence between two gaussians.\n    Shapes are automatically broadcasted, so batches can be compared to\n    scalars, among other use cases.\n    \"\"\"\n    tensor = None\n    for obj in (mean1, logvar1, mean2, logvar2):\n        if isinstance(obj, torch.Tensor):\n            tensor = obj\n            break\n    assert tensor is not None, \"at least one argument must be a Tensor\"\n\n    # Force variances to be Tensors. Broadcasting helps convert scalars to\n    # Tensors, but it does not work for torch.exp().\n    logvar1, logvar2 = [\n        x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)\n        for x in (logvar1, logvar2)\n    ]\n\n    return 0.5 * (\n        -1.0\n        + logvar2\n        - logvar1\n        + torch.exp(logvar1 - logvar2)\n        + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)\n    )\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError(\"Decay must be between 0 and 1\")\n\n        self.m_name2s_name = {}\n        self.register_buffer(\"decay\", torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer(\n            \"num_updates\",\n            torch.tensor(0, dtype=torch.int)\n            if use_num_upates\n            else torch.tensor(-1, dtype=torch.int),\n        )\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                # remove as '.'-character is not allowed in buffers\n                s_name = name.replace(\".\", \"\")\n                self.m_name2s_name.update({name: s_name})\n                self.register_buffer(s_name, p.clone().detach().data)\n\n        self.collected_params = []\n\n    def forward(self, model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(\n                        one_minus_decay * (shadow_params[sname] - m_param[key])\n                    )\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the current parameters for restoring later.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters stored with the `store` method.\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/mamba.py",
    "content": "import torch\nfrom mamba_ssm import Mamba\nimport torch.nn as nn\n\n\ndef count_parameters(model):\n    \"\"\"\n    Calculate the total number of parameters in a PyTorch model.\n\n    Parameters:\n    - model (nn.Module): The PyTorch model.\n\n    Returns:\n    - int: The total number of parameters in the model.\n    \"\"\"\n    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n\nclass MambaBlocks(nn.Module):\n    def __init__(self, dim, n_block=4):\n        super(MambaBlocks, self).__init__()\n        self.mamba_blocks = nn.ModuleList(\n            [\n                Mamba(\n                    # This module uses roughly 3 * expand * d_model^2 parameters\n                    d_model=dim,  # Model dimension d_model\n                    d_state=256,  # SSM state expansion factor\n                    d_conv=4,  # Local convolution width\n                    expand=16,  # Block expansion factor\n                )\n                for i in range(n_block)\n            ]\n        )\n        self.mamba_norm = nn.ModuleList(\n            [nn.LayerNorm(dim, eps=1e-6) for i in range(n_block)]\n        )\n\n    def forward(self, x):\n        for i, (block, norm) in enumerate(zip(self.mamba_blocks, self.mamba_norm)):\n            x = block(x) + x\n            if i != len(self.mamba_blocks) - 1:\n                x = norm(x)\n        return x\n\n\nif __name__ == \"__main__\":\n    batch, length, dim = 2, 512, 768\n    x = torch.randn(batch, length, dim).to(\"cuda\")\n    model = MambaBlocks(n_block=4).to(\"cuda\")\n\n    print(\"Number of parameters:\", count_parameters(model))\n\n    y = model(x)\n\n    assert y.shape == x.shape\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/nn.py",
    "content": "\"\"\"\nVarious utilities for neural networks.\n\"\"\"\n\nimport math\n\nimport torch as th\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass GroupNorm32(nn.GroupNorm):\n    def __init__(self, num_groups, num_channels, swish, eps=1e-5):\n        super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)\n        self.swish = swish\n\n    def forward(self, x):\n        y = super().forward(x.float()).to(x.dtype)\n        if self.swish == 1.0:\n            y = F.silu(y)\n        elif self.swish:\n            y = y * F.sigmoid(y * float(self.swish))\n        return y\n\n\ndef conv_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D convolution module.\n    \"\"\"\n    if dims == 1:\n        return nn.Conv1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.Conv2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.Conv3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef linear(*args, **kwargs):\n    \"\"\"\n    Create a linear module.\n    \"\"\"\n    return nn.Linear(*args, **kwargs)\n\n\ndef avg_pool_nd(dims, *args, **kwargs):\n    \"\"\"\n    Create a 1D, 2D, or 3D average pooling module.\n    \"\"\"\n    if dims == 1:\n        return nn.AvgPool1d(*args, **kwargs)\n    elif dims == 2:\n        return nn.AvgPool2d(*args, **kwargs)\n    elif dims == 3:\n        return nn.AvgPool3d(*args, **kwargs)\n    raise ValueError(f\"unsupported dimensions: {dims}\")\n\n\ndef update_ema(target_params, source_params, rate=0.99):\n    \"\"\"\n    Update target parameters to be closer to those of source parameters using\n    an exponential moving average.\n\n    :param target_params: the target parameter sequence.\n    :param source_params: the source parameter sequence.\n    :param rate: the EMA rate (closer to 1 means slower).\n    \"\"\"\n    for targ, src in zip(target_params, source_params):\n        targ.detach().mul_(rate).add_(src, alpha=1 - rate)\n\n\ndef zero_module(module):\n    \"\"\"\n    Zero out the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().zero_()\n    return module\n\n\ndef scale_module(module, scale):\n    \"\"\"\n    Scale the parameters of a module and return it.\n    \"\"\"\n    for p in module.parameters():\n        p.detach().mul_(scale)\n    return module\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef normalization(channels, swish=0.0):\n    \"\"\"\n    Make a standard normalization layer, with an optional swish activation.\n\n    :param channels: number of input channels.\n    :return: an nn.Module for normalization.\n    \"\"\"\n    return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)\n\n\n# def timestep_embedding(timesteps, dim, max_period=10000):\n#    \"\"\"\n#    Create sinusoidal timestep embeddings.\n\n#    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n#                      These may be fractional.\n#    :param dim: the dimension of the output.\n#    :param max_period: controls the minimum frequency of the embeddings.\n#    :return: an [N x dim] Tensor of positional embeddings.\n#    \"\"\"\n#    half = dim // 2\n#    freqs = th.exp(\n#        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half\n#    ).to(device=timesteps.device)\n#    args = timesteps[:, None].float() * freqs[None]\n#    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)\n#    if dim % 2:\n#        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)\n#    return embedding\n\n\ndef timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):\n    \"\"\"\n    Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D Tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param dim: the dimension of the output.\n    :param max_period: controls the minimum frequency of the embeddings.\n    :return: an [N x dim] Tensor of positional embeddings.\n    \"\"\"\n    if not repeat_only:\n        half = dim // 2\n        freqs = th.exp(\n            -math.log(max_period)\n            * th.arange(start=0, end=half, dtype=th.float32)\n            / half\n        ).to(device=timesteps.device)\n        args = timesteps[:, None].float() * freqs[None]\n        embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)\n    else:\n        embedding = repeat(timesteps, \"b -> b d\", d=dim)\n    return embedding\n\n\ndef checkpoint(func, inputs, params, flag):\n    \"\"\"\n    Evaluate a function without caching intermediate activations, allowing for\n    reduced memory at the expense of extra compute in the backward pass.\n\n    :param func: the function to evaluate.\n    :param inputs: the argument sequence to pass to `func`.\n    :param params: a sequence of parameters `func` depends on but does not\n                   explicitly take as arguments.\n    :param flag: if False, disable gradient checkpointing.\n    \"\"\"\n    # flag = False\n    if flag:\n        args = tuple(inputs) + tuple(params)\n        return CheckpointFunction.apply(func, len(inputs), *args)\n    else:\n        return func(*inputs)\n\n\nclass CheckpointFunction(th.autograd.Function):\n    @staticmethod\n    def forward(ctx, run_function, length, *args):\n        ctx.run_function = run_function\n        ctx.input_tensors = list(args[:length])\n        ctx.input_params = list(args[length:])\n        with th.no_grad():\n            output_tensors = ctx.run_function(*ctx.input_tensors)\n        return output_tensors\n\n    @staticmethod\n    def backward(ctx, *output_grads):\n        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]\n        with th.enable_grad():\n            # Fixes a bug where the first op in run_function modifies the\n            # Tensor storage in place, which is not allowed for detach()'d\n            # Tensors.\n            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]\n            output_tensors = ctx.run_function(*shallow_copies)\n        input_grads = th.autograd.grad(\n            output_tensors,\n            ctx.input_tensors + ctx.input_params,\n            output_grads,\n            allow_unused=True,\n        )\n        del ctx.input_tensors\n        del ctx.input_params\n        del output_tensors\n        return (None, None) + input_grads\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/modules/x_transformer.py",
    "content": "\"\"\"shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers\"\"\"\nimport torch\nfrom torch import nn, einsum\nimport torch.nn.functional as F\nfrom functools import partial\nfrom inspect import isfunction\nfrom collections import namedtuple\nfrom einops import rearrange, repeat, reduce\n\n# constants\n\nDEFAULT_DIM_HEAD = 64\n\nIntermediates = namedtuple(\"Intermediates\", [\"pre_softmax_attn\", \"post_softmax_attn\"])\n\nLayerIntermediates = namedtuple(\"Intermediates\", [\"hiddens\", \"attn_intermediates\"])\n\n\nclass AbsolutePositionalEmbedding(nn.Module):\n    def __init__(self, dim, max_seq_len):\n        super().__init__()\n        self.emb = nn.Embedding(max_seq_len, dim)\n        self.init_()\n\n    def init_(self):\n        nn.init.normal_(self.emb.weight, std=0.02)\n\n    def forward(self, x):\n        n = torch.arange(x.shape[1], device=x.device)\n        return self.emb(n)[None, :, :]\n\n\nclass FixedPositionalEmbedding(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))\n        self.register_buffer(\"inv_freq\", inv_freq)\n\n    def forward(self, x, seq_dim=1, offset=0):\n        t = (\n            torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)\n            + offset\n        )\n        sinusoid_inp = torch.einsum(\"i , j -> i j\", t, self.inv_freq)\n        emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)\n        return emb[None, :, :]\n\n\n# helpers\n\n\ndef exists(val):\n    return val is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef always(val):\n    def inner(*args, **kwargs):\n        return val\n\n    return inner\n\n\ndef not_equals(val):\n    def inner(x):\n        return x != val\n\n    return inner\n\n\ndef equals(val):\n    def inner(x):\n        return x == val\n\n    return inner\n\n\ndef max_neg_value(tensor):\n    return -torch.finfo(tensor.dtype).max\n\n\n# keyword argument helpers\n\n\ndef pick_and_pop(keys, d):\n    values = list(map(lambda key: d.pop(key), keys))\n    return dict(zip(keys, values))\n\n\ndef group_dict_by_key(cond, d):\n    return_val = [dict(), dict()]\n    for key in d.keys():\n        match = bool(cond(key))\n        ind = int(not match)\n        return_val[ind][key] = d[key]\n    return (*return_val,)\n\n\ndef string_begins_with(prefix, str):\n    return str.startswith(prefix)\n\n\ndef group_by_key_prefix(prefix, d):\n    return group_dict_by_key(partial(string_begins_with, prefix), d)\n\n\ndef groupby_prefix_and_trim(prefix, d):\n    kwargs_with_prefix, kwargs = group_dict_by_key(\n        partial(string_begins_with, prefix), d\n    )\n    kwargs_without_prefix = dict(\n        map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))\n    )\n    return kwargs_without_prefix, kwargs\n\n\n# classes\nclass Scale(nn.Module):\n    def __init__(self, value, fn):\n        super().__init__()\n        self.value = value\n        self.fn = fn\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.value, *rest)\n\n\nclass Rezero(nn.Module):\n    def __init__(self, fn):\n        super().__init__()\n        self.fn = fn\n        self.g = nn.Parameter(torch.zeros(1))\n\n    def forward(self, x, **kwargs):\n        x, *rest = self.fn(x, **kwargs)\n        return (x * self.g, *rest)\n\n\nclass ScaleNorm(nn.Module):\n    def __init__(self, dim, eps=1e-5):\n        super().__init__()\n        self.scale = dim**-0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(1))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass RMSNorm(nn.Module):\n    def __init__(self, dim, eps=1e-8):\n        super().__init__()\n        self.scale = dim**-0.5\n        self.eps = eps\n        self.g = nn.Parameter(torch.ones(dim))\n\n    def forward(self, x):\n        norm = torch.norm(x, dim=-1, keepdim=True) * self.scale\n        return x / norm.clamp(min=self.eps) * self.g\n\n\nclass Residual(nn.Module):\n    def forward(self, x, residual):\n        return x + residual\n\n\nclass GRUGating(nn.Module):\n    def __init__(self, dim):\n        super().__init__()\n        self.gru = nn.GRUCell(dim, dim)\n\n    def forward(self, x, residual):\n        gated_output = self.gru(\n            rearrange(x, \"b n d -> (b n) d\"), rearrange(residual, \"b n d -> (b n) d\")\n        )\n\n        return gated_output.reshape_as(x)\n\n\n# feedforward\n\n\nclass GEGLU(nn.Module):\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.proj = nn.Linear(dim_in, dim_out * 2)\n\n    def forward(self, x):\n        x, gate = self.proj(x).chunk(2, dim=-1)\n        return x * F.gelu(gate)\n\n\nclass FeedForward(nn.Module):\n    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):\n        super().__init__()\n        inner_dim = int(dim * mult)\n        dim_out = default(dim_out, dim)\n        project_in = (\n            nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())\n            if not glu\n            else GEGLU(dim, inner_dim)\n        )\n\n        self.net = nn.Sequential(\n            project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)\n        )\n\n    def forward(self, x):\n        return self.net(x)\n\n\n# attention.\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim,\n        dim_head=DEFAULT_DIM_HEAD,\n        heads=8,\n        causal=False,\n        mask=None,\n        talking_heads=False,\n        sparse_topk=None,\n        use_entmax15=False,\n        num_mem_kv=0,\n        dropout=0.0,\n        on_attn=False,\n    ):\n        super().__init__()\n        if use_entmax15:\n            raise NotImplementedError(\n                \"Check out entmax activation instead of softmax activation!\"\n            )\n        self.scale = dim_head**-0.5\n        self.heads = heads\n        self.causal = causal\n        self.mask = mask\n\n        inner_dim = dim_head * heads\n\n        self.to_q = nn.Linear(dim, inner_dim, bias=False)\n        self.to_k = nn.Linear(dim, inner_dim, bias=False)\n        self.to_v = nn.Linear(dim, inner_dim, bias=False)\n        self.dropout = nn.Dropout(dropout)\n\n        # talking heads\n        self.talking_heads = talking_heads\n        if talking_heads:\n            self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n            self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))\n\n        # explicit topk sparse attention\n        self.sparse_topk = sparse_topk\n\n        # entmax\n        # self.attn_fn = entmax15 if use_entmax15 else F.softmax\n        self.attn_fn = F.softmax\n\n        # add memory key / values\n        self.num_mem_kv = num_mem_kv\n        if num_mem_kv > 0:\n            self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n            self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))\n\n        # attention on attention\n        self.attn_on_attn = on_attn\n        self.to_out = (\n            nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU())\n            if on_attn\n            else nn.Linear(inner_dim, dim)\n        )\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        context_mask=None,\n        rel_pos=None,\n        sinusoidal_emb=None,\n        prev_attn=None,\n        mem=None,\n    ):\n        b, n, _, h, talking_heads, device = (\n            *x.shape,\n            self.heads,\n            self.talking_heads,\n            x.device,\n        )\n        kv_input = default(context, x)\n\n        q_input = x\n        k_input = kv_input\n        v_input = kv_input\n\n        if exists(mem):\n            k_input = torch.cat((mem, k_input), dim=-2)\n            v_input = torch.cat((mem, v_input), dim=-2)\n\n        if exists(sinusoidal_emb):\n            # in shortformer, the query would start at a position offset depending on the past cached memory\n            offset = k_input.shape[-2] - q_input.shape[-2]\n            q_input = q_input + sinusoidal_emb(q_input, offset=offset)\n            k_input = k_input + sinusoidal_emb(k_input)\n\n        q = self.to_q(q_input)\n        k = self.to_k(k_input)\n        v = self.to_v(v_input)\n\n        q, k, v = map(lambda t: rearrange(t, \"b n (h d) -> b h n d\", h=h), (q, k, v))\n\n        input_mask = None\n        if any(map(exists, (mask, context_mask))):\n            q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())\n            k_mask = q_mask if not exists(context) else context_mask\n            k_mask = default(\n                k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool()\n            )\n            q_mask = rearrange(q_mask, \"b i -> b () i ()\")\n            k_mask = rearrange(k_mask, \"b j -> b () () j\")\n            input_mask = q_mask * k_mask\n\n        if self.num_mem_kv > 0:\n            mem_k, mem_v = map(\n                lambda t: repeat(t, \"h n d -> b h n d\", b=b), (self.mem_k, self.mem_v)\n            )\n            k = torch.cat((mem_k, k), dim=-2)\n            v = torch.cat((mem_v, v), dim=-2)\n            if exists(input_mask):\n                input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)\n\n        dots = einsum(\"b h i d, b h j d -> b h i j\", q, k) * self.scale\n        mask_value = max_neg_value(dots)\n\n        if exists(prev_attn):\n            dots = dots + prev_attn\n\n        pre_softmax_attn = dots\n\n        if talking_heads:\n            dots = einsum(\n                \"b h i j, h k -> b k i j\", dots, self.pre_softmax_proj\n            ).contiguous()\n\n        if exists(rel_pos):\n            dots = rel_pos(dots)\n\n        if exists(input_mask):\n            dots.masked_fill_(~input_mask, mask_value)\n            del input_mask\n\n        if self.causal:\n            i, j = dots.shape[-2:]\n            r = torch.arange(i, device=device)\n            mask = rearrange(r, \"i -> () () i ()\") < rearrange(r, \"j -> () () () j\")\n            mask = F.pad(mask, (j - i, 0), value=False)\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:\n            top, _ = dots.topk(self.sparse_topk, dim=-1)\n            vk = top[..., -1].unsqueeze(-1).expand_as(dots)\n            mask = dots < vk\n            dots.masked_fill_(mask, mask_value)\n            del mask\n\n        attn = self.attn_fn(dots, dim=-1)\n        post_softmax_attn = attn\n\n        attn = self.dropout(attn)\n\n        if talking_heads:\n            attn = einsum(\n                \"b h i j, h k -> b k i j\", attn, self.post_softmax_proj\n            ).contiguous()\n\n        out = einsum(\"b h i j, b h j d -> b h i d\", attn, v)\n        out = rearrange(out, \"b h n d -> b n (h d)\")\n\n        intermediates = Intermediates(\n            pre_softmax_attn=pre_softmax_attn, post_softmax_attn=post_softmax_attn\n        )\n\n        return self.to_out(out), intermediates\n\n\nclass AttentionLayers(nn.Module):\n    def __init__(\n        self,\n        dim,\n        depth,\n        heads=8,\n        causal=False,\n        cross_attend=False,\n        only_cross=False,\n        use_scalenorm=False,\n        use_rmsnorm=False,\n        use_rezero=False,\n        rel_pos_num_buckets=32,\n        rel_pos_max_distance=128,\n        position_infused_attn=False,\n        custom_layers=None,\n        sandwich_coef=None,\n        par_ratio=None,\n        residual_attn=False,\n        cross_residual_attn=False,\n        macaron=False,\n        pre_norm=True,\n        gate_residual=False,\n        **kwargs,\n    ):\n        super().__init__()\n        ff_kwargs, kwargs = groupby_prefix_and_trim(\"ff_\", kwargs)\n        attn_kwargs, _ = groupby_prefix_and_trim(\"attn_\", kwargs)\n\n        dim_head = attn_kwargs.get(\"dim_head\", DEFAULT_DIM_HEAD)\n\n        self.dim = dim\n        self.depth = depth\n        self.layers = nn.ModuleList([])\n\n        self.has_pos_emb = position_infused_attn\n        self.pia_pos_emb = (\n            FixedPositionalEmbedding(dim) if position_infused_attn else None\n        )\n        self.rotary_pos_emb = always(None)\n\n        assert (\n            rel_pos_num_buckets <= rel_pos_max_distance\n        ), \"number of relative position buckets must be less than the relative position max distance\"\n        self.rel_pos = None\n\n        self.pre_norm = pre_norm\n\n        self.residual_attn = residual_attn\n        self.cross_residual_attn = cross_residual_attn\n\n        norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm\n        norm_class = RMSNorm if use_rmsnorm else norm_class\n        norm_fn = partial(norm_class, dim)\n\n        norm_fn = nn.Identity if use_rezero else norm_fn\n        branch_fn = Rezero if use_rezero else None\n\n        if cross_attend and not only_cross:\n            default_block = (\"a\", \"c\", \"f\")\n        elif cross_attend and only_cross:\n            default_block = (\"c\", \"f\")\n        else:\n            default_block = (\"a\", \"f\")\n\n        if macaron:\n            default_block = (\"f\",) + default_block\n\n        if exists(custom_layers):\n            layer_types = custom_layers\n        elif exists(par_ratio):\n            par_depth = depth * len(default_block)\n            assert 1 < par_ratio <= par_depth, \"par ratio out of range\"\n            default_block = tuple(filter(not_equals(\"f\"), default_block))\n            par_attn = par_depth // par_ratio\n            depth_cut = (\n                par_depth * 2 // 3\n            )  # 2 / 3 attention layer cutoff suggested by PAR paper\n            par_width = (depth_cut + depth_cut // par_attn) // par_attn\n            assert (\n                len(default_block) <= par_width\n            ), \"default block is too large for par_ratio\"\n            par_block = default_block + (\"f\",) * (par_width - len(default_block))\n            par_head = par_block * par_attn\n            layer_types = par_head + (\"f\",) * (par_depth - len(par_head))\n        elif exists(sandwich_coef):\n            assert (\n                sandwich_coef > 0 and sandwich_coef <= depth\n            ), \"sandwich coefficient should be less than the depth\"\n            layer_types = (\n                (\"a\",) * sandwich_coef\n                + default_block * (depth - sandwich_coef)\n                + (\"f\",) * sandwich_coef\n            )\n        else:\n            layer_types = default_block * depth\n\n        self.layer_types = layer_types\n        self.num_attn_layers = len(list(filter(equals(\"a\"), layer_types)))\n\n        for layer_type in self.layer_types:\n            if layer_type == \"a\":\n                layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)\n            elif layer_type == \"c\":\n                layer = Attention(dim, heads=heads, **attn_kwargs)\n            elif layer_type == \"f\":\n                layer = FeedForward(dim, **ff_kwargs)\n                layer = layer if not macaron else Scale(0.5, layer)\n            else:\n                raise Exception(f\"invalid layer type {layer_type}\")\n\n            if isinstance(layer, Attention) and exists(branch_fn):\n                layer = branch_fn(layer)\n\n            if gate_residual:\n                residual_fn = GRUGating(dim)\n            else:\n                residual_fn = Residual()\n\n            self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))\n\n    def forward(\n        self,\n        x,\n        context=None,\n        mask=None,\n        context_mask=None,\n        mems=None,\n        return_hiddens=False,\n    ):\n        hiddens = []\n        intermediates = []\n        prev_attn = None\n        prev_cross_attn = None\n\n        mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers\n\n        for ind, (layer_type, (norm, block, residual_fn)) in enumerate(\n            zip(self.layer_types, self.layers)\n        ):\n            is_last = ind == (len(self.layers) - 1)\n\n            if layer_type == \"a\":\n                hiddens.append(x)\n                layer_mem = mems.pop(0)\n\n            residual = x\n\n            if self.pre_norm:\n                x = norm(x)\n\n            if layer_type == \"a\":\n                out, inter = block(\n                    x,\n                    mask=mask,\n                    sinusoidal_emb=self.pia_pos_emb,\n                    rel_pos=self.rel_pos,\n                    prev_attn=prev_attn,\n                    mem=layer_mem,\n                )\n            elif layer_type == \"c\":\n                out, inter = block(\n                    x,\n                    context=context,\n                    mask=mask,\n                    context_mask=context_mask,\n                    prev_attn=prev_cross_attn,\n                )\n            elif layer_type == \"f\":\n                out = block(x)\n\n            x = residual_fn(out, residual)\n\n            if layer_type in (\"a\", \"c\"):\n                intermediates.append(inter)\n\n            if layer_type == \"a\" and self.residual_attn:\n                prev_attn = inter.pre_softmax_attn\n            elif layer_type == \"c\" and self.cross_residual_attn:\n                prev_cross_attn = inter.pre_softmax_attn\n\n            if not self.pre_norm and not is_last:\n                x = norm(x)\n\n        if return_hiddens:\n            intermediates = LayerIntermediates(\n                hiddens=hiddens, attn_intermediates=intermediates\n            )\n\n            return x, intermediates\n\n        return x\n\n\nclass Encoder(AttentionLayers):\n    def __init__(self, **kwargs):\n        assert \"causal\" not in kwargs, \"cannot set causality on encoder\"\n        super().__init__(causal=False, **kwargs)\n\n\nclass TransformerWrapper(nn.Module):\n    def __init__(\n        self,\n        *,\n        num_tokens,\n        max_seq_len,\n        attn_layers,\n        emb_dim=None,\n        max_mem_len=0.0,\n        emb_dropout=0.0,\n        num_memory_tokens=None,\n        tie_embedding=False,\n        use_pos_emb=True,\n    ):\n        super().__init__()\n        assert isinstance(\n            attn_layers, AttentionLayers\n        ), \"attention layers must be one of Encoder or Decoder\"\n\n        dim = attn_layers.dim\n        emb_dim = default(emb_dim, dim)\n\n        self.max_seq_len = max_seq_len\n        self.max_mem_len = max_mem_len\n        self.num_tokens = num_tokens\n\n        self.token_emb = nn.Embedding(num_tokens, emb_dim)\n        self.pos_emb = (\n            AbsolutePositionalEmbedding(emb_dim, max_seq_len)\n            if (use_pos_emb and not attn_layers.has_pos_emb)\n            else always(0)\n        )\n        self.emb_dropout = nn.Dropout(emb_dropout)\n\n        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()\n        self.attn_layers = attn_layers\n        self.norm = nn.LayerNorm(dim)\n\n        self.init_()\n\n        self.to_logits = (\n            nn.Linear(dim, num_tokens)\n            if not tie_embedding\n            else lambda t: t @ self.token_emb.weight.t()\n        )\n\n        # memory tokens (like [cls]) from Memory Transformers paper\n        num_memory_tokens = default(num_memory_tokens, 0)\n        self.num_memory_tokens = num_memory_tokens\n        if num_memory_tokens > 0:\n            self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))\n\n            # let funnel encoder know number of memory tokens, if specified\n            if hasattr(attn_layers, \"num_memory_tokens\"):\n                attn_layers.num_memory_tokens = num_memory_tokens\n\n    def init_(self):\n        nn.init.normal_(self.token_emb.weight, std=0.02)\n\n    def forward(\n        self,\n        x,\n        return_embeddings=False,\n        mask=None,\n        return_mems=False,\n        return_attn=False,\n        mems=None,\n        **kwargs,\n    ):\n        b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens\n        x = self.token_emb(x)\n        x += self.pos_emb(x)\n        x = self.emb_dropout(x)\n\n        x = self.project_emb(x)\n\n        if num_mem > 0:\n            mem = repeat(self.memory_tokens, \"n d -> b n d\", b=b)\n            x = torch.cat((mem, x), dim=1)\n\n            # auto-handle masking after appending memory tokens\n            if exists(mask):\n                mask = F.pad(mask, (num_mem, 0), value=True)\n\n        x, intermediates = self.attn_layers(\n            x, mask=mask, mems=mems, return_hiddens=True, **kwargs\n        )\n        x = self.norm(x)\n\n        mem, x = x[:, :num_mem], x[:, num_mem:]\n\n        out = self.to_logits(x) if not return_embeddings else x\n\n        if return_mems:\n            hiddens = intermediates.hiddens\n            new_mems = (\n                list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens)))\n                if exists(mems)\n                else hiddens\n            )\n            new_mems = list(\n                map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems)\n            )\n            return out, new_mems\n\n        if return_attn:\n            attn_maps = list(\n                map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)\n            )\n            return out, attn_maps\n\n        return out\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_diffusion/util.py",
    "content": "import importlib\n\nimport torch\nimport numpy as np\nfrom collections import abc\nfrom einops import rearrange\nfrom functools import partial\n\nimport multiprocessing as mp\nfrom threading import Thread\nfrom queue import Queue\n\nfrom inspect import isfunction\nfrom PIL import Image, ImageDraw, ImageFont\n\n\ndef disabled_train(self, mode=True):\n    \"\"\"Overwrite model.train with this function to make sure train/eval mode\n    does not change anymore.\"\"\"\n    return self\n\n\ndef get_unconditional_condition(batchsize, downsampling_rate, device):\n    token_num = 512 // downsampling_rate\n    representation_quant = (\n        torch.zeros((batchsize, token_num, 768 * downsampling_rate)).to(device).float()\n    )\n    return [representation_quant, torch.ones((batchsize, token_num)).to(device).float()]\n\n\ndef log_txt_as_img(wh, xc, size=10):\n    # wh a tuple of (width, height)\n    # xc a list of captions to plot\n    b = len(xc)\n    txts = list()\n    for bi in range(b):\n        txt = Image.new(\"RGB\", wh, color=\"white\")\n        draw = ImageDraw.Draw(txt)\n        font = ImageFont.truetype(\"data/DejaVuSans.ttf\", size=size)\n        nc = int(40 * (wh[0] / 256))\n        lines = \"\\n\".join(\n            xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)\n        )\n\n        try:\n            draw.text((0, 0), lines, fill=\"black\", font=font)\n        except UnicodeEncodeError:\n            print(\"Cant encode string for logging. Skipping.\")\n\n        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0\n        txts.append(txt)\n    txts = np.stack(txts)\n    txts = torch.tensor(txts)\n    return txts\n\n\ndef ismap(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] > 3)\n\n\ndef isimage(x):\n    if not isinstance(x, torch.Tensor):\n        return False\n    return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)\n\n\ndef int16_to_float32(x):\n    return (x / 32767.0).astype(np.float32)\n\n\ndef float32_to_int16(x):\n    x = np.clip(x, a_min=-1.0, a_max=1.0)\n    return (x * 32767.0).astype(np.int16)\n\n\ndef exists(x):\n    return x is not None\n\n\ndef default(val, d):\n    if exists(val):\n        return val\n    return d() if isfunction(d) else d\n\n\ndef mean_flat(tensor):\n    \"\"\"\n    https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86\n    Take the mean over all non-batch dimensions.\n    \"\"\"\n    return tensor.mean(dim=list(range(1, len(tensor.shape))))\n\n\ndef count_params(model, verbose=False):\n    total_params = sum(p.numel() for p in model.parameters())\n    if verbose:\n        print(f\"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.\")\n    return total_params\n\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        if config == \"__is_first_stage__\":\n            return None\n        elif config == \"__is_unconditional__\":\n            return None\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\n\ndef get_obj_from_str(string, reload=False):\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=None), cls)\n\n\ndef _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):\n    # create dummy dataset instance\n\n    # run prefetching\n    if idx_to_fn:\n        res = func(data, worker_id=idx)\n    else:\n        res = func(data)\n    Q.put([idx, res])\n    Q.put(\"Done\")\n\n\ndef parallel_data_prefetch(\n    func: callable,\n    data,\n    n_proc,\n    target_data_type=\"ndarray\",\n    cpu_intensive=True,\n    use_worker_id=False,\n):\n    # if target_data_type not in [\"ndarray\", \"list\"]:\n    #     raise ValueError(\n    #         \"Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray.\"\n    #     )\n    if isinstance(data, np.ndarray) and target_data_type == \"list\":\n        raise ValueError(\"list expected but function got ndarray.\")\n    elif isinstance(data, abc.Iterable):\n        if isinstance(data, dict):\n            print(\n                f'WARNING:\"data\" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'\n            )\n            data = list(data.values())\n        if target_data_type == \"ndarray\":\n            data = np.asarray(data)\n        else:\n            data = list(data)\n    else:\n        raise TypeError(\n            f\"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}.\"\n        )\n\n    if cpu_intensive:\n        Q = mp.Queue(1000)\n        proc = mp.Process\n    else:\n        Q = Queue(1000)\n        proc = Thread\n    # spawn processes\n    if target_data_type == \"ndarray\":\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(np.array_split(data, n_proc))\n        ]\n    else:\n        step = (\n            int(len(data) / n_proc + 1)\n            if len(data) % n_proc != 0\n            else int(len(data) / n_proc)\n        )\n        arguments = [\n            [func, Q, part, i, use_worker_id]\n            for i, part in enumerate(\n                [data[i : i + step] for i in range(0, len(data), step)]\n            )\n        ]\n    processes = []\n    for i in range(n_proc):\n        p = proc(target=_do_parallel_data_prefetch, args=arguments[i])\n        processes += [p]\n\n    # start processes\n    print(f\"Start prefetching...\")\n    import time\n\n    start = time.time()\n    gather_res = [[] for _ in range(n_proc)]\n    try:\n        for p in processes:\n            p.start()\n\n        k = 0\n        while k < n_proc:\n            # get result\n            res = Q.get()\n            if res == \"Done\":\n                k += 1\n            else:\n                gather_res[res[0]] = res[1]\n\n    except Exception as e:\n        print(\"Exception: \", e)\n        for p in processes:\n            p.terminate()\n\n        raise e\n    finally:\n        for p in processes:\n            p.join()\n        print(f\"Prefetching complete. [{time.time() - start} sec.]\")\n\n    if target_data_type == \"ndarray\":\n        if not isinstance(gather_res[0], np.ndarray):\n            return np.concatenate([np.asarray(r) for r in gather_res], axis=0)\n\n        # order outputs\n        return np.concatenate(gather_res, axis=0)\n    elif target_data_type == \"list\":\n        out = []\n        for r in gather_res:\n            out.extend(r)\n        return out\n    else:\n        return gather_res\n"
  },
  {
    "path": "semanticodec/modules/decoder/latent_encoder/__init__.py",
    "content": ""
  },
  {
    "path": "semanticodec/modules/decoder/latent_encoder/autoencoder.py",
    "content": "import torch\nimport os\n\nimport torch.nn.functional as F\nfrom semanticodec.modules.decoder.latent_diffusion.modules.ema import *\n\nfrom semanticodec.modules.decoder.latent_diffusion.modules.diffusionmodules.model import (\n    Encoder,\n    Decoder,\n)\nfrom semanticodec.modules.decoder.latent_diffusion.modules.distributions.distributions import (\n    DiagonalGaussianDistribution,\n)\n\nimport soundfile as sf\n\nfrom semanticodec.modules.decoder.utilities.model import get_vocoder\nfrom semanticodec.modules.decoder.utilities.tools import synth_one_sample\n\n\nclass AutoencoderKL(nn.Module):\n    def __init__(\n        self,\n        ddconfig=None,\n        lossconfig=None,\n        batchsize=None,\n        embed_dim=None,\n        time_shuffle=1,\n        subband=1,\n        sampling_rate=16000,\n        ckpt_path=None,\n        reload_from_ckpt=None,\n        ignore_keys=[],\n        image_key=\"fbank\",\n        colorize_nlabels=None,\n        monitor=None,\n        base_learning_rate=1e-5,\n    ):\n        super().__init__()\n        self.automatic_optimization = False\n        assert (\n            \"mel_bins\" in ddconfig.keys()\n        ), \"mel_bins is not specified in the Autoencoder config\"\n        num_mel = ddconfig[\"mel_bins\"]\n        self.image_key = image_key\n        self.sampling_rate = sampling_rate\n        self.encoder = Encoder(**ddconfig)\n        self.decoder = Decoder(**ddconfig)\n        self.subband = int(subband)\n\n        if self.subband > 1:\n            print(\"Use subband decomposition %s\" % self.subband)\n\n        assert ddconfig[\"double_z\"]\n        self.quant_conv = torch.nn.Conv2d(2 * ddconfig[\"z_channels\"], 2 * embed_dim, 1)\n        self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig[\"z_channels\"], 1)\n\n        if self.image_key == \"fbank\":\n            self.vocoder = get_vocoder(None, \"cpu\", num_mel)\n        self.embed_dim = embed_dim\n        if colorize_nlabels is not None:\n            assert type(colorize_nlabels) == int\n            self.register_buffer(\"colorize\", torch.randn(3, colorize_nlabels, 1, 1))\n        if monitor is not None:\n            self.monitor = monitor\n        if ckpt_path is not None:\n            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)\n        self.learning_rate = float(base_learning_rate)\n\n        self.time_shuffle = time_shuffle\n        self.reload_from_ckpt = reload_from_ckpt\n        self.reloaded = False\n        self.mean, self.std = None, None\n\n        self.feature_cache = None\n        self.flag_first_run = True\n        self.train_step = 0\n\n        self.logger_save_dir = None\n        self.logger_exp_name = None\n        self.logger_exp_group_name = None\n\n    def get_log_dir(self):\n        return os.path.join(\n            self.logger_save_dir, self.logger_exp_group_name, self.logger_exp_name\n        )\n\n    def set_log_dir(self, save_dir, exp_group_name, exp_name):\n        self.logger_save_dir = save_dir\n        self.logger_exp_name = exp_name\n        self.logger_exp_group_name = exp_group_name\n\n    def init_from_ckpt(self, path, ignore_keys=list()):\n        sd = torch.load(path, map_location=\"cpu\")[\"state_dict\"]\n        keys = list(sd.keys())\n        for k in keys:\n            for ik in ignore_keys:\n                if k.startswith(ik):\n                    print(\"Deleting key {} from state_dict.\".format(k))\n                    del sd[k]\n        self.load_state_dict(sd, strict=False)\n        print(f\"Restored from {path}\")\n\n    def encode(self, x):\n        # x = self.time_shuffle_operation(x)\n        x = self.freq_split_subband(x)\n        h = self.encoder(x)\n        moments = self.quant_conv(h)\n        posterior = DiagonalGaussianDistribution(moments)\n        return posterior\n\n    def decode(self, z):\n        z = self.post_quant_conv(z)\n        dec = self.decoder(z)\n        # bs, ch, shuffled_timesteps, fbins = dec.size()\n        # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins)\n        dec = self.freq_merge_subband(dec)\n        return dec\n\n    def decode_to_waveform(self, dec):\n        from utilities.model import vocoder_infer\n\n        if self.image_key == \"fbank\":\n            dec = dec.squeeze(1).permute(0, 2, 1)\n            wav_reconstruction = vocoder_infer(dec, self.vocoder)\n        elif self.image_key == \"stft\":\n            dec = dec.squeeze(1).permute(0, 2, 1)\n            wav_reconstruction = self.wave_decoder(dec)\n        return wav_reconstruction\n\n    def forward(self, input, sample_posterior=True):\n        posterior = self.encode(input)\n        if sample_posterior:\n            z = posterior.sample()\n        else:\n            z = posterior.mode()\n\n        if self.flag_first_run:\n            print(\"Latent size: \", z.size())\n            self.flag_first_run = False\n\n        dec = self.decode(z)\n\n        return dec, posterior\n\n    def freq_split_subband(self, fbank):\n        if self.subband == 1 or self.image_key != \"stft\":\n            return fbank\n\n        bs, ch, tstep, fbins = fbank.size()\n\n        assert fbank.size(-1) % self.subband == 0\n        assert ch == 1\n\n        return (\n            fbank.squeeze(1)\n            .reshape(bs, tstep, self.subband, fbins // self.subband)\n            .permute(0, 2, 1, 3)\n        )\n\n    def freq_merge_subband(self, subband_fbank):\n        if self.subband == 1 or self.image_key != \"stft\":\n            return subband_fbank\n        assert subband_fbank.size(1) == self.subband  # Channel dimension\n        bs, sub_ch, tstep, fbins = subband_fbank.size()\n        return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1)\n\n    def save_wave(self, batch_wav, fname, save_dir):\n        os.makedirs(save_dir, exist_ok=True)\n\n        for wav, name in zip(batch_wav, fname):\n            name = os.path.basename(name)\n\n            sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate)\n\n    def get_last_layer(self):\n        return self.decoder.conv_out.weight\n\n    @torch.no_grad()\n    def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs):\n        log = dict()\n        x = batch.to(self.device)\n        if not only_inputs:\n            xrec, posterior = self(x)\n            log[\"samples\"] = self.decode(posterior.sample())\n            log[\"reconstructions\"] = xrec\n\n        log[\"inputs\"] = x\n        wavs = self._log_img(log, train=train, index=0, waveform=waveform)\n        return wavs\n\n    def tensor2numpy(self, tensor):\n        return tensor.cpu().detach().numpy()\n\n    def to_rgb(self, x):\n        assert self.image_key == \"segmentation\"\n        if not hasattr(self, \"colorize\"):\n            self.register_buffer(\"colorize\", torch.randn(3, x.shape[1], 1, 1).to(x))\n        x = F.conv2d(x, weight=self.colorize)\n        x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0\n        return x\n\n\nclass IdentityFirstStage(torch.nn.Module):\n    def __init__(self, *args, vq_interface=False, **kwargs):\n        self.vq_interface = vq_interface  # TODO: Should be true by default but check to not break older stuff\n        super().__init__()\n\n    def encode(self, x, *args, **kwargs):\n        return x\n\n    def decode(self, x, *args, **kwargs):\n        return x\n\n    def quantize(self, x, *args, **kwargs):\n        if self.vq_interface:\n            return x, None, [None, None, None]\n        return x\n\n    def forward(self, x, *args, **kwargs):\n        return x\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/__init__.py",
    "content": "from .tools import *\nfrom .model import *\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/audio/__init__.py",
    "content": "from .audio_processing import *\nfrom .stft import *\nfrom .tools import *\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/audio/audio_processing.py",
    "content": "import torch\nimport numpy as np\nimport librosa.util as librosa_util\nfrom scipy.signal import get_window\n\n\ndef window_sumsquare(\n    window,\n    n_frames,\n    hop_length,\n    win_length,\n    n_fft,\n    dtype=np.float32,\n    norm=None,\n):\n    \"\"\"\n    # from librosa 0.6\n    Compute the sum-square envelope of a window function at a given hop length.\n\n    This is used to estimate modulation effects induced by windowing\n    observations in short-time fourier transforms.\n\n    Parameters\n    ----------\n    window : string, tuple, number, callable, or list-like\n        Window specification, as in `get_window`\n\n    n_frames : int > 0\n        The number of analysis frames\n\n    hop_length : int > 0\n        The number of samples to advance between frames\n\n    win_length : [optional]\n        The length of the window function.  By default, this matches `n_fft`.\n\n    n_fft : int > 0\n        The length of each analysis frame.\n\n    dtype : np.dtype\n        The data type of the output\n\n    Returns\n    -------\n    wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`\n        The sum-squared envelope of the window function\n    \"\"\"\n    if win_length is None:\n        win_length = n_fft\n\n    n = n_fft + hop_length * (n_frames - 1)\n    x = np.zeros(n, dtype=dtype)\n\n    # Compute the squared window at the desired length\n    win_sq = get_window(window, win_length, fftbins=True)\n    win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2\n    win_sq = librosa_util.pad_center(win_sq, n_fft)\n\n    # Fill the envelope\n    for i in range(n_frames):\n        sample = i * hop_length\n        x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]\n    return x\n\n\ndef griffin_lim(magnitudes, stft_fn, n_iters=30):\n    \"\"\"\n    PARAMS\n    ------\n    magnitudes: spectrogram magnitudes\n    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods\n    \"\"\"\n\n    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))\n    angles = angles.astype(np.float32)\n    angles = torch.autograd.Variable(torch.from_numpy(angles))\n    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)\n\n    for i in range(n_iters):\n        _, angles = stft_fn.transform(signal)\n        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)\n    return signal\n\n\ndef dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):\n    \"\"\"\n    PARAMS\n    ------\n    C: compression factor\n    \"\"\"\n    return normalize_fun(torch.clamp(x, min=clip_val) * C)\n\n\ndef dynamic_range_decompression(x, C=1):\n    \"\"\"\n    PARAMS\n    ------\n    C: compression factor used to compress\n    \"\"\"\n    return torch.exp(x) / C\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/audio/stft.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.signal import get_window\nfrom librosa.util import pad_center, tiny\nfrom librosa.filters import mel as librosa_mel_fn\n\nfrom utilities.audio.audio_processing import (\n    dynamic_range_compression,\n    dynamic_range_decompression,\n    window_sumsquare,\n)\n\n\nclass STFT(torch.nn.Module):\n    \"\"\"adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft\"\"\"\n\n    def __init__(self, filter_length, hop_length, win_length, window=\"hann\"):\n        super(STFT, self).__init__()\n        self.filter_length = filter_length\n        self.hop_length = hop_length\n        self.win_length = win_length\n        self.window = window\n        self.forward_transform = None\n        scale = self.filter_length / self.hop_length\n        fourier_basis = np.fft.fft(np.eye(self.filter_length))\n\n        cutoff = int((self.filter_length / 2 + 1))\n        fourier_basis = np.vstack(\n            [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]\n        )\n\n        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])\n        inverse_basis = torch.FloatTensor(\n            np.linalg.pinv(scale * fourier_basis).T[:, None, :]\n        )\n\n        if window is not None:\n            assert filter_length >= win_length\n            # get window and zero center pad it to filter_length\n            fft_window = get_window(window, win_length, fftbins=True)\n            fft_window = pad_center(fft_window, filter_length)\n            fft_window = torch.from_numpy(fft_window).float()\n\n            # window the bases\n            forward_basis *= fft_window\n            inverse_basis *= fft_window\n\n        self.register_buffer(\"forward_basis\", forward_basis.float())\n        self.register_buffer(\"inverse_basis\", inverse_basis.float())\n\n    def transform(self, input_data):\n        num_batches = input_data.size(0)\n        num_samples = input_data.size(1)\n\n        self.num_samples = num_samples\n\n        # similar to librosa, reflect-pad the input\n        input_data = input_data.view(num_batches, 1, num_samples)\n        input_data = F.pad(\n            input_data.unsqueeze(1),\n            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),\n            mode=\"reflect\",\n        )\n        input_data = input_data.squeeze(1)\n\n        forward_transform = F.conv1d(\n            input_data,\n            torch.autograd.Variable(self.forward_basis, requires_grad=False),\n            stride=self.hop_length,\n            padding=0,\n        ).cpu()\n\n        cutoff = int((self.filter_length / 2) + 1)\n        real_part = forward_transform[:, :cutoff, :]\n        imag_part = forward_transform[:, cutoff:, :]\n\n        magnitude = torch.sqrt(real_part**2 + imag_part**2)\n        phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))\n\n        return magnitude, phase\n\n    def inverse(self, magnitude, phase):\n        recombine_magnitude_phase = torch.cat(\n            [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1\n        )\n\n        inverse_transform = F.conv_transpose1d(\n            recombine_magnitude_phase,\n            torch.autograd.Variable(self.inverse_basis, requires_grad=False),\n            stride=self.hop_length,\n            padding=0,\n        )\n\n        if self.window is not None:\n            window_sum = window_sumsquare(\n                self.window,\n                magnitude.size(-1),\n                hop_length=self.hop_length,\n                win_length=self.win_length,\n                n_fft=self.filter_length,\n                dtype=np.float32,\n            )\n            # remove modulation effects\n            approx_nonzero_indices = torch.from_numpy(\n                np.where(window_sum > tiny(window_sum))[0]\n            )\n            window_sum = torch.autograd.Variable(\n                torch.from_numpy(window_sum), requires_grad=False\n            )\n            window_sum = window_sum\n            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[\n                approx_nonzero_indices\n            ]\n\n            # scale by hop ratio\n            inverse_transform *= float(self.filter_length) / self.hop_length\n\n        inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]\n        inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]\n\n        return inverse_transform\n\n    def forward(self, input_data):\n        self.magnitude, self.phase = self.transform(input_data)\n        reconstruction = self.inverse(self.magnitude, self.phase)\n        return reconstruction\n\n\nclass TacotronSTFT(torch.nn.Module):\n    def __init__(\n        self,\n        filter_length,\n        hop_length,\n        win_length,\n        n_mel_channels,\n        sampling_rate,\n        mel_fmin,\n        mel_fmax,\n    ):\n        super(TacotronSTFT, self).__init__()\n        self.n_mel_channels = n_mel_channels\n        self.sampling_rate = sampling_rate\n        self.stft_fn = STFT(filter_length, hop_length, win_length)\n        mel_basis = librosa_mel_fn(\n            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax\n        )\n        mel_basis = torch.from_numpy(mel_basis).float()\n        self.register_buffer(\"mel_basis\", mel_basis)\n\n    def spectral_normalize(self, magnitudes, normalize_fun):\n        output = dynamic_range_compression(magnitudes, normalize_fun)\n        return output\n\n    def spectral_de_normalize(self, magnitudes):\n        output = dynamic_range_decompression(magnitudes)\n        return output\n\n    def mel_spectrogram(self, y, normalize_fun=torch.log):\n        \"\"\"Computes mel-spectrograms from a batch of waves\n        PARAMS\n        ------\n        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]\n\n        RETURNS\n        -------\n        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)\n        \"\"\"\n        assert torch.min(y.data) >= -1, torch.min(y.data)\n        assert torch.max(y.data) <= 1, torch.max(y.data)\n\n        magnitudes, phases = self.stft_fn.transform(y)\n        magnitudes = magnitudes.data\n        mel_output = torch.matmul(self.mel_basis, magnitudes)\n        mel_output = self.spectral_normalize(mel_output, normalize_fun)\n        energy = torch.norm(magnitudes, dim=1)\n\n        return mel_output, magnitudes, phases, energy\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/audio/tools.py",
    "content": "import torch\nimport numpy as np\nfrom scipy.io.wavfile import write\nimport torchaudio\n\nfrom utilities.audio.audio_processing import griffin_lim\n\n\ndef get_mel_from_wav(audio, _stft):\n    audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)\n    audio = torch.autograd.Variable(audio, requires_grad=False)\n    melspec, magnitudes, phases, energy = _stft.mel_spectrogram(audio)\n    melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)\n    magnitudes = torch.squeeze(magnitudes, 0).numpy().astype(np.float32)\n    energy = torch.squeeze(energy, 0).numpy().astype(np.float32)\n    return melspec, magnitudes, energy\n\n\ndef inv_mel_spec(mel, out_filename, _stft, griffin_iters=60):\n    mel = torch.stack([mel])\n    mel_decompress = _stft.spectral_de_normalize(mel)\n    mel_decompress = mel_decompress.transpose(1, 2).data.cpu()\n    spec_from_mel_scaling = 1000\n    spec_from_mel = torch.mm(mel_decompress[0], _stft.mel_basis)\n    spec_from_mel = spec_from_mel.transpose(0, 1).unsqueeze(0)\n    spec_from_mel = spec_from_mel * spec_from_mel_scaling\n\n    audio = griffin_lim(\n        torch.autograd.Variable(spec_from_mel[:, :, :-1]), _stft._stft_fn, griffin_iters\n    )\n\n    audio = audio.squeeze()\n    audio = audio.cpu().numpy()\n    audio_path = out_filename\n    write(audio_path, _stft.sampling_rate, audio)\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/model.py",
    "content": "import os\nimport json\n\nimport torch\nimport numpy as np\n\nimport semanticodec.modules.decoder.hifigan as hifigan\n\n\ndef get_available_checkpoint_keys(model, ckpt):\n    print(\"==> Attemp to reload from %s\" % ckpt)\n    state_dict = torch.load(ckpt)[\"state_dict\"]\n    current_state_dict = model.state_dict()\n    new_state_dict = {}\n    for k in state_dict.keys():\n        if (\n            k in current_state_dict.keys()\n            and current_state_dict[k].size() == state_dict[k].size()\n        ):\n            new_state_dict[k] = state_dict[k]\n        else:\n            print(\"==> WARNING: Skipping %s\" % k)\n    print(\n        \"%s out of %s keys are matched\"\n        % (len(new_state_dict.keys()), len(state_dict.keys()))\n    )\n    return new_state_dict\n\n\ndef get_param_num(model):\n    num_param = sum(param.numel() for param in model.parameters())\n    return num_param\n\n\ndef torch_version_orig_mod_remove(state_dict):\n    new_state_dict = {}\n    new_state_dict[\"generator\"] = {}\n    for key in state_dict[\"generator\"].keys():\n        if \"_orig_mod.\" in key:\n            new_state_dict[\"generator\"][key.replace(\"_orig_mod.\", \"\")] = state_dict[\n                \"generator\"\n            ][key]\n        else:\n            new_state_dict[\"generator\"][key] = state_dict[\"generator\"][key]\n    return new_state_dict\n\n\ndef get_vocoder(config, device, mel_bins):\n    config = {\n        \"resblock\": \"1\",\n        \"batch_size\": 16,\n        \"learning_rate\": 0.0002,\n        \"adam_b1\": 0.8,\n        \"adam_b2\": 0.99,\n        \"lr_decay\": 0.999,\n        \"seed\": 1234,\n        \"upsample_rates\": [5, 4, 2, 2, 2],\n        \"upsample_kernel_sizes\": [16, 16, 8, 4, 4],\n        \"upsample_initial_channel\": 1024,\n        \"resblock_kernel_sizes\": [3, 7, 11],\n        \"resblock_dilation_sizes\": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],\n        \"segment_size\": 8192,\n        \"num_mels\": 64,\n        \"num_freq\": 1025,\n        \"n_fft\": 1024,\n        \"hop_size\": 160,\n        \"win_size\": 1024,\n        \"sampling_rate\": 16000,\n        \"fmin\": 0,\n        \"fmax\": 8000,\n        \"fmax_for_loss\": None,\n        \"num_workers\": 4,\n        \"dist_config\": {\n            \"dist_backend\": \"nccl\",\n            \"dist_url\": \"tcp://localhost:54321\",\n            \"world_size\": 1\n        }\n    }\n    config = hifigan.AttrDict(config)\n    vocoder = hifigan.Generator_old(config)\n    vocoder.eval()\n    vocoder.remove_weight_norm()\n    vocoder.to(device)\n    return vocoder\n\n\ndef vocoder_infer(mels, vocoder, lengths=None):\n    with torch.no_grad():\n        wavs = vocoder(mels).squeeze(1)\n\n    wavs = (wavs.cpu().numpy() * 32768).astype(\"int16\")\n\n    if lengths is not None:\n        wavs = wavs[:, :lengths]\n\n    return wavs\n"
  },
  {
    "path": "semanticodec/modules/decoder/utilities/tools.py",
    "content": "# Author: Haohe Liu\n# Email: haoheliu@gmail.com\n# Date: 11 Feb 2023\n\nimport os\nimport json\n\nimport torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy.io import wavfile\n\n\nimport hashlib\nimport os\n\nimport requests\nfrom tqdm import tqdm\n\nURL_MAP = {\n    \"vggishish_lpaps\": \"https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt\",\n    \"vggishish_mean_std_melspec_10s_22050hz\": \"https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt\",\n    \"melception\": \"https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt\",\n}\n\nCKPT_MAP = {\n    \"vggishish_lpaps\": \"vggishish16.pt\",\n    \"vggishish_mean_std_melspec_10s_22050hz\": \"train_means_stds_melspec_10s_22050hz.txt\",\n    \"melception\": \"melception-21-05-10T09-28-40.pt\",\n}\n\nMD5_MAP = {\n    \"vggishish_lpaps\": \"197040c524a07ccacf7715d7080a80bd\",\n    \"vggishish_mean_std_melspec_10s_22050hz\": \"f449c6fd0e248936c16f6d22492bb625\",\n    \"melception\": \"a71a41041e945b457c7d3d814bbcf72d\",\n}\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef load_json(fname):\n    with open(fname, \"r\") as f:\n        data = json.load(f)\n        return data\n\n\ndef read_json(dataset_json_file):\n    with open(dataset_json_file, \"r\") as fp:\n        data_json = json.load(fp)\n    return data_json[\"data\"]\n\n\ndef copy_test_subset_data(metadata, testset_copy_target_path):\n    # metadata = read_json(testset_metadata)\n    os.makedirs(testset_copy_target_path, exist_ok=True)\n    if len(os.listdir(testset_copy_target_path)) == len(metadata):\n        return\n    else:\n        # delete files in folder testset_copy_target_path\n        for file in os.listdir(testset_copy_target_path):\n            try:\n                os.remove(os.path.join(testset_copy_target_path, file))\n            except Exception as e:\n                print(e)\n\n    print(\"Copying test subset data to {}\".format(testset_copy_target_path))\n    for each in tqdm(metadata):\n        cmd = \"cp {} {}\".format(each[\"wav\"], os.path.join(testset_copy_target_path))\n        os.system(cmd)\n\n\ndef listdir_nohidden(path):\n    for f in os.listdir(path):\n        if not f.startswith(\".\"):\n            yield f\n\ndef get_restore_step(path):\n    checkpoints = os.listdir(path)\n    if os.path.exists(os.path.join(path, \"final.ckpt\")):\n        return \"final.ckpt\", 0\n    elif not os.path.exists(os.path.join(path, \"last.ckpt\")):\n        steps = [int(x.split(\".ckpt\")[0].split(\"step=\")[1]) for x in checkpoints]\n        return checkpoints[np.argmax(steps)], np.max(steps)\n    else:\n        steps = []\n        for x in checkpoints:\n            if \"last\" in x:\n                if \"-v\" not in x:\n                    fname = \"last.ckpt\"\n                else:\n                    this_version = int(x.split(\".ckpt\")[0].split(\"-v\")[1])\n                    steps.append(this_version)\n                    if len(steps) == 0 or this_version > np.max(steps):\n                        fname = \"last-v%s.ckpt\" % this_version\n        return fname, 0\n\n\ndef download(url, local_path, chunk_size=1024):\n    os.makedirs(os.path.split(local_path)[0], exist_ok=True)\n    with requests.get(url, stream=True) as r:\n        total_size = int(r.headers.get(\"content-length\", 0))\n        with tqdm(total=total_size, unit=\"B\", unit_scale=True) as pbar:\n            with open(local_path, \"wb\") as f:\n                for data in r.iter_content(chunk_size=chunk_size):\n                    if data:\n                        f.write(data)\n                        pbar.update(chunk_size)\n\n\ndef md5_hash(path):\n    with open(path, \"rb\") as f:\n        content = f.read()\n    return hashlib.md5(content).hexdigest()\n\n\ndef get_ckpt_path(name, root, check=False):\n    assert name in URL_MAP\n    path = os.path.join(root, CKPT_MAP[name])\n    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):\n        print(\"Downloading {} model from {} to {}\".format(name, URL_MAP[name], path))\n        download(URL_MAP[name], path)\n        md5 = md5_hash(path)\n        assert md5 == MD5_MAP[name], md5\n    return path\n\n\nclass KeyNotFoundError(Exception):\n    def __init__(self, cause, keys=None, visited=None):\n        self.cause = cause\n        self.keys = keys\n        self.visited = visited\n        messages = list()\n        if keys is not None:\n            messages.append(\"Key not found: {}\".format(keys))\n        if visited is not None:\n            messages.append(\"Visited: {}\".format(visited))\n        messages.append(\"Cause:\\n{}\".format(cause))\n        message = \"\\n\".join(messages)\n        super().__init__(message)\n\n\ndef retrieve(\n    list_or_dict, key, splitval=\"/\", default=None, expand=True, pass_success=False\n):\n    \"\"\"Given a nested list or dict return the desired value at key expanding\n    callable nodes if necessary and :attr:`expand` is ``True``. The expansion\n    is done in-place.\n\n    Parameters\n    ----------\n        list_or_dict : list or dict\n            Possibly nested list or dictionary.\n        key : str\n            key/to/value, path like string describing all keys necessary to\n            consider to get to the desired value. List indices can also be\n            passed here.\n        splitval : str\n            String that defines the delimiter between keys of the\n            different depth levels in `key`.\n        default : obj\n            Value returned if :attr:`key` is not found.\n        expand : bool\n            Whether to expand callable nodes on the path or not.\n\n    Returns\n    -------\n        The desired value or if :attr:`default` is not ``None`` and the\n        :attr:`key` is not found returns ``default``.\n\n    Raises\n    ------\n        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is\n        ``None``.\n    \"\"\"\n\n    keys = key.split(splitval)\n\n    success = True\n    try:\n        visited = []\n        parent = None\n        last_key = None\n        for key in keys:\n            if callable(list_or_dict):\n                if not expand:\n                    raise KeyNotFoundError(\n                        ValueError(\n                            \"Trying to get past callable node with expand=False.\"\n                        ),\n                        keys=keys,\n                        visited=visited,\n                    )\n                list_or_dict = list_or_dict()\n                parent[last_key] = list_or_dict\n\n            last_key = key\n            parent = list_or_dict\n\n            try:\n                if isinstance(list_or_dict, dict):\n                    list_or_dict = list_or_dict[key]\n                else:\n                    list_or_dict = list_or_dict[int(key)]\n            except (KeyError, IndexError, ValueError) as e:\n                raise KeyNotFoundError(e, keys=keys, visited=visited)\n\n            visited += [key]\n        # final expansion of retrieved value\n        if expand and callable(list_or_dict):\n            list_or_dict = list_or_dict()\n            parent[last_key] = list_or_dict\n    except KeyNotFoundError as e:\n        if default is None:\n            raise e\n        else:\n            list_or_dict = default\n            success = False\n\n    if not pass_success:\n        return list_or_dict\n    else:\n        return list_or_dict, success\n\n\ndef to_device(data, device):\n    if len(data) == 12:\n        (\n            ids,\n            raw_texts,\n            speakers,\n            texts,\n            src_lens,\n            max_src_len,\n            mels,\n            mel_lens,\n            max_mel_len,\n            pitches,\n            energies,\n            durations,\n        ) = data\n\n        speakers = torch.from_numpy(speakers).long().to(device)\n        texts = torch.from_numpy(texts).long().to(device)\n        src_lens = torch.from_numpy(src_lens).to(device)\n        mels = torch.from_numpy(mels).float().to(device)\n        mel_lens = torch.from_numpy(mel_lens).to(device)\n        pitches = torch.from_numpy(pitches).float().to(device)\n        energies = torch.from_numpy(energies).to(device)\n        durations = torch.from_numpy(durations).long().to(device)\n\n        return (\n            ids,\n            raw_texts,\n            speakers,\n            texts,\n            src_lens,\n            max_src_len,\n            mels,\n            mel_lens,\n            max_mel_len,\n            pitches,\n            energies,\n            durations,\n        )\n\n    if len(data) == 6:\n        (ids, raw_texts, speakers, texts, src_lens, max_src_len) = data\n\n        speakers = torch.from_numpy(speakers).long().to(device)\n        texts = torch.from_numpy(texts).long().to(device)\n        src_lens = torch.from_numpy(src_lens).to(device)\n\n        return (ids, raw_texts, speakers, texts, src_lens, max_src_len)\n\n\ndef log(logger, step=None, fig=None, audio=None, sampling_rate=22050, tag=\"\"):\n    # if losses is not None:\n    #     logger.add_scalar(\"Loss/total_loss\", losses[0], step)\n    #     logger.add_scalar(\"Loss/mel_loss\", losses[1], step)\n    #     logger.add_scalar(\"Loss/mel_postnet_loss\", losses[2], step)\n    #     logger.add_scalar(\"Loss/pitch_loss\", losses[3], step)\n    #     logger.add_scalar(\"Loss/energy_loss\", losses[4], step)\n    #     logger.add_scalar(\"Loss/duration_loss\", losses[5], step)\n    #     if(len(losses) > 6):\n    #         logger.add_scalar(\"Loss/disc_loss\", losses[6], step)\n    #         logger.add_scalar(\"Loss/fmap_loss\", losses[7], step)\n    #         logger.add_scalar(\"Loss/r_loss\", losses[8], step)\n    #         logger.add_scalar(\"Loss/g_loss\", losses[9], step)\n    #         logger.add_scalar(\"Loss/gen_loss\", losses[10], step)\n    #         logger.add_scalar(\"Loss/diff_loss\", losses[11], step)\n\n    if fig is not None:\n        logger.add_figure(tag, fig)\n\n    if audio is not None:\n        audio = audio / (max(abs(audio)) * 1.1)\n        logger.add_audio(\n            tag,\n            audio,\n            sample_rate=sampling_rate,\n        )\n\n\ndef get_mask_from_lengths(lengths, max_len=None):\n    batch_size = lengths.shape[0]\n    if max_len is None:\n        max_len = torch.max(lengths).item()\n\n    ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(device)\n    mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)\n\n    return mask\n\n\ndef expand(values, durations):\n    out = list()\n    for value, d in zip(values, durations):\n        out += [value] * max(0, int(d))\n    return np.array(out)\n\n\ndef synth_one_sample(mel_input, mel_prediction, labels, vocoder):\n    if vocoder is not None:\n        from .model import vocoder_infer\n\n        wav_reconstruction = vocoder_infer(\n            mel_input.permute(0, 2, 1),\n            vocoder,\n        )\n        wav_prediction = vocoder_infer(\n            mel_prediction.permute(0, 2, 1),\n            vocoder,\n        )\n    else:\n        wav_reconstruction = wav_prediction = None\n\n    return wav_reconstruction, wav_prediction\n\n\ndef pad_1D(inputs, PAD=0):\n    def pad_data(x, length, PAD):\n        x_padded = np.pad(\n            x, (0, length - x.shape[0]), mode=\"constant\", constant_values=PAD\n        )\n        return x_padded\n\n    max_len = max((len(x) for x in inputs))\n    padded = np.stack([pad_data(x, max_len, PAD) for x in inputs])\n\n    return padded\n\n\ndef pad_2D(inputs, maxlen=None):\n    def pad(x, max_len):\n        PAD = 0\n        if np.shape(x)[0] > max_len:\n            raise ValueError(\"not max_len\")\n\n        s = np.shape(x)[1]\n        x_padded = np.pad(\n            x, (0, max_len - np.shape(x)[0]), mode=\"constant\", constant_values=PAD\n        )\n        return x_padded[:, :s]\n\n    if maxlen:\n        output = np.stack([pad(x, maxlen) for x in inputs])\n    else:\n        max_len = max(np.shape(x)[0] for x in inputs)\n        output = np.stack([pad(x, max_len) for x in inputs])\n\n    return output\n\n\ndef pad(input_ele, mel_max_length=None):\n    if mel_max_length:\n        max_len = mel_max_length\n    else:\n        max_len = max([input_ele[i].size(0) for i in range(len(input_ele))])\n\n    out_list = list()\n    for i, batch in enumerate(input_ele):\n        if len(batch.shape) == 1:\n            one_batch_padded = F.pad(\n                batch, (0, max_len - batch.size(0)), \"constant\", 0.0\n            )\n        elif len(batch.shape) == 2:\n            one_batch_padded = F.pad(\n                batch, (0, 0, 0, max_len - batch.size(0)), \"constant\", 0.0\n            )\n        out_list.append(one_batch_padded)\n    out_padded = torch.stack(out_list)\n    return out_padded\n"
  },
  {
    "path": "semanticodec/modules/encoder/__init__.py",
    "content": "\n"
  },
  {
    "path": "semanticodec/modules/encoder/encoder.py",
    "content": "import torch\nimport math\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom semanticodec.modules.audiomae.AudioMAE import Vanilla_AudioMAE\nfrom vector_quantize_pytorch import VectorQuantize\nfrom vector_quantize_pytorch import ResidualVQ\nfrom semanticodec.utils import (\n    concat_1x2,\n    concat_2x2,\n    PositionalEncoding,\n    extract_kaldi_fbank_feature,\n)\n\n\nclass AudioMAEConditionQuantResEncoder(nn.Module):\n    def __init__(\n        self,\n        centroid_npy_path=None,\n        feature_dimension=768,\n        codebook_size=8192,\n        codebook_dim=None,\n        use_cosine_sim=False,\n        decay=0.9,\n        residual_encoder=\"lstm\",\n        lstm_layer=2,\n        lstm_bidirectional=True,\n        commitment_weight=1.0,\n        rvq_layers=0,\n        use_oracle=False,\n        use_positional_embedding=True,\n    ):\n        super().__init__()\n        self.use_oracle = use_oracle\n        self.use_positional_embedding = use_positional_embedding\n        self.residual_encoder = residual_encoder\n        self.downsampling_rate = feature_dimension // 768\n        self.feature_dimension = feature_dimension\n        self.device = None\n        self.pos_embed = PositionalEncoding(seq_length=512, embedding_dim=192)\n\n        assert centroid_npy_path is not None, \"centroid_npy_path is required\"\n        self.centroid_npy = torch.from_numpy(np.load(centroid_npy_path))\n\n        self.centroid_npy.requires_grad = False\n        self.audiomae = Vanilla_AudioMAE()\n        self.audiomae.eval()\n        for p in self.audiomae.parameters():\n            p.requires_grad = False\n\n        self.no_audiomae_mask = True\n        self.no_audiomae_average = False\n\n        if self.residual_encoder == \"lstm\":\n            self.encoder = nn.LSTM(\n                input_size=feature_dimension * 2,\n                hidden_size=feature_dimension * 2,\n                num_layers=lstm_layer,\n                bias=True,\n                batch_first=True,\n                bidirectional=lstm_bidirectional,\n            )\n        else:\n            raise ValueError(\"Invalid model name %s\" % self.residual_encoder)\n\n        self.encoder_output_linear = nn.Linear(\n            in_features=feature_dimension * 2\n            if not lstm_bidirectional\n            else feature_dimension * 4,\n            out_features=feature_dimension,\n            bias=False,\n        )\n\n        self.rvq_layers = rvq_layers\n        self.codebook_size = codebook_size\n        if rvq_layers <= 0:\n            self.quantizer = VectorQuantize(\n                dim=feature_dimension,\n                codebook_size=codebook_size,\n                decay=decay,\n                commitment_weight=commitment_weight,\n                codebook_dim=codebook_dim,\n                use_cosine_sim=use_cosine_sim,\n            )\n        else:\n            self.quantizer = ResidualVQ(\n                dim=feature_dimension,\n                num_quantizers=rvq_layers,  # specify number of quantizers\n                codebook_size=codebook_size,  # codebook size\n            )\n\n        self.indices_statistic_count = 0\n        self.indices_statistic = {}\n        self.eval()\n\n    def mark_out_padding(self, feature, padding_cutoff_index):\n        feature_temporal_dim = feature.shape[-2]\n        for i, index in enumerate(padding_cutoff_index):\n            feature_cutoff_index = math.ceil(feature_temporal_dim * index)\n            feature[i, int(feature_cutoff_index) :] *= 0.0\n            feature[i, int(feature_cutoff_index) :] -= 1.0\n        return feature\n\n    # Required\n    def get_unconditional_condition(self, batchsize):\n        param = next(self.audiomae.parameters())\n        assert param.requires_grad == False\n        device = param.device\n        token_num = 512\n        representation_quant = (\n            torch.zeros((batchsize, token_num, 768)).to(device).float()\n        )\n        if self.use_positional_embedding:\n            pe = self.pos_embed(representation_quant)\n            representation_quant = torch.cat(\n                [representation_quant, pe.repeat(batchsize, 1, 1)], dim=-1\n            )\n        return [\n            representation_quant,\n            torch.ones((batchsize, token_num)).to(device).float(),\n        ]\n\n    def quant_mem_efficient(\n        self, representation, first_token_removed=False, feature_dim=768\n    ):\n        assert representation.size(-1) % 768 == 0\n        # Removing the first token and keeping the shape as [batch_size, seq_length - 1, 768] for clarity\n\n        if not first_token_removed:\n            representation = representation[\n                :, 1:, :\n            ]  # Shape: [batch_size, seq_length - 1, 768]\n\n        # Compute squared norms of each row in representation\n        norm_rep = representation.pow(2).sum(\n            dim=2, keepdim=True\n        )  # Shape: [batch_size, seq_length - 1, 1]\n\n        # Compute squared norms of centroids\n        norm_cent = self.centroid_npy.pow(2).sum(\n            dim=1, keepdim=True\n        )  # Shape: [2048, 1]\n\n        # Compute dot products\n        # Reshape representation for batch matrix multiplication: [batch_size * (seq_length - 1), 768]\n        rep_flat = representation.reshape(-1, feature_dim)\n        # Dot product, need to transpose centroids: [batch_size * (seq_length - 1), 2048]\n        dot_product = torch.mm(rep_flat, self.centroid_npy.t())\n        dot_product = dot_product.reshape(\n            representation.shape[0], representation.shape[1], -1\n        )  # Reshape back\n\n        # Compute L2 distance using the formula: ||a-b||^2 = ||a||^2 + ||b||^2 - 2*a.b\n        distances = norm_rep + norm_cent.t() - 2 * dot_product  # Correct broadcasting\n\n        # Find the index of the closest centroid for each vector\n        _, tokens = torch.min(distances, dim=2)  # Shape: [batch_size, seq_length - 1]\n\n        return tokens\n\n    def unquant(self, tokens):\n        \"\"\"\n        Project the quantized tokens into continuous representation with self.centroid_npy.\n        Args:\n            tokens (torch.Tensor): The quantized tokens, shape [batch_size, seq_length - 1]\n        Returns:\n            torch.Tensor: The continuous representation, shape [batch_size, seq_length - 1, feature_dim]\n        \"\"\"\n        return F.embedding(\n            tokens, self.centroid_npy\n        )  # Shape: [batch_size, seq_length - 1, 768]\n\n    def indices_utilization_statistic(self, indices):\n        # indices shape: [batchsize, 256, self.rvq_layers], values are integer codebook indices\n        if indices.dim() == 2:\n            indices = indices.unsqueeze(-1)\n\n        # Update statistics with current indices\n        batch_size, _, rvq_layers = indices.shape\n\n        # Initialize the statistic data structure if not already done\n        if not self.indices_statistic:\n            # Create a list of dictionaries, one for each RVQ layer\n            self.indices_statistic = [{} for _ in range(rvq_layers)]\n\n        # Process each RVQ layer separately\n        for layer in range(rvq_layers):\n            layer_indices = (\n                indices[:, :, layer].view(-1).cpu().tolist()\n            )  # Flatten and convert to list for easy counting\n            for idx in layer_indices:\n                if idx in self.indices_statistic[layer]:\n                    self.indices_statistic[layer][idx] += 1\n                else:\n                    self.indices_statistic[layer][idx] = 1\n\n        # Update count and possibly calculate statistics\n        if self.indices_statistic_count % 10000 == 0:\n            # Calculate and print statistics for each codebook\n            for layer, stats in enumerate(self.indices_statistic):\n                utilization_rate = len(list(stats.keys())) / self.codebook_size\n                utilizations = list(stats.values())\n                print(\n                    f\"\\n\\nLayer {layer} Utilization Rate: {utilization_rate}\",\n                    \"max utilization\",\n                    {max(utilizations)},\n                    \"min utilization\",\n                    {min(utilizations)},\n                    \"std\",\n                    np.std(utilizations),\n                    \"median\",\n                    np.median(utilizations),\n                    \"\\n\\n\",\n                )\n                metrics = {\n                    f\"codec/{layer}_utilization\": utilization_rate,\n                    f\"codec/{layer}_utilization_max\": max(utilizations),\n                    f\"codec/{layer}_utilization_min\": min(utilizations),\n                    f\"codec/{layer}_utilization_std\": np.std(utilizations),\n                    f\"codec/{layer}_utilization_median\": np.median(utilizations),\n                }\n                print(\"\\n\")\n                print(metrics)\n                print(\"\\n\")\n\n            self.indices_statistic = [{} for _ in range(rvq_layers)]\n            self.indices_statistic_count = 0\n\n        self.indices_statistic_count += 1\n\n    def concate(self, representation):\n        assert representation.size(-1) == 768\n        representation = representation[:, 1:, :].transpose(1, 2)\n        bs, embedding_dim, token_num = representation.size()\n        representation = representation.reshape(bs, embedding_dim, 64, 8).permute(\n            0, 2, 3, 1\n        )\n        if self.downsampling_rate == 2:\n            concatenated = concat_1x2(representation)\n        elif self.downsampling_rate == 4:\n            concatenated = concat_2x2(representation)\n        else:\n            raise ValueError(\"Invalid downsampling rate %s\" % self.downsampling_rate)\n        return concatenated  # [bs, token_num, embedding_dim]\n\n    def get_unconditional_condition(self, batchsize):\n        param = next(self.audiomae.parameters())\n        assert param.requires_grad == False\n        device = param.device\n        token_num = 512 // self.downsampling_rate\n        representation_quant = (\n            torch.zeros((batchsize, token_num, self.feature_dimension))\n            .to(device)\n            .float()\n        )\n        if self.use_positional_embedding:\n            pe = self.pos_embed(representation_quant)\n            if not self.use_oracle:\n                representation_quant = torch.cat(\n                    [\n                        representation_quant,\n                        representation_quant,\n                        pe.repeat(batchsize, 1, 1),\n                    ],\n                    dim=-1,\n                )\n            else:\n                representation_quant = torch.cat(\n                    [representation_quant, pe.repeat(batchsize, 1, 1)],\n                    dim=-1,\n                )\n        return self.wrap_return_dict(\n            crossattn_audiomae_pooled=[\n                representation_quant,\n                torch.ones((batchsize, token_num)).to(device).float(),\n            ],\n            commit_loss=torch.zeros((1,)).to(device),\n        )\n\n    def long_token_split_window(self, tokens, window_length=512, overlap=0.0625):\n        # Overlap 0.64 seconds\n        # batch: [batchsize, token_length, embedding_dimension]\n        # Split into segments with overlap\n        _, token_length, _ = tokens.size()\n        overlap = int(window_length * overlap)\n        current_start = 0\n        token_window_list = []\n        while current_start + window_length < token_length:\n            current_batch = tokens[:, current_start : current_start + window_length, :]\n            token_window_list.append(current_batch)\n            current_start += window_length - overlap\n\n        remaining_batch = tokens[:, current_start:, :]\n\n        if remaining_batch.size(-2) > 0:\n            # Pad to window length\n            # remaining_batch = F.pad(remaining_batch, (0, 0, 0, window_length - remaining_batch.size(-2), 0, 0))\n            token_window_list.append(remaining_batch)\n        return token_window_list\n\n    def forward(self, batch):\n        # Perform padding before this function\n        # Trim the audio token after this function\n        assert batch.size(-1) == 128 and batch.size(-2) % 1024 == 0\n        if self.device is None:\n            self.device = batch.device\n            self.centroid_npy = self.centroid_npy.to(self.device)\n\n        window_length = 1024\n        current_start = 0\n        total_length_batch = batch.size(-2)\n\n        tokens_list = []\n        quantized_feature_list = []\n        while current_start + window_length <= total_length_batch:\n            current_batch = batch[:, current_start : current_start + window_length, :]\n            with torch.no_grad():\n                # [bs, 513, 768]\n                output = self._forward(current_batch)\n                tokens_list.append(output[\"tokens\"])\n                quantized_feature_list.append(output[\"quantized_feature\"])\n            current_start += window_length\n        return torch.cat(tokens_list, dim=1)\n\n    def _forward(self, batch):\n        assert batch.size(-2) == 1024 and batch.size(-1) == 128\n\n        if self.device is None:\n            self.device = batch.device\n            self.centroid_npy = self.centroid_npy.to(self.device)\n\n        batch = batch.unsqueeze(1)\n\n        padding_cutoff_index = []\n        temporal_dim = batch.shape[-2]\n        for i in range(batch.shape[0]):\n            active_index = (\n                torch.std(batch[i, 0], dim=-1) <= 1e-7\n            )  # F F T T F F T T T T T\n            # If there are empty segment in the audio or there are padding in the audio\n            try:\n                if active_index.any():\n                    # Convert boolean tensor to integer tensor where False becomes 0\n                    int_tensor = active_index == False\n                    # Find indices where the tensor is False\n                    false_indices = torch.nonzero(int_tensor, as_tuple=False).squeeze()\n                    # Get the last index of False\n                    # last_false_index = false_indices[-1].item() if false_indices.numel() > 0 else -1\n                    if false_indices.numel() > 0:\n                        last_false_index = false_indices[-1].item()\n                    else:\n                        last_false_index = -1\n                    column_max = last_false_index + 1\n                # If there are no any empty segment in the audio\n                else:\n                    column_max = temporal_dim\n            except Exception as e:\n                import traceback\n\n                traceback.print_exc()\n                print(false_indices)\n                print(false_indices.numel())\n                column_max = 0\n\n            padding_cutoff_index.append(column_max / temporal_dim)\n\n        with torch.no_grad():\n            # [bs, 513, 768]\n            representation = self.audiomae(\n                batch,\n                no_mask=self.no_audiomae_mask,\n                no_average=self.no_audiomae_average,\n            )\n\n            if self.downsampling_rate != 1:\n                representation = self.concate(representation)\n                representation = (\n                    representation.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1)\n                )\n            else:\n                representation = representation[:, 1:, :]\n\n        if not self.use_oracle:\n            # Quantize the audiomae representation to tokens\n            tokens = self.quant_mem_efficient(\n                representation,\n                first_token_removed=True,\n                feature_dim=self.feature_dimension,\n            )\n\n            # Change the token back to the representations, which information losed\n            representation_quant = self.unquant(tokens)\n            audiomae_feature_after_quant = representation_quant.clone()\n            representation_quant_stack_unquant = torch.cat(\n                [representation, representation_quant], dim=-1\n            )\n\n            representation_quant_stack_unquant = self.mark_out_padding(\n                representation_quant_stack_unquant, padding_cutoff_index\n            )\n\n            # Use the encoder to extract extra information for conditioning\n            if self.residual_encoder == \"transformer\":\n                representation_residual = self.encoder(\n                    representation_quant_stack_unquant.permute(0, 2, 1)\n                ).permute(0, 2, 1)\n            elif (\n                self.residual_encoder == \"lstm\"\n                or self.residual_encoder == \"mamba\"\n                or self.residual_encoder == \"ResidualLSTM\"\n            ):\n                representation_residual = self.encoder(\n                    representation_quant_stack_unquant\n                )\n\n            # If you use LSTM as encoder\n            if type(representation_residual) == tuple:\n                representation_residual = representation_residual[0]\n\n            representation_residual = self.encoder_output_linear(\n                representation_residual\n            )\n            representation_residual_quant, indices, commit_loss = self.quantizer(\n                representation_residual\n            )\n            # import ipdb; ipdb.set_trace()\n            # assert torch.max(self.quantizer.get_output_from_indices(indices).reshape(1, 512, 768)-representation_residual_quant) <= 1e-5\n            tokens = torch.cat([tokens.unsqueeze(-1), indices.unsqueeze(-1)], dim=-1)\n            representation_quant = torch.cat(\n                [representation_residual_quant, representation_quant], dim=-1\n            )\n        else:\n            # Oracle\n            param = next(self.audiomae.parameters())\n            assert param.requires_grad == False\n            tokens = None\n\n            representation_quant = torch.cat([representation], dim=-1)\n\n        representation_quant = self.mark_out_padding(\n            representation_quant, padding_cutoff_index\n        )\n\n        if self.use_positional_embedding:\n            pe = self.pos_embed(representation_quant).to(representation_quant.device)\n            representation_quant = torch.cat(\n                [representation_quant, pe.repeat(representation_quant.size(0), 1, 1)],\n                dim=-1,\n            )\n\n        return self.wrap_return_dict(\n            crossattn_audiomae_pooled=[\n                representation_quant,\n                torch.ones((representation_quant.size(0), representation_quant.size(1)))\n                .to(representation_quant.device)\n                .float(),\n            ],\n            tokens=tokens,\n        )\n\n    def token_to_quantized_feature(self, tokens):\n        semantic_tokens, acoustic_tokens = tokens[..., 0], tokens[..., 1]\n        semantic_feature = self.unquant(semantic_tokens)\n        token_num, feature_dim = semantic_feature.shape[-2], semantic_feature.shape[-1]\n        acoustic_feature = self.quantizer.get_output_from_indices(\n            acoustic_tokens\n        ).reshape(1, token_num, feature_dim)\n        return torch.cat([acoustic_feature, semantic_feature], dim=-1)\n\n    def wrap_return_dict(self, crossattn_audiomae_pooled, tokens):\n        return {\"quantized_feature\": crossattn_audiomae_pooled, \"tokens\": tokens}\n"
  },
  {
    "path": "semanticodec/utils.py",
    "content": "import torch\nimport math\nimport torch.nn as nn\n\nimport torchaudio\n\n\ndef concat_1x2(tensor):\n    batchsize, width, height, channels = tensor.shape\n    # Check if height is divisible by 2 for concatenation\n    if height % 2 != 0:\n        raise ValueError(\"Height must be divisible by 2 for 1x2 concatenation.\")\n    # Reshape to group 1x2 blocks\n    tensor_reshaped = tensor.view(batchsize, width, height // 2, 2, channels)\n    # Permute to move the 1x2 blocks next to the channel dimension\n    tensor_permuted = tensor_reshaped.permute(0, 1, 2, 3, 4)\n    # Concatenate the 1x2 blocks along the channel dimension\n    tensor_concat = tensor_permuted.reshape(batchsize, width, height // 2, channels * 2)\n    return tensor_concat\n\n\ndef concat_2x2(tensor):\n    batchsize, width, height, channels = tensor.shape\n    # Reshape to group 2x2 blocks\n    tensor_reshaped = tensor.view(batchsize, width // 2, 2, height // 2, 2, channels)\n    # Permute to move the 2x2 blocks next to the channel dimension\n    tensor_permuted = tensor_reshaped.permute(0, 1, 3, 2, 4, 5)\n    # Concatenate the 2x2 blocks along the channel dimension\n    tensor_concat = tensor_permuted.reshape(\n        batchsize, width // 2, height // 2, channels * 4\n    )\n    return tensor_concat\n\n\ndef extract_kaldi_fbank_feature(waveform, sampling_rate, target_length=1024):\n    norm_mean = -4.2677393\n    norm_std = 4.5689974\n\n    sampling_rate = sampling_rate\n\n    if sampling_rate != 16000:\n        waveform_16k = torchaudio.functional.resample(\n            waveform, orig_freq=sampling_rate, new_freq=16000\n        )\n    else:\n        waveform_16k = waveform\n\n    waveform_16k = waveform_16k - waveform_16k.mean()\n    fbank = torchaudio.compliance.kaldi.fbank(\n        waveform_16k,\n        htk_compat=True,\n        sample_frequency=16000,\n        use_energy=False,\n        window_type=\"hanning\",\n        num_mel_bins=128,\n        dither=0.0,\n        frame_shift=10,\n    )\n\n    TARGET_LEN = target_length\n\n    # cut and pad\n    n_frames = fbank.shape[0]\n    p = TARGET_LEN - n_frames\n    if p > 0:\n        m = torch.nn.ZeroPad2d((0, 0, 0, p))\n        fbank = m(fbank)\n    elif p < 0:\n        fbank = fbank[:TARGET_LEN, :]\n\n    fbank = (fbank - norm_mean) / (norm_std * 2)\n\n    return {\"ta_kaldi_fbank\": fbank}  # [1024, 128]\n\n\nclass PositionalEncoding:\n    def __init__(self, seq_length=512, embedding_dim=192):\n        self.seq_length = seq_length\n        self.embedding_dim = embedding_dim\n\n        # Initialize positional embeddings\n        position = torch.arange(seq_length).unsqueeze(1)\n        div_term = torch.exp(\n            torch.arange(0, embedding_dim, 2) * -(math.log(10000.0) / embedding_dim)\n        )\n        pe = torch.zeros(seq_length, embedding_dim)\n        pe[:, 0::2] = torch.sin(position * div_term)\n        pe[:, 1::2] = torch.cos(position * div_term)\n\n        # Add a 'batch' dimension with 'unsqueeze'\n        self.pe = pe.unsqueeze(0)\n\n    def __call__(self, x):\n        \"\"\"\n        Args:\n            x: Tensor, shape [batch_size, seq_length, embedding_dim]\n        \"\"\"\n        # return positional embeddings\n        return self.pe[:, : x.size(1)]\n"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\n# -*- encoding: utf-8 -*-\n# python3 setup.py sdist bdist_wheel\n\"\"\"\n@File    :   setup.py.py    \n@Contact :   haoheliu@gmail.com\n@License :   (C)Copyright 2020-2100\n\n@Modify Time      @Author    @Version    @Desciption\n------------      -------    --------    -----------\n3/5/24 5:16 PM   Haohe Liu      1.0         None\n\"\"\"\n\n# !/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n# Note: To use the 'upload' functionality of this file, you must:\n#   $ pipenv install twine --dev\n\nimport io\nimport os\nimport sys\nfrom shutil import rmtree\n\nfrom setuptools import find_packages, setup, Command\n\n# Package meta-data.\nNAME = \"semanticodec\"\nDESCRIPTION = \"This package is written for semanticodec\"\nURL = \"https://github.com/haoheliu/semanticodec-inference\"\nEMAIL = \"haoheliu@gmail.com\"\nAUTHOR = \"Haohe Liu\"\nREQUIRES_PYTHON = \">=3.8.0\"\nVERSION = \"0.0.1\"\n\n# What packages are required for this module to be executed?\nREQUIRED = [\"torch\", \"torchaudio\", \"soundfile\", \"vector-quantize-pytorch\", \"huggingface_hub\", \"timm\", \"scipy\"]\n\n# What packages are optional?\nEXTRAS = {}\n\n# The rest you shouldn't have to touch too much :)\n# ------------------------------------------------\n# Except, perhaps the License and Trove Classifiers!\n# If you do change the License, remember to change the Trove Classifier for that!\n\nhere = os.path.abspath(os.path.dirname(__file__))\n\n# Import the README and use it as the long-description.\n# Note: this will only work if 'README.md' is present in your MANIFEST.in file!\ntry:\n    with io.open(os.path.join(here, \"README.md\"), encoding=\"utf-8\") as f:\n        long_description = \"\\n\" + f.read()\nexcept FileNotFoundError:\n    long_description = DESCRIPTION\n\n# Load the package's __version__.py module as a dictionary.\nabout = {}\nif not VERSION:\n    project_slug = NAME.lower().replace(\"-\", \"_\").replace(\" \", \"_\")\n    with open(os.path.join(here, project_slug, \"__version__.py\")) as f:\n        exec(f.read(), about)\nelse:\n    about[\"__version__\"] = VERSION\n\n\nclass UploadCommand(Command):\n    \"\"\"Support setup.py upload.\"\"\"\n\n    description = \"Build and publish the package.\"\n    user_options = []\n\n    @staticmethod\n    def status(s):\n        \"\"\"Prints things in bold.\"\"\"\n        print(\"\\033[1m{0}\\033[0m\".format(s))\n\n    def initialize_options(self):\n        pass\n\n    def finalize_options(self):\n        pass\n\n    def run(self):\n        try:\n            self.status(\"Removing previous builds…\")\n            rmtree(os.path.join(here, \"dist\"))\n        except OSError:\n            pass\n\n        self.status(\"Building Source and Wheel (universal) distribution…\")\n        os.system(\"{0} setup.py sdist bdist_wheel --universal\".format(sys.executable))\n\n        self.status(\"Uploading the package to PyPI via Twine…\")\n        os.system(\"twine upload dist/*\")\n\n        self.status(\"Pushing git tags…\")\n        os.system(\"git tag v{0}\".format(about[\"__version__\"]))\n        os.system(\"git push --tags\")\n\n        sys.exit()\n\n\n# Where the magic happens:\nsetup(\n    name=NAME,\n    version=about[\"__version__\"],\n    description=DESCRIPTION,\n    long_description=long_description,\n    long_description_content_type=\"text/markdown\",\n    author=AUTHOR,\n    author_email=EMAIL,\n    python_requires=REQUIRES_PYTHON,\n    url=URL,\n    # packages=find_packages(exclude=[\"tests\", \"*.tests\", \"*.tests.*\", \"tests.*\"]),\n    # If your package is a single module, use this instead of 'packages':\n    py_modules=[\"semanticodec\"],\n    # entry_points={\n    #     'console_scripts': ['mycli=mymodule:cli'],\n    # },\n    install_requires=REQUIRED,\n    extras_require=EXTRAS,\n    packages=find_packages(),\n    include_package_data=True,\n    license=\"MIT\",\n    classifiers=[\n        # Trove classifiers\n        # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers\n        \"License :: OSI Approved :: MIT License\",\n        \"Programming Language :: Python\",\n        \"Programming Language :: Python :: 3\",\n        \"Programming Language :: Python :: 3.7\",\n        \"Programming Language :: Python :: Implementation :: CPython\",\n        \"Programming Language :: Python :: Implementation :: PyPy\",\n    ],\n    # $ setup.py publish support.\n    cmdclass={\n        \"upload\": UploadCommand,\n    },\n)\n"
  },
  {
    "path": "test/encoding.py",
    "content": "from semanticodec import SemantiCodec\nimport soundfile as sf\n\nsemanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=16384)  # 1.35 kbps\nfilepath = \"test.wav\"\n\ntokens = semanticodec.encode(filepath)\nwaveform = semanticodec.decode(tokens)\n\nsf.write(\"test_output_100.wav\", waveform[0, 0], 16000)\n#########################################################\n\nsemanticodec = SemantiCodec(token_rate=50, semantic_vocab_size=16384)  # 0.68 kbps\n\ntokens = semanticodec.encode(filepath)\nwaveform = semanticodec.decode(tokens)\n\nsf.write(\"test_output_50.wav\", waveform[0, 0], 16000)\n#########################################################\n\nsemanticodec = SemantiCodec(token_rate=25, semantic_vocab_size=16384)  # 0.34 kbps\n\ntokens = semanticodec.encode(filepath)\nwaveform = semanticodec.decode(tokens)\n\nsf.write(\"test_output_25.wav\", waveform[0, 0], 16000)\n"
  },
  {
    "path": "test/test_all_settings.py",
    "content": "from semanticodec import SemantiCodec\nimport soundfile as sf\n\ndef test_semanticodec(token_rate, semantic_vocab_size, test_id):\n    print(f\"Testing with token_rate: {token_rate}, semantic_vocab_size: {semantic_vocab_size}\")\n    semanticodec = SemantiCodec(token_rate=token_rate, semantic_vocab_size=semantic_vocab_size)\n    filepath = \"test.wav\"\n    \n    # Encoding and decoding process\n    tokens = semanticodec.encode(filepath)\n    waveform = semanticodec.decode(tokens)\n    \n    # Writing the output to a file\n    output_filename = f\"output_{test_id}.wav\"\n    sf.write(output_filename, waveform[0, 0], 16000)\n    print(f\"Output written to {output_filename}\\n\")\n    del semanticodec\n\n# Test cases\ntest_cases = [\n    (100, 32768),\n    (50, 32768),\n    (25, 32768),\n    (100, 16384),\n    (50, 16384),\n    (25, 16384),\n    (100, 8192),\n    (50, 8192),\n    (25, 8192),\n    (100, 4096),\n    (50, 4096),\n    (25, 4096)\n]\n\n# Running all test cases\nfor idx, (rate, vocab_size) in enumerate(test_cases, start=1):\n    test_semanticodec(rate, vocab_size, idx)"
  }
]