[
  {
    "path": ".github/FUNDING.yml",
    "content": "github: yuvraj108c\ncustom: [\"https://paypal.me/yuvraj108c\", \"https://buymeacoffee.com/yuvraj108cz\"]\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "content": "name: Publish to Comfy registry\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n      - master\n    paths:\n      - \"pyproject.toml\"\n\npermissions:\n  issues: write\n\njobs:\n  publish-node:\n    name: Publish Custom Node to registry\n    runs-on: ubuntu-latest\n    if: ${{ github.repository_owner == 'yuvraj108c' }}\n    steps:\n      - name: Check out code\n        uses: actions/checkout@v4\n        with:\n          submodules: true\n      - name: Publish Custom Node\n        uses: Comfy-Org/publish-node-action@v1\n        with:\n          ## Add your own personal access token to your Github Repository secrets and reference it here.\n          personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}\n"
  },
  {
    "path": ".gitignore",
    "content": ".DS_Store\n*pyc\n.vscode\n__pycache__\n# *.egg-info\n*.bak\ncheckpoints\nresults\nbackup"
  },
  {
    "path": "LICENSE",
    "content": "S-Lab License 1.0\n\nCopyright 2024 S-Lab\n\nRedistribution and use for non-commercial purpose in source and \nbinary forms, with or without modification, are permitted provided \nthat the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright \n   notice, this list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright \n   notice, this list of conditions and the following disclaimer in \n   the documentation and/or other materials provided with the \n   distribution.\n\n3. Neither the name of the copyright holder nor the names of its \n   contributors may be used to endorse or promote products derived \n   from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT \nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR \nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT \nHOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, \nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT \nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, \nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY \nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT \n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nIn the event that redistribution and/or use for commercial purpose in \nsource or binary forms, with or without modification is required, \nplease contact the contributor(s) of the work.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# ComfyUI InvSR\n[![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.09013-b31b1b.svg)](https://arxiv.org/abs/2412.09013) \n\nThis project is a ComfyUI wrapper for [InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)\n\n**Last tested**: 2 January 2026 (ComfyUI v0.7.0@f2fda02 | Torch 2.9.1 | Python 3.10.12 | RTX4090 | CUDA 13.0 | Debian 12)\n\n<img height=\"400\" src=\"https://github.com/user-attachments/assets/6c057a3c-3355-4060-9161-a88ab6f6d986\" />\n\n</div>\n\n## ⭐ Support\nIf you like my projects and wish to see updates and new features, please consider supporting me. It helps a lot! \n\n[![ComfyUI-Depth-Anything-Tensorrt](https://img.shields.io/badge/ComfyUI--Depth--Anything--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt)\n[![ComfyUI-Upscaler-Tensorrt](https://img.shields.io/badge/ComfyUI--Upscaler--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt)\n[![ComfyUI-Dwpose-Tensorrt](https://img.shields.io/badge/ComfyUI--Dwpose--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Dwpose-Tensorrt)\n[![ComfyUI-Rife-Tensorrt](https://img.shields.io/badge/ComfyUI--Rife--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Rife-Tensorrt)\n\n[![ComfyUI-Whisper](https://img.shields.io/badge/ComfyUI--Whisper-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Whisper)\n[![ComfyUI_InvSR](https://img.shields.io/badge/ComfyUI__InvSR-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI_InvSR)\n[![ComfyUI-Thera](https://img.shields.io/badge/ComfyUI--Thera-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Thera)\n[![ComfyUI-Video-Depth-Anything](https://img.shields.io/badge/ComfyUI--Video--Depth--Anything-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Video-Depth-Anything)\n[![ComfyUI-PiperTTS](https://img.shields.io/badge/ComfyUI--PiperTTS-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-PiperTTS)\n\n[![buy-me-coffees](https://i.imgur.com/3MDbAtw.png)](https://www.buymeacoffee.com/yuvraj108cZ)\n[![paypal-donation](https://i.imgur.com/w5jjubk.png)](https://paypal.me/yuvraj108c)\n---\n\n## Installation\nNavigate to the ComfyUI `/custom_nodes` directory\n```bash\ngit clone https://github.com/yuvraj108c/ComfyUI_InvSR\ncd ComfyUI_InvSR\n\npip install -r requirements.txt\n```\n\n## Usage\n- Load [example workflow](workflows/invsr.json) \n- Diffusers model (stabilityai/sd-turbo) will download automatically to `ComfyUI/models/diffusers`\n- InvSR model (noise_predictor_sd_turbo_v5.pth) will download automatically to `ComfyUI/models/invsr`\n- To deal with large images, e.g, 1k---->4k, set `chopping_size` 256\n- If your GPU memory is limited, please set `chopping_batch_size` to 1\n\n## Parameters\n- `num_steps`: number of inference steps\n- `cfg`: classifier-free guidance scale\n- `batch_size`: Controls how many complete images are processed simultaneously\n- `chopping_batch_size`: Controls how many patches from the same image are processed simultaneously\n- `chopping_size`: Controls the size of patches when splitting large images\n- `color_fix`: Method to fix color shift in processed images\n\n## Updates\n**28 April 2025**\n- Update diffusers versions in requirements.txt to fix https://github.com/yuvraj108c/ComfyUI_InvSR/issues/26, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/21, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/15\n- Add support for `noise_predictor_sd_turbo_v5_diftune.pth`\n  \n**03 February 2025**\n- Add cfg parameter\n- Make image divisible by 16\n- Use `mm` to set torch device\n  \n**31 January 2025**\n- Merged https://github.com/yuvraj108c/ComfyUI_InvSR/pull/5 by [wfjsw](https://github.com/wfjsw)\n  - Compatibility with `diffusers>=0.28`\n  - Massive code refactoring & cleanup\n\n## Citation\n```bibtex\n@article{yue2024InvSR,\n  title={Arbitrary-steps Image Super-resolution via Diffusion Inversion},\n  author={Yue, Zongsheng and Kang, Liao and Loy, Chen Change},\n  journal = {arXiv preprint arXiv:2412.09013},\n  year={2024},\n}\n```\n\n## License\nThis project is licensed under [NTU S-Lab License 1.0](LICENSE)\n\n## Acknowledgments\nThanks to [simplepod.ai](https://simplepod.ai/) for providing GPU servers\n\n## Star History\n[![Star History Chart](https://api.star-history.com/svg?repos=yuvraj108c/ComfyUI_InvSR&type=Date)](https://star-history.com/#yuvraj108c/ComfyUI_InvSR&Date)\n"
  },
  {
    "path": "__init__.py",
    "content": "from .node import LoadInvSRModels, InvSRSampler\n \nNODE_CLASS_MAPPINGS = { \n    \"LoadInvSRModels\" : LoadInvSRModels,\n    \"InvSRSampler\" : InvSRSampler\n}\n\nNODE_DISPLAY_NAME_MAPPINGS = {\n     \"LoadInvSRModels\" : \"Load InvSR Models\",\n     \"InvSRSampler\" : \"InvSRSampler\"\n}\n\n__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']\n"
  },
  {
    "path": "comfyui_invsr_trimmed/__init__.py",
    "content": "from .inference_invsr import get_configs, Namespace\nfrom .sampler_invsr import InvSamplerSR, BaseSampler\nfrom .noise_predictor import NoisePredictor\nfrom .time_aware_encoder import TimeAwareEncoder\n\n__all__ = [\n    \"get_configs\", \n    \"Namespace\",\n    \"InvSamplerSR\", \n    \"BaseSampler\", \n    \"NoisePredictor\", \n    \"TimeAwareEncoder\"\n]\n"
  },
  {
    "path": "comfyui_invsr_trimmed/inference_invsr.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2023-03-11 17:17:41\n\nimport numpy as np\nfrom pathlib import Path\nfrom omegaconf import OmegaConf\nfrom .sampler_invsr import InvSamplerSR, BaseSampler\n\nfrom .utils import util_common\nfrom .utils.util_opts import str2bool\nfrom huggingface_hub import hf_hub_download\nfrom shutil import copy2\n\nclass Namespace:\n    def __init__(self, **kwargs):\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n    \n    def __repr__(self):\n        items = [f\"{key}={repr(value)}\" for key, value in vars(self).items()]\n        return f\"Namespace({', '.join(items)})\"\n\ndef get_configs(args, log=False):\n    configs = OmegaConf.load(args.cfg_path)\n\n    if args.timesteps is not None:\n        assert len(args.timesteps) == args.num_steps\n        configs.timesteps = sorted(args.timesteps, reverse=True)\n    else:\n        if args.num_steps == 1:\n            configs.timesteps = [200,]\n        elif args.num_steps == 2:\n            configs.timesteps = [200, 100]\n        elif args.num_steps == 3:\n            configs.timesteps = [200, 100, 50]\n        elif args.num_steps == 4:\n            configs.timesteps = [200, 150, 100, 50]\n        elif args.num_steps == 5:\n            configs.timesteps = [250, 200, 150, 100, 50]\n        else:\n            assert args.num_steps <= 250\n            configs.timesteps = np.linspace(\n                start=args.started_step, stop=0, num=args.num_steps, endpoint=False, dtype=np.int64()\n            ).tolist()\n    if log:\n        print(f'[InvSR] - Setting timesteps for inference: {configs.timesteps}')\n\n    # path to save Stable Diffusion\n    sd_path = args.sd_path if args.sd_path else \"./weights\"\n    util_common.mkdir(sd_path, delete=False, parents=True)\n    configs.sd_pipe.params.cache_dir = sd_path\n\n    # path to save noise predictor\n    started_ckpt_name = args.invsr_model\n\n    if getattr(args, \"started_ckpt_dir\", None) is not None:\n        started_ckpt_dir = args.started_ckpt_dir\n    else:\n        started_ckpt_dir = \"./weights\"\n\n    if getattr(args, \"started_ckpt_path\", None) is not None:\n        started_ckpt_path = args.started_ckpt_path\n    else:\n        started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name\n        util_common.mkdir(started_ckpt_dir, delete=False, parents=True)\n\n    if not Path(started_ckpt_path).exists():\n        temp_path = hf_hub_download(\n            repo_id=\"OAOA/InvSR\",\n            filename=started_ckpt_name,\n        )\n        copy2(temp_path, started_ckpt_path)\n    configs.model_start.ckpt_path = str(started_ckpt_path)\n\n    configs.bs = args.bs\n    configs.tiled_vae = args.tiled_vae\n    configs.color_fix = args.color_fix\n    configs.basesr.chopping.pch_size = args.chopping_size\n    if args.bs > 1:\n        configs.basesr.chopping.extra_bs = 1\n    else:\n        configs.basesr.chopping.extra_bs = args.chopping_bs\n\n    return configs\n"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/__init__.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/lpips.py",
    "content": "\nfrom __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\nimport numpy as np\nfrom . import pretrained_networks as pn\n\ndef normalize_tensor(in_feat,eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))\n    return in_feat/(norm_factor+eps)\n\ndef spatial_average(in_tens, keepdim=True):\n    return in_tens.mean([2,3],keepdim=keepdim)\n\ndef upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W\n    in_H, in_W = in_tens.shape[2], in_tens.shape[3]\n    return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)\n\n# Learned perceptual metric\nclass LPIPS(nn.Module):\n    def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,\n        pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True,\n                latent=False, in_chans=3, verbose=True):\n        \"\"\" Initializes a perceptual loss torch.nn.Module\n\n        Parameters (default listed first)\n        ---------------------------------\n        lpips : bool\n            [True] use linear layers on top of base/trunk network\n            [False] means no linear layers; each layer is averaged together\n        pretrained : bool\n            This flag controls the linear layers, which are only in effect when lpips=True above\n            [True] means linear layers are calibrated with human perceptual judgments\n            [False] means linear layers are randomly initialized\n        pnet_rand : bool\n            [False] means trunk loaded with ImageNet classification weights\n            [True] means randomly initialized trunk\n        net : str\n            ['alex','vgg','squeeze'] are the base/trunk networks available\n        version : str\n            ['v0.1'] is the default and latest\n            ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)\n        model_path : 'str'\n            [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1\n\n        The following parameters should only be changed if training the network\n\n        eval_mode : bool\n            [True] is for test mode (default)\n            [False] is for training mode\n        pnet_tune\n            [False] keep base/trunk frozen\n            [True] tune the base/trunk network\n        use_dropout : bool\n            [True] to use dropout when training linear layers\n            [False] for no dropout when training linear layers\n        \"\"\"\n\n        super(LPIPS, self).__init__()\n        if(verbose):\n            print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%\n                ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))\n\n        self.pnet_type = net\n        self.pnet_tune = pnet_tune\n        self.pnet_rand = pnet_rand\n        self.spatial = spatial\n        self.latent = latent\n        self.lpips = lpips # false means baseline of just averaging all layers\n        self.version = version\n        self.scaling_layer = ScalingLayer()\n\n        if(self.pnet_type in ['vgg','vgg16']):\n            if not latent:\n                net_type = pn.vgg16\n            else:\n                net_type = pn.vgg16_latent\n            self.chns = [64,128,256,512,512]\n        elif(self.pnet_type=='alex'):\n            net_type = pn.alexnet\n            self.chns = [64,192,384,256,256]\n        elif(self.pnet_type=='squeeze'):\n            net_type = pn.squeezenet\n            self.chns = [64,128,256,384,384,512,512]\n        self.L = len(self.chns)\n\n        if latent:\n            self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune, in_chans=in_chans)\n        else:\n            self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)\n\n        if(lpips):\n            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]\n            if(self.pnet_type=='squeeze'): # 7 layers for squeezenet\n                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)\n                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)\n                self.lins+=[self.lin5,self.lin6]\n            self.lins = nn.ModuleList(self.lins)\n\n            if(pretrained):\n                if(model_path is None):\n                    import inspect\n                    import os\n                    model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))\n\n                if(verbose):\n                    print('Loading model from: %s'%model_path)\n                missing_keys, unexpected_keys = self.load_state_dict(\n                        torch.load(model_path, map_location='cpu'),\n                        strict=False,\n                        )\n                print(f'Number of missing keys when loading chckepoint: {len(missing_keys)}')\n                print(f'Number of unexpected keys when loading chckepoint: {len(unexpected_keys)}')\n\n        if(eval_mode):\n            self.eval()\n\n    def forward(self, in0, in1, retPerLayer=False, normalize=False):\n        if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]\n            in0 = 2 * in0  - 1\n            in1 = 2 * in1  - 1\n\n        # v0.0 - original release had a bug, where input was not scaled\n        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if (not self.latent and self.version=='0.1') else (in0, in1)\n        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n\n        for kk in range(self.L):\n            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk]-feats1[kk])**2\n\n        if(self.lpips):\n            if(self.spatial):\n                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]\n        else:\n            if(self.spatial):\n                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]\n\n        val = 0\n        for l in range(self.L):\n            val += res[l]\n\n        if(retPerLayer):\n            return (val, res)\n        else:\n            return val\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])\n        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])\n\n    def forward(self, inp):\n        return (inp - self.shift) / self.scale\n\nclass NetLinLayer(nn.Module):\n    ''' A single linear layer which does a 1x1 conv '''\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n\n        layers = [nn.Dropout(),] if(use_dropout) else []\n        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.model(x)\n\nclass Dist2LogitLayer(nn.Module):\n    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''\n    def __init__(self, chn_mid=32, use_sigmoid=True):\n        super(Dist2LogitLayer, self).__init__()\n\n        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]\n        layers += [nn.LeakyReLU(0.2,True),]\n        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]\n        layers += [nn.LeakyReLU(0.2,True),]\n        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]\n        if(use_sigmoid):\n            layers += [nn.Sigmoid(),]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self,d0,d1,eps=0.1):\n        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))\n\nclass BCERankingLoss(nn.Module):\n    def __init__(self, chn_mid=32):\n        super(BCERankingLoss, self).__init__()\n        self.net = Dist2LogitLayer(chn_mid=chn_mid, use_sigmoid=False)\n        # self.parameters = list(self.net.parameters())\n        # self.loss = torch.nn.BCELoss()\n        self.loss = torch.nn.BCEWithLogitsLoss()\n\n    def forward(self, d0, d1, judge):\n        per = (judge+1.)/2.\n        self.logit = self.net.forward(d0,d1)\n        return self.loss(self.logit, per)\n\ndef print_network(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print('Network',net)\n    print('Total number of parameters: %d' % num_params)\n"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py",
    "content": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(squeezenet, self).__init__()\n        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.slice6 = torch.nn.Sequential()\n        self.slice7 = torch.nn.Sequential()\n        self.N_slices = 7\n        for x in range(2):\n            self.slice1.add_module(str(x), pretrained_features[x])\n        for x in range(2,5):\n            self.slice2.add_module(str(x), pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), pretrained_features[x])\n        for x in range(10, 11):\n            self.slice5.add_module(str(x), pretrained_features[x])\n        for x in range(11, 12):\n            self.slice6.add_module(str(x), pretrained_features[x])\n        for x in range(12, 13):\n            self.slice7.add_module(str(x), pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        h = self.slice6(h)\n        h_relu6 = h\n        h = self.slice7(h)\n        h_relu7 = h\n        vgg_outputs = namedtuple(\"SqueezeOutputs\", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])\n        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)\n\n        return out\n\n\nclass alexnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(alexnet, self).__init__()\n        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(2):\n            self.slice1.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(10, 12):\n            self.slice5.add_module(str(x), alexnet_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        alexnet_outputs = namedtuple(\"AlexnetOutputs\", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])\n        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)\n\n        return out\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n        return out\n\nclass vgg16_latent(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True, in_chans=3):\n        super(vgg16_latent, self).__init__()\n        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        # max pooling layers: vgg_pretrained_features[5, 9, 16, 23]\n        for x in range(4):\n            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        if (not in_chans == 3):\n            # assert in_chans == 4\n            weight = self.slice1[0].weight.data[:, 0,].unsqueeze(1).repeat(1, in_chans, 1, 1)\n            self.slice1[0].weight.data = weight\n        for x in range(5, 9): # skip max pooling at index 5\n            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(10, 16): # skip max pooling at index 9\n            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(17, 23): # skip max pooling at index 16\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n        return out\n\n\nclass resnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True, num=18):\n        super(resnet, self).__init__()\n        if(num==18):\n            self.net = tv.resnet18(pretrained=pretrained)\n        elif(num==34):\n            self.net = tv.resnet34(pretrained=pretrained)\n        elif(num==50):\n            self.net = tv.resnet50(pretrained=pretrained)\n        elif(num==101):\n            self.net = tv.resnet101(pretrained=pretrained)\n        elif(num==152):\n            self.net = tv.resnet152(pretrained=pretrained)\n        self.N_slices = 5\n\n        self.conv1 = self.net.conv1\n        self.bn1 = self.net.bn1\n        self.relu = self.net.relu\n        self.maxpool = self.net.maxpool\n        self.layer1 = self.net.layer1\n        self.layer2 = self.net.layer2\n        self.layer3 = self.net.layer3\n        self.layer4 = self.net.layer4\n\n    def forward(self, X):\n        h = self.conv1(X)\n        h = self.bn1(h)\n        h = self.relu(h)\n        h_relu1 = h\n        h = self.maxpool(h)\n        h = self.layer1(h)\n        h_conv2 = h\n        h = self.layer2(h)\n        h_conv3 = h\n        h = self.layer3(h)\n        h_conv4 = h\n        h = self.layer4(h)\n        h_conv5 = h\n\n        outputs = namedtuple(\"Outputs\", ['relu1','conv2','conv3','conv4','conv5'])\n        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)\n\n        return out\n"
  },
  {
    "path": "comfyui_invsr_trimmed/noise_predictor.py",
    "content": "from typing import Dict, Optional, Tuple, Union\nimport torch\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders.single_file_model import FromOriginalModelMixin\nfrom diffusers.models.autoencoders.vae import (\n    Decoder,\n    DecoderOutput,\n    DiagonalGaussianDistribution,\n    Encoder,\n)\nfrom diffusers.models.attention_processor import (\n    ADDED_KV_ATTENTION_PROCESSORS,\n    CROSS_ATTENTION_PROCESSORS,\n    AttentionProcessor,\n    AttnAddedKVProcessor,\n    AttnProcessor,\n)\nfrom diffusers.models.modeling_outputs import AutoencoderKLOutput\nfrom diffusers.utils.accelerate_utils import apply_forward_hook\nfrom .time_aware_encoder import TimeAwareEncoder\n\nclass NoisePredictor(ModelMixin, ConfigMixin, FromOriginalModelMixin):\n    r\"\"\"\n    A noise predicted model from the encoder of AutoencoderKL.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"DownEncoderBlock2D\",)`):\n            Tuple of downsample block types.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpDecoderBlock2D\",)`):\n            Tuple of upsample block types.\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):\n            Tuple of block output channels.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.\n        sample_size (`int`, *optional*, defaults to `32`): Sample input size.\n        mid_block_add_attention (`bool`, *optional*, default to `True`):\n            If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the\n            mid_block will only have resnet blocks\n        temb_channels (`int`, *optional*, default to 256): Number of channels for time embedding\n        freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):\n            Whether to flip sin to cos for Fourier time embedding.\n        double_z (`bool`, *optional*, defaults to `True`):\n            Whether to double the number of output channels for the last block.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"BasicTransformerBlock\", \"ResnetBlock2D\"]\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels: int = 3,\n        down_block_types: Tuple[str] = (\"DownEncoderBlock2D\",),\n        up_block_types: Tuple[str] = (\"UpDecoderBlock2D\",),\n        block_out_channels: Tuple[int] = (64,),\n        layers_per_block: int = 1,\n        act_fn: str = \"silu\",\n        latent_channels: int = 4,\n        norm_num_groups: int = 32,\n        sample_size: int = 32,\n        mid_block_add_attention: bool = True,\n        attention_head_dim: int = 1,\n        resnet_time_scale_shift: str = \"default\",\n        temb_channels: int = 256,\n        freq_shift: int = 0,\n        flip_sin_to_cos: bool = True,\n        double_z: bool = True,\n    ):\n        super().__init__()\n\n        # pass init params to Encoder\n        self.encoder = TimeAwareEncoder(\n            in_channels=in_channels,\n            out_channels=latent_channels,\n            down_block_types=down_block_types,\n            block_out_channels=block_out_channels,\n            layers_per_block=layers_per_block,\n            act_fn=act_fn,\n            norm_num_groups=norm_num_groups,\n            double_z=double_z,\n            mid_block_add_attention=mid_block_add_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n            freq_shift=freq_shift,\n            flip_sin_to_cos=flip_sin_to_cos,\n            attention_head_dim=attention_head_dim,\n        )\n\n        self.use_slicing = False\n        self.use_tiling = False\n        self.double_z = double_z\n\n        # only relevant if vae tiling is enabled\n        self.tile_sample_min_size = self.config.sample_size\n        sample_size = (\n            self.config.sample_size[0]\n            if isinstance(self.config.sample_size, (list, tuple))\n            else self.config.sample_size\n        )\n        self.tile_latent_min_size = int(\n            sample_size / (2 ** (len(self.config.block_out_channels) - 1))\n        )\n        self.tile_overlap_factor = 0.25\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (Encoder, Decoder)):\n            module.gradient_checkpointing = value\n\n    def enable_tiling(self, use_tiling: bool = True):\n        r\"\"\"\n        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to\n        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow\n        processing larger images.\n        \"\"\"\n        self.use_tiling = use_tiling\n\n    def disable_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.enable_tiling(False)\n\n    def enable_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to\n        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.use_slicing = True\n\n    def disable_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing\n        decoding in one step.\n        \"\"\"\n        self.use_slicing = False\n\n    @property\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(\n            name: str,\n            module: torch.nn.Module,\n            processors: Dict[str, AttentionProcessor],\n        ):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(\n        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]\n    ):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        if all(\n            proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnAddedKVProcessor()\n        elif all(\n            proc.__class__ in CROSS_ATTENTION_PROCESSORS\n            for proc in self.attn_processors.values()\n        ):\n            processor = AttnProcessor()\n        else:\n            raise ValueError(\n                f\"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}\"\n            )\n\n        self.set_attn_processor(processor)\n\n    @apply_forward_hook\n    def encode(\n        self,\n        x: torch.Tensor,\n        timestep: Union[int, torch.Tensor],\n        return_dict: bool = True,\n    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:\n        \"\"\"\n        Encode a batch of images into latents.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n                The latent representations of the encoded images. If `return_dict` is True, a\n                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.\n        \"\"\"\n        if self.use_tiling and (\n            x.shape[-1] > self.tile_sample_min_size\n            or x.shape[-2] > self.tile_sample_min_size\n        ):\n            return self.tiled_encode(x, timestep, return_dict=return_dict)\n\n        if self.use_slicing and x.shape[0] > 1:\n            encoded_slices = [self.encoder(x_slice, timestep) for x_slice in x.split(1)]\n            h = torch.cat(encoded_slices)\n        else:\n            h = self.encoder(x, timestep)\n\n        if not self.double_z:\n            return h\n\n        posterior = DiagonalGaussianDistribution(h)\n\n        if not return_dict:\n            return (posterior,)\n\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def tiled_encode(\n        self,\n        x: torch.Tensor,\n        timestep: Union[int, torch.Tensor],\n        return_dict: bool = True,\n    ) -> AutoencoderKLOutput:\n        r\"\"\"Encode a batch of images using a tiled encoder.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several\n        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is\n        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the\n        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the\n        output, but they should be much less noticeable.\n\n        Args:\n            x (`torch.Tensor`): Input batch of images.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:\n                If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain\n                `tuple` is returned.\n        \"\"\"\n        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))\n        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)\n        row_limit = self.tile_latent_min_size - blend_extent\n\n        # Split the image into 512x512 tiles and encode them separately.\n        rows = []\n        for i in range(0, x.shape[2], overlap_size):\n            row = []\n            for j in range(0, x.shape[3], overlap_size):\n                tile = x[\n                    :,\n                    :,\n                    i : i + self.tile_sample_min_size,\n                    j : j + self.tile_sample_min_size,\n                ]\n                tile = self.encoder(tile, timestep)\n                if self.config.use_quant_conv:\n                    tile = self.quant_conv(tile)\n                row.append(tile)\n            rows.append(row)\n        result_rows = []\n        for i, row in enumerate(rows):\n            result_row = []\n            for j, tile in enumerate(row):\n                # blend the above tile and the left tile\n                # to the current tile and add the current tile to the result row\n                if i > 0:\n                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)\n                if j > 0:\n                    tile = self.blend_h(row[j - 1], tile, blend_extent)\n                result_row.append(tile[:, :, :row_limit, :row_limit])\n            result_rows.append(torch.cat(result_row, dim=3))\n\n        moments = torch.cat(result_rows, dim=2)\n        posterior = DiagonalGaussianDistribution(moments)\n\n        if not return_dict:\n            return (posterior,)\n\n        return AutoencoderKLOutput(latent_dist=posterior)\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timesteps: torch.Tensor,\n        sample_posterior: bool = True,\n        center_input_sample: bool = True,\n        generator: Optional[torch.Generator] = None,\n    ) -> Union[DecoderOutput, torch.Tensor]:\n        r\"\"\"\n        Args:\n            sample (`torch.Tensor`): Input sample.\n            sample_posterior (`bool`, *optional*, defaults to `False`):\n                Whether to sample from the posterior.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.\n        \"\"\"\n        if center_input_sample:\n            sample = sample * 2 - 1.0\n\n        if not self.double_z:\n            h = self.encode(sample, timesteps)\n            return h\n        else:\n            posterior = self.encode(sample, timesteps).latent_dist\n\n        if sample_posterior:\n            return posterior.sample()\n        else:\n            return posterior\n"
  },
  {
    "path": "comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py",
    "content": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union, Tuple\n\nimport numpy as np\nimport PIL.Image\nimport torch\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection\n\nfrom diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback\nfrom diffusers.configuration_utils import FrozenDict\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel\nfrom diffusers.models.lora import adjust_lora_scale_text_encoder\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    USE_PEFT_BACKEND,\n    deprecate,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import (\n    StableDiffusionSafetyChecker,\n)\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import requests\n        >>> import torch\n        >>> from PIL import Image\n        >>> from io import BytesIO\n\n        >>> from diffusers import StableDiffusionImg2ImgPipeline\n\n        >>> device = \"cuda\"\n        >>> model_id_or_path = \"runwayml/stable-diffusion-v1-5\"\n        >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)\n        >>> pipe = pipe.to(device)\n\n        >>> url = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n\n        >>> response = requests.get(url)\n        >>> init_image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n        >>> init_image = init_image.resize((768, 512))\n\n        >>> prompt = \"A fantasy landscape, trending on artstation\"\n\n        >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images\n        >>> images[0].save(\"fantasy_landscape.png\")\n        ```\n\"\"\"\n\n\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\ndef preprocess(image):\n    deprecation_message = \"The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead\"\n    deprecate(\"preprocess\", \"1.0.0\", deprecation_message, standard_warn=False)\n    if isinstance(image, torch.Tensor):\n        return image\n    elif isinstance(image, PIL.Image.Image):\n        image = [image]\n\n    if isinstance(image[0], PIL.Image.Image):\n        w, h = image[0].size\n        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8\n\n        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION[\"lanczos\"]))[None, :] for i in image]\n        image = np.concatenate(image, axis=0)\n        image = np.array(image).astype(np.float32) / 255.0\n        image = image.transpose(0, 3, 1, 2)\n        image = 2.0 * image - 1.0\n        image = torch.from_numpy(image)\n    elif isinstance(image[0], torch.Tensor):\n        image = torch.cat(image, dim=0)\n    return image\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    timesteps: Optional[List[int]] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    **kwargs,\n):\n    \"\"\"\n    Prepare the sampling timesteps and noise sigmas.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    num_inference_steps = len(timesteps)\n    timesteps = torch.tensor(timesteps, dtype=torch.float32, device=device) - 1\n    scheduler.timesteps = timesteps\n\n    if not hasattr(scheduler, 'sigmas_cache'):\n        scheduler.sigmas_cache = scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000\n    sigmas = scheduler.sigmas_cache[timesteps.long()]\n\n    # minimal sigma\n    if scheduler.config.final_sigmas_type == \"sigma_min\":\n        sigma_last = ((1 - scheduler.alphas_cumprod[0]) / scheduler.alphas_cumprod[0]) ** 0.5\n    elif scheduler.config.final_sigmas_type == \"zero\":\n        sigma_last = 0\n    else:\n        raise ValueError(\n            f\"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {scheduler.config.final_sigmas_type}\"\n        )\n    sigma_last = torch.tensor([sigma_last,], dtype=torch.float32).to(device=sigmas.device)\n    sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32)\n    scheduler.sigmas = sigmas.to(\"cpu\")  # to avoid too much CPU/GPU communication\n\n    scheduler._step_index = None\n    scheduler._begin_index = None\n\n    return scheduler.timesteps, num_inference_steps\n\nclass StableDiffusionInvEnhancePipeline(\n    DiffusionPipeline,\n    StableDiffusionMixin,\n    TextualInversionLoaderMixin,\n    IPAdapterMixin,\n    StableDiffusionLoraLoaderMixin,\n    FromSingleFileMixin,\n):\n    r\"\"\"\n    Pipeline for text-guided image-to-image generation using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods\n    implemented for all pipelines (downloading, saving, running on a particular device, etc.).\n\n    The pipeline also inherits the following loading methods:\n        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings\n        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights\n        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights\n        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files\n        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.\n        text_encoder ([`~transformers.CLIPTextModel`]):\n            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).\n        tokenizer ([`~transformers.CLIPTokenizer`]):\n            A `CLIPTokenizer` to tokenize text.\n        unet ([`UNet2DConditionModel`]):\n            A `UNet2DConditionModel` to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details\n            about a model's potential harms.\n        feature_extractor ([`~transformers.CLIPImageProcessor`]):\n            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->image_encoder->unet->vae\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\", \"image_encoder\"]\n    _exclude_from_cpu_offload = [\"safety_checker\"]\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\", \"negative_prompt_embeds\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        image_encoder: CLIPVisionModelWithProjection = None,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if hasattr(scheduler.config, \"steps_offset\") and scheduler.config.steps_offset != 1:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`\"\n                f\" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure \"\n                \"to update the config accordingly as leaving `steps_offset` might led to incorrect results\"\n                \" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,\"\n                \" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`\"\n                \" file\"\n            )\n            deprecate(\"steps_offset!=1\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"steps_offset\"] = 1\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if hasattr(scheduler.config, \"clip_sample\") and scheduler.config.clip_sample is True:\n            deprecation_message = (\n                f\"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.\"\n                \" `clip_sample` should be set to False in the configuration file. Please make sure to update the\"\n                \" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in\"\n                \" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very\"\n                \" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file\"\n            )\n            deprecate(\"clip_sample not set\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(scheduler.config)\n            new_config[\"clip_sample\"] = False\n            scheduler._internal_dict = FrozenDict(new_config)\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = hasattr(unet.config, \"_diffusers_version\") and version.parse(\n            version.parse(unet.config._diffusers_version).base_version\n        ) < version.parse(\"0.9.0.dev0\")\n        is_unet_sample_size_less_64 = hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- runwayml/stable-diffusion-v1-5\"\n                \" \\n- runwayml/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False)\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            image_encoder=image_encoder,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        **kwargs,\n    ):\n        deprecation_message = \"`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple.\"\n        deprecate(\"_encode_prompt()\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        prompt_embeds_tuple = self.encode_prompt(\n            prompt=prompt,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            negative_prompt=negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=lora_scale,\n            **kwargs,\n        )\n\n        # concatenate for backwards comp\n        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        lora_scale: Optional[float] = None,\n        clip_skip: Optional[int] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n            lora_scale (`float`, *optional*):\n                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n        \"\"\"\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if not USE_PEFT_BACKEND:\n                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)\n            else:\n                scale_lora_layers(self.text_encoder, lora_scale)\n\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            if clip_skip is None:\n                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)\n                prompt_embeds = prompt_embeds[0]\n            else:\n                prompt_embeds = self.text_encoder(\n                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True\n                )\n                # Access the `hidden_states` first, that contains a tuple of\n                # all the hidden states from the encoder layers. Then index into\n                # the tuple to access the hidden states from the desired layer.\n                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]\n                # We also need to apply the final LayerNorm here to not mess with the\n                # representations. The `last_hidden_states` that we typically use for\n                # obtaining the final prompt representations passes through the LayerNorm\n                # layer.\n                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)\n\n        if self.text_encoder is not None:\n            prompt_embeds_dtype = self.text_encoder.dtype\n        elif self.unet is not None:\n            prompt_embeds_dtype = self.unet.dtype\n        else:\n            prompt_embeds_dtype = prompt_embeds.dtype\n\n        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: process multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        if self.text_encoder is not None:\n            if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        return prompt_embeds, negative_prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image\n    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if not isinstance(image, torch.Tensor):\n            image = self.feature_extractor(image, return_tensors=\"pt\").pixel_values\n\n        image = image.to(device=device, dtype=dtype)\n        if output_hidden_states:\n            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]\n            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_enc_hidden_states = self.image_encoder(\n                torch.zeros_like(image), output_hidden_states=True\n            ).hidden_states[-2]\n            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(\n                num_images_per_prompt, dim=0\n            )\n            return image_enc_hidden_states, uncond_image_enc_hidden_states\n        else:\n            image_embeds = self.image_encoder(image).image_embeds\n            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)\n            uncond_image_embeds = torch.zeros_like(image_embeds)\n\n            return image_embeds, uncond_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds\n    def prepare_ip_adapter_image_embeds(\n        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance\n    ):\n        image_embeds = []\n        if do_classifier_free_guidance:\n            negative_image_embeds = []\n        if ip_adapter_image_embeds is None:\n            if not isinstance(ip_adapter_image, list):\n                ip_adapter_image = [ip_adapter_image]\n\n            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):\n                raise ValueError(\n                    f\"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters.\"\n                )\n\n            for single_ip_adapter_image, image_proj_layer in zip(\n                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers\n            ):\n                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)\n                single_image_embeds, single_negative_image_embeds = self.encode_image(\n                    single_ip_adapter_image, device, 1, output_hidden_state\n                )\n\n                image_embeds.append(single_image_embeds[None, :])\n                if do_classifier_free_guidance:\n                    negative_image_embeds.append(single_negative_image_embeds[None, :])\n        else:\n            for single_image_embeds in ip_adapter_image_embeds:\n                if do_classifier_free_guidance:\n                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)\n                    negative_image_embeds.append(single_negative_image_embeds)\n                image_embeds.append(single_image_embeds)\n\n        ip_adapter_image_embeds = []\n        for i, single_image_embeds in enumerate(image_embeds):\n            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)\n            if do_classifier_free_guidance:\n                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)\n                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)\n\n            single_image_embeds = single_image_embeds.to(device=device)\n            ip_adapter_image_embeds.append(single_image_embeds)\n\n        return ip_adapter_image_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        deprecation_message = \"The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead\"\n        deprecate(\"decode_latents\", \"1.0.0\", deprecation_message, standard_warn=False)\n\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        strength,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        ip_adapter_image=None,\n        ip_adapter_image_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:\n            raise ValueError(\n                \"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined.\"\n            )\n\n        if ip_adapter_image_embeds is not None:\n            if not isinstance(ip_adapter_image_embeds, list):\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}\"\n                )\n            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:\n                raise ValueError(\n                    f\"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D\"\n                )\n\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)\n\n        t_start = max(num_inference_steps - init_timestep, 0)\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    def prepare_latents(\n        self, image, timestep, batch_size, num_images_per_prompt, dtype, device,\n        noise=None, generator=None,\n    ):\n        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):\n            raise ValueError(\n                f\"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}\"\n            )\n\n        image = image.to(device=device, dtype=dtype)\n\n        batch_size = batch_size * num_images_per_prompt\n\n        if image.shape[1] == 4:\n            init_latents = image\n\n        else:\n            if isinstance(generator, list) and len(generator) != batch_size:\n                raise ValueError(\n                    f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                    f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n                )\n\n            elif isinstance(generator, list):\n                if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:\n                    image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)\n                elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:\n                    raise ValueError(\n                        f\"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} \"\n                    )\n\n                init_latents = [\n                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                    for i in range(batch_size)\n                ]\n                init_latents = torch.cat(init_latents, dim=0)\n            else:\n                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n            init_latents = self.vae.config.scaling_factor * init_latents\n\n        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            deprecation_message = (\n                f\"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial\"\n                \" images (`image`). Initial images are now duplicating to match the number of text prompts. Note\"\n                \" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update\"\n                \" your script to pass as many initial images as text prompts to suppress this warning.\"\n            )\n            deprecate(\"len(prompt) != len(image)\", \"1.0.0\", deprecation_message, standard_warn=False)\n            additional_image_per_prompt = batch_size // init_latents.shape[0]\n            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            init_latents = torch.cat([init_latents], dim=0)\n\n        shape = init_latents.shape\n        if noise is None:\n            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n\n        # get latents\n        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)\n        latents = init_latents\n\n        return latents.type(dtype)\n\n    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding\n    def get_guidance_scale_embedding(\n        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32\n    ) -> torch.Tensor:\n        \"\"\"\n        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298\n\n        Args:\n            w (`torch.Tensor`):\n                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.\n            embedding_dim (`int`, *optional*, defaults to 512):\n                Dimension of the embeddings to generate.\n            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):\n                Data type of the generated embeddings.\n\n        Returns:\n            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.\n        \"\"\"\n        assert len(w.shape) == 1\n        w = w * 1000.0\n\n        half_dim = embedding_dim // 2\n        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)\n        emb = w.to(dtype)[:, None] * emb[None, :]\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n        if embedding_dim % 2 == 1:  # zero pad\n            emb = torch.nn.functional.pad(emb, (0, 1))\n        assert emb.shape == (w.shape[0], embedding_dim)\n        return emb\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def clip_skip(self):\n        return self._clip_skip\n\n    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n    # corresponds to doing no classifier free guidance.\n    @property\n    def do_classifier_free_guidance(self):\n        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None\n\n    @property\n    def cross_attention_kwargs(self):\n        return self._cross_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        image: PipelineImageInput = None,\n        target_size: Tuple[int] = (512, 512),\n        strength: float = 0.25,\n        num_inference_steps: Optional[int] = 20,\n        timesteps: List[int] = None,\n        sigmas: List[float] = None,\n        guidance_scale: Optional[float] = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: Optional[float] = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        prompt_embeds: Optional[torch.Tensor] = None,\n        negative_prompt_embeds: Optional[torch.Tensor] = None,\n        ip_adapter_image: Optional[PipelineImageInput] = None,\n        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        clip_skip: int = None,\n        callback_on_step_end: Optional[\n            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]\n        ] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        **kwargs,\n    ):\n        r\"\"\"\n        The call function to the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an low quality image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list\n                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a\n                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            target_size ('Tuple[int]'): Targeted image resolution (height, width)\n            strength (`float`, *optional*, defaults to 0.25):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 20):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference. This parameter is modulated by `strength`.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            sigmas (`List[float]`, *optional*):\n                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in\n                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed\n                will be used.\n            start_noise_predictor ('nn.Module', *optional*): Noise predictor for the initial step\n            intermediate_noise_predictor ('nn.Module', *optional*): Noise predictor for the intermediate steps.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                A higher guidance scale value encourages the model to generate images closely linked to the text\n                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide what to not include in image generation. If not defined, you need to\n                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies\n                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make\n                generation deterministic.\n            prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not\n                provided, text embeddings are generated from the `prompt` input argument.\n            negative_prompt_embeds (`torch.Tensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If\n                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.\n            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.\n            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):\n                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of\n                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should\n                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not\n                provided, embeddings are computed from the `ip_adapter_image` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generated image. Choose between `PIL.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in\n                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            clip_skip (`int`, *optional*):\n                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that\n                the output of the pre-final layer will be used for computing the prompt embeddings.\n            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):\n                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of\n                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:\n                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a\n                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n        Examples:\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,\n                otherwise a `tuple` is returned where the first element is a list with the generated images and the\n                second element is a list of `bool`s indicating whether the corresponding generated image contains\n                \"not-safe-for-work\" (nsfw) content.\n        \"\"\"\n\n        callback = kwargs.pop(\"callback\", None)\n        callback_steps = kwargs.pop(\"callback_steps\", None)\n\n        if callback is not None:\n            deprecate(\n                \"callback\",\n                \"1.0.0\",\n                \"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n        if callback_steps is not None:\n            deprecate(\n                \"callback_steps\",\n                \"1.0.0\",\n                \"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`\",\n            )\n\n        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):\n            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            strength,\n            callback_steps,\n            negative_prompt,\n            prompt_embeds,\n            negative_prompt_embeds,\n            ip_adapter_image,\n            ip_adapter_image_embeds,\n            callback_on_step_end_tensor_inputs,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._clip_skip = clip_skip\n        self._cross_attention_kwargs = cross_attention_kwargs\n        self._interrupt = False\n\n        # 2. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        # 3. Encode input prompt\n        text_encoder_lora_scale = (\n            self.cross_attention_kwargs.get(\"scale\", None) if self.cross_attention_kwargs is not None else None\n        )\n        prompt_embeds, negative_prompt_embeds = self.encode_prompt(\n            prompt,\n            device,\n            num_images_per_prompt,\n            self.do_classifier_free_guidance,\n            negative_prompt,\n            prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds,\n            lora_scale=text_encoder_lora_scale,\n            clip_skip=self.clip_skip,\n        )\n        # For classifier free guidance, we need to do two forward passes.\n        # Here we concatenate the unconditional and text embeddings into a single batch\n        # to avoid doing two forward passes\n        if self.do_classifier_free_guidance:\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:\n            image_embeds = self.prepare_ip_adapter_image_embeds(\n                ip_adapter_image,\n                ip_adapter_image_embeds,\n                device,\n                batch_size * num_images_per_prompt,\n                self.do_classifier_free_guidance,\n            )\n\n        # 4. Preprocess image\n        self.image_processor.config.do_normalize = False\n        image = self.image_processor.preprocess(image)  # [0, 1], torch tensor, (b,c,h,w)\n        self.image_processor.config.do_normalize = True\n        image_up = torch.nn.functional.interpolate(image, size=target_size, mode='bicubic') # upsampling\n        image_up = self.image_processor.normalize(image_up)  # [-1, 1]\n\n        # 5. set timesteps\n        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, timesteps, device)\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 6. Prepare latent variables\n        if getattr(self, 'start_noise_predictor', None) is not None:\n            with torch.amp.autocast('cuda'):\n                noise = self.start_noise_predictor(\n                    image, latent_timestep, sample_posterior=True, center_input_sample=True,\n                )\n        else:\n            noise = None\n        latents = self.prepare_latents(\n            image_up, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype,\n            device, noise, generator,\n        )\n\n        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7.1 Add image embeds for IP-Adapter\n        added_cond_kwargs = (\n            {\"image_embeds\": image_embeds}\n            if ip_adapter_image is not None or ip_adapter_image_embeds is not None\n            else None\n        )\n\n        # 7.2 Optionally get Guidance Scale Embedding\n        timestep_cond = None\n        if self.unet.config.time_cond_proj_dim is not None:\n            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)\n            timestep_cond = self.get_guidance_scale_embedding(\n                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim\n            ).to(device=device, dtype=latents.dtype)\n\n        # 8. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        self._num_timesteps = len(timesteps)\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents\n                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=prompt_embeds,\n                    timestep_cond=timestep_cond,\n                    cross_attention_kwargs=self.cross_attention_kwargs,\n                    added_cond_kwargs=added_cond_kwargs,\n                    return_dict=False,\n                )[0]\n\n                # perform guidance\n                if self.do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                # compute the previous noisy sample x_t -> x_t-1\n                if getattr(self, 'intermediate_noise_predictor', None) is not None and i + 1 < len(timesteps):\n                    t_next = timesteps[i+1]\n                    with torch.amp.autocast('cuda'):\n                        noise = self.intermediate_noise_predictor(image, t_next, center_input_sample=True)\n                    extra_step_kwargs['noise'] = noise\n                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n                    negative_prompt_embeds = callback_outputs.pop(\"negative_prompt_embeds\", negative_prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        step_idx = i // getattr(self.scheduler, \"order\", 1)\n                        callback(step_idx, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[\n                0\n            ]\n            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "comfyui_invsr_trimmed/sampler_invsr.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-07-13 16:59:27\n\nimport os, sys, math, random\n\nimport numpy as np\nfrom pathlib import Path\n\nfrom .utils import util_net\nfrom .utils import util_image\nfrom .utils import util_common\nfrom .utils import util_color_fix\n\nimport torch\nimport torch.nn.functional as F\nimport torch.distributed as dist\nimport torch.multiprocessing as mean_psnr\n\nfrom .pipeline_stable_diffusion_inversion_sr import StableDiffusionInvEnhancePipeline\nfrom diffusers import AutoencoderKL\nimport comfy.model_management as mm\nDEVICE = mm.get_torch_device()\n\n_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\\\n           'meticulous detailing, hyper sharpness, perfect without deformations'\n_negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\\\n           'painting, drawing, sketch, oil painting'\n\nclass BaseSampler:\n    def __init__(self, configs):\n        '''\n        Input:\n            configs: config, see the yaml file in folder ./configs/\n                configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}\n            seed: int, random seed\n        '''\n        self.configs = configs\n\n        self.setup_seed()\n\n        self.build_model()\n\n    def setup_seed(self, seed=None):\n        seed = self.configs.seed if seed is None else seed\n        random.seed(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n\n    def write_log(self, log_str):\n        print(log_str, flush=True)\n\n    def build_model(self):\n        # Build Stable diffusion\n        params = dict(self.configs.sd_pipe.params)\n        torch_dtype = params.pop('torch_dtype')\n        params['torch_dtype'] = get_torch_dtype(torch_dtype)\n        base_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)\n        if self.configs.get('scheduler', None) is not None:\n            pipe_id = self.configs.scheduler.target.split('.')[-1]\n            self.write_log(f'Loading scheduler of {pipe_id}...')\n            base_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(\n                base_pipe.scheduler.config\n            )\n            self.write_log('Loaded Done')\n        if self.configs.get('vae_fp16', None) is not None:\n            params_vae = dict(self.configs.vae_fp16.params)\n            torch_dtype = params_vae.pop('torch_dtype')\n            params_vae['torch_dtype'] = get_torch_dtype(torch_dtype)\n            pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path\n            self.write_log(f'Loading improved vae from {pipe_id}...')\n            base_pipe.vae = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(\n                **params_vae,\n            )\n            self.write_log('Loaded Done')\n        if self.configs.base_model in ['sd-turbo', 'sd2base'] :\n            sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)\n        else:\n            raise ValueError(f\"Unsupported base model: {self.configs.base_model}!\")\n        sd_pipe.to(DEVICE)\n        if self.configs.sliced_vae:\n            sd_pipe.vae.enable_slicing()\n        if self.configs.tiled_vae:\n            sd_pipe.vae.enable_tiling()\n            sd_pipe.vae.tile_latent_min_size = self.configs.latent_tiled_size\n            sd_pipe.vae.tile_sample_min_size = self.configs.sample_tiled_size\n        if self.configs.gradient_checkpointing_vae:\n            self.write_log(f\"Activating gradient checkpoing for vae...\")\n            sd_pipe.vae.enable_gradient_checkpointing()\n\n        model_configs = self.configs.model_start\n        params = model_configs.get('params', dict)\n        model_start = util_common.get_obj_from_str(model_configs.target)(**params)\n        model_start.to(DEVICE)\n        ckpt_path = model_configs.get('ckpt_path')\n        assert ckpt_path is not None\n        self.write_log(f\"[InvSR] - Loading started model from {ckpt_path}...\")\n        state = torch.load(ckpt_path, map_location=DEVICE)\n        if 'state_dict' in state:\n            state = state['state_dict']\n        util_net.reload_model(model_start, state)\n        # self.write_log(f\"Loading Done\")\n        model_start.eval()\n        setattr(sd_pipe, 'start_noise_predictor', model_start)\n\n        self.sd_pipe = sd_pipe\n\nclass InvSamplerSR(BaseSampler):\n    def __init__(self, base_sampler):\n        self.configs = base_sampler.configs\n        self.sd_pipe = base_sampler.sd_pipe\n        \n    @torch.no_grad()\n    def sample_func(self, im_cond):\n        '''\n        Input:\n            im_cond: b x c x h x w, torch tensor, [0,1], RGB\n        Output:\n            xt: h x w x c, numpy array, [0,1], RGB\n        '''\n\n        \n\n        ori_h_lq, ori_w_lq = im_cond.shape[-2:]\n        ori_w_hq = ori_w_lq * self.configs.basesr.sf\n        ori_h_hq = ori_h_lq * self.configs.basesr.sf\n        vae_sf = (2 ** (len(self.sd_pipe.vae.config.block_out_channels) - 1))\n        if hasattr(self.sd_pipe, 'unet'):\n            diffusion_sf = (2 ** (len(self.sd_pipe.unet.config.block_out_channels) - 1))\n        else:\n            diffusion_sf = self.sd_pipe.transformer.patch_size\n        mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf\n        idle_pch_size = self.configs.basesr.chopping.pch_size\n\n        if min(im_cond.shape[-2:]) >= idle_pch_size:\n            pad_h_up = pad_w_left = 0\n        else:\n            while min(im_cond.shape[-2:]) < idle_pch_size:\n                pad_h_up = max(min((idle_pch_size - im_cond.shape[-2]) // 2, im_cond.shape[-2]-1), 0)\n                pad_h_down = max(min(idle_pch_size - im_cond.shape[-2] - pad_h_up, im_cond.shape[-2]-1), 0)\n                pad_w_left = max(min((idle_pch_size - im_cond.shape[-1]) // 2, im_cond.shape[-1]-1), 0)\n                pad_w_right = max(min(idle_pch_size - im_cond.shape[-1] - pad_w_left, im_cond.shape[-1]-1), 0)\n                im_cond = F.pad(im_cond, pad=(pad_w_left, pad_w_right, pad_h_up, pad_h_down), mode='reflect')\n\n        if im_cond.shape[-2] == idle_pch_size and im_cond.shape[-1] == idle_pch_size:\n            target_size = (\n                im_cond.shape[-2] * self.configs.basesr.sf,\n                im_cond.shape[-1] * self.configs.basesr.sf\n            )\n            res_sr = self.sd_pipe(\n                image=im_cond.type(torch.float16),\n                prompt=[_positive, ]*im_cond.shape[0],\n                negative_prompt=[_negative, ]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None,\n                target_size=target_size,\n                timesteps=self.configs.timesteps,\n                guidance_scale=self.configs.cfg_scale,\n                output_type=\"pt\",    # torch tensor, b x c x h x w, [0, 1]\n            ).images\n        else:\n            if not (im_cond.shape[-2] % mod_lq == 0 and im_cond.shape[-1] % mod_lq == 0):\n                target_h_lq = math.ceil(im_cond.shape[-2] / mod_lq) * mod_lq\n                target_w_lq = math.ceil(im_cond.shape[-1] / mod_lq) * mod_lq\n                pad_h = target_h_lq - im_cond.shape[-2]\n                pad_w = target_w_lq - im_cond.shape[-1]\n                im_cond= F.pad(im_cond, pad=(0, pad_w, 0, pad_h), mode='reflect')\n\n            im_spliter = util_image.ImageSpliterTh(\n                im_cond,\n                pch_size=idle_pch_size,\n                stride= int(idle_pch_size * 0.50),\n                sf=self.configs.basesr.sf,\n                weight_type=self.configs.basesr.chopping.weight_type,\n                extra_bs=self.configs.basesr.chopping.extra_bs,\n            )\n            \n            # pbar = ProgressBar(len(im_spliter) * im_cond.shape[0])\n\n            for im_lq_pch, index_infos in im_spliter:\n                target_size = (\n                    im_lq_pch.shape[-2] * self.configs.basesr.sf,\n                    im_lq_pch.shape[-1] * self.configs.basesr.sf,\n                )\n\n                # start = torch.cuda.Event(enable_timing=True)\n                # end = torch.cuda.Event(enable_timing=True)\n                # start.record()\n\n                res_sr_pch = self.sd_pipe(\n                    image=im_lq_pch.type(torch.float16),\n                    prompt=[_positive, ]*im_lq_pch.shape[0],\n                    negative_prompt=[_negative, ]*im_lq_pch.shape[0] if self.configs.cfg_scale > 1.0 else None,\n                    target_size=target_size,\n                    timesteps=self.configs.timesteps,\n                    guidance_scale=self.configs.cfg_scale,\n                    output_type=\"pt\",    # torch tensor, b x c x h x w, [0, 1]\n                ).images\n\n                # end.record()\n                # torch.cuda.synchronize()\n                # print(f\"Time: {start.elapsed_time(end):.6f}\")\n\n                im_spliter.update(res_sr_pch, index_infos)\n                # pbar.update(im_lq_pch.shape[0])\n            res_sr = im_spliter.gather()\n\n        pad_h_up *= self.configs.basesr.sf\n        pad_w_left *= self.configs.basesr.sf\n        res_sr = res_sr[:, :, pad_h_up:ori_h_hq+pad_h_up, pad_w_left:ori_w_hq+pad_w_left]\n\n        if self.configs.color_fix:\n            im_cond_up = F.interpolate(\n                im_cond, size=res_sr.shape[-2:], mode='bicubic', align_corners=False, antialias=True\n            )\n            if self.configs.color_fix == 'ycbcr':\n                res_sr = util_color_fix.ycbcr_color_replace(res_sr, im_cond_up)\n            elif self.configs.color_fix == 'wavelet':\n                res_sr = util_color_fix.wavelet_reconstruction(res_sr, im_cond_up)\n            else:\n                raise ValueError(f\"Unsupported color fixing type: {self.configs.color_fix}\")\n\n        res_sr = res_sr.clamp(0.0, 1.0).cpu().float().numpy()\n\n        return res_sr\n\n    def inference(self, image_bchw):\n        return self.sample_func(image_bchw.to(DEVICE))\n\ndef get_torch_dtype(torch_dtype: str):\n    if torch_dtype == 'torch.float16':\n        return torch.float16\n    elif torch_dtype == 'torch.bfloat16':\n        return torch.bfloat16\n    elif torch_dtype == 'torch.float32':\n        return torch.float32\n    else:\n        raise ValueError(f'Unexpected torch dtype:{torch_dtype}')\n"
  },
  {
    "path": "comfyui_invsr_trimmed/time_aware_encoder.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom diffusers.utils import is_torch_version\nfrom diffusers.models.unets.unet_2d_blocks import (\n    UNetMidBlock2D,\n    get_down_block,\n)\nfrom diffusers.models.embeddings import TimestepEmbedding, Timesteps\n\n\nclass TimeAwareEncoder(nn.Module):\n    r\"\"\"\n    The `TimeAwareEncoder` layer of a variational autoencoder that encodes its input into a latent representation.\n\n    Args:\n        in_channels (`int`, *optional*, defaults to 3):\n            The number of input channels.\n        out_channels (`int`, *optional*, defaults to 3):\n            The number of output channels.\n        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `(\"DownEncoderBlock2D\",)`):\n            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available\n            options.\n        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):\n            The number of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2):\n            The number of layers per block.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups for normalization.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`):\n            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.\n        double_z (`bool`, *optional*, defaults to `True`):\n            Whether to double the number of output channels for the last block.\n        resnet_time_scale_shift (`str`, defaults to `\"default\"`)\n    \"\"\"\n\n    def __init__(\n        self,\n        in_channels: int = 3,\n        out_channels: int = 3,\n        down_block_types: Tuple[str, ...] = (\"DownEncoderBlock2D\",),\n        block_out_channels: Tuple[int, ...] = (64,),\n        layers_per_block: Union[int, Tuple[int, ...]] = 2,\n        norm_num_groups: int = 32,\n        act_fn: str = \"silu\",\n        double_z: bool = True,\n        mid_block_add_attention=True,\n        resnet_time_scale_shift: str = \"default\",\n        temb_channels: int = 256,\n        freq_shift: int = 0,\n        flip_sin_to_cos: bool = True,\n        attention_head_dim: int = 1,\n    ):\n        super().__init__()\n        if isinstance(layers_per_block, int):\n            layers_per_block = (layers_per_block,) * len(down_block_types)\n        self.layers_per_block = layers_per_block\n\n        timestep_input_dim = max(128, block_out_channels[0])\n        self.time_proj = Timesteps(timestep_input_dim, flip_sin_to_cos, freq_shift)\n        self.time_embedding = TimestepEmbedding(timestep_input_dim, temb_channels)\n\n        self.conv_in = nn.Conv2d(\n            in_channels,\n            block_out_channels[0],\n            kernel_size=3,\n            stride=1,\n            padding=1,\n        )\n\n        self.down_blocks = nn.ModuleList([])\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=self.layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                add_downsample=not is_final_block,\n                resnet_eps=1e-6,\n                downsample_padding=0,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                attention_head_dim=attention_head_dim,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                temb_channels=temb_channels,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        self.mid_block = UNetMidBlock2D(\n            in_channels=block_out_channels[-1],\n            resnet_eps=1e-6,\n            resnet_act_fn=act_fn,\n            output_scale_factor=1,\n            attention_head_dim=attention_head_dim,\n            resnet_groups=norm_num_groups,\n            add_attention=mid_block_add_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(\n            num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6\n        )\n        self.conv_act = nn.SiLU()\n\n        conv_out_channels = 2 * out_channels if double_z else out_channels\n        self.conv_out = nn.Conv2d(\n            block_out_channels[-1], conv_out_channels, 3, padding=1\n        )\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        sample: torch.Tensor,\n        timesteps: Union[torch.Tensor, int],\n    ) -> torch.Tensor:\n        r\"\"\"The forward method of the `Encoder` class.\"\"\"\n\n        # time embedding\n        if not torch.is_tensor(timesteps):\n            timesteps = torch.tensor(\n                [timesteps], dtype=torch.long, device=sample.device\n            )\n        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps * torch.ones(\n            sample.shape[0], dtype=timesteps.dtype, device=timesteps.device\n        )\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=list(self.time_embedding.parameters())[0].dtype)\n        emb = self.time_embedding(t_emb)\n\n        sample = self.conv_in(sample)\n\n        if self.training and self.gradient_checkpointing:\n\n            def create_custom_forward(module):\n                def custom_forward(*inputs):\n                    return module(*inputs)\n\n                return custom_forward\n\n            # down\n            if is_torch_version(\">=\", \"1.11.0\"):\n                for down_block in self.down_blocks:\n                    sample = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(down_block),\n                        sample,\n                        emb,\n                        use_reentrant=False,\n                    )\n                # middle\n                sample = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(self.mid_block),\n                    sample,\n                    emb,\n                    use_reentrant=False,\n                )\n            else:\n                for down_block in self.down_blocks:\n                    sample = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(down_block), sample, emb\n                    )\n                # middle\n                sample = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(self.mid_block), sample, emb\n                )\n\n        else:\n            # down\n            for down_block in self.down_blocks:\n                sample, _ = down_block(sample, emb)\n\n            # middle\n            sample = self.mid_block(sample, emb)\n\n        # post-process\n        sample = self.conv_norm_out(sample)\n        sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        return sample\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/__init__.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-01-18 11:40:23\n\n\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/resize.py",
    "content": "\"\"\"\nA standalone PyTorch implementation for fast and efficient bicubic resampling.\nThe resulting values are the same to MATLAB function imresize('bicubic').\n## Author:      Sanghyun Son\n## Email:       sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)\n## Version:     1.2.0\n## Last update: July 9th, 2020 (KST)\nDependency: torch\nExample::\n>>> import torch\n>>> import core\n>>> x = torch.arange(16).float().view(1, 1, 4, 4)\n>>> y = core.imresize(x, sizes=(3, 3))\n>>> print(y)\ntensor([[[[ 0.7506,  2.1004,  3.4503],\n          [ 6.1505,  7.5000,  8.8499],\n          [11.5497, 12.8996, 14.2494]]]])\n\"\"\"\n\nimport math\nimport typing\n\nimport torch\nfrom torch.nn import functional as F\n\n__all__ = ['imresize']\n\n_I = typing.Optional[int]\n_D = typing.Optional[torch.dtype]\n\n\ndef nearest_contribution(x: torch.Tensor) -> torch.Tensor:\n    range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))\n    cont = range_around_0.to(dtype=x.dtype)\n    return cont\n\n\ndef linear_contribution(x: torch.Tensor) -> torch.Tensor:\n    ax = x.abs()\n    range_01 = ax.le(1)\n    cont = (1 - ax) * range_01.to(dtype=x.dtype)\n    return cont\n\n\ndef cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:\n    ax = x.abs()\n    ax2 = ax * ax\n    ax3 = ax * ax2\n\n    range_01 = ax.le(1)\n    range_12 = torch.logical_and(ax.gt(1), ax.le(2))\n\n    cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1\n    cont_01 = cont_01 * range_01.to(dtype=x.dtype)\n\n    cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)\n    cont_12 = cont_12 * range_12.to(dtype=x.dtype)\n\n    cont = cont_01 + cont_12\n    return cont\n\n\ndef gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:\n    range_3sigma = (x.abs() <= 3 * sigma + 1)\n    # Normalization will be done after\n    cont = torch.exp(-x.pow(2) / (2 * sigma**2))\n    cont = cont * range_3sigma.to(dtype=x.dtype)\n    return cont\n\n\ndef discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:\n    '''\n    For downsampling with integer scale only.\n    '''\n    downsampling_factor = int(1 / scale)\n    if kernel == 'cubic':\n        kernel_size_orig = 4\n    else:\n        raise ValueError('Pass!')\n\n    if antialiasing:\n        kernel_size = kernel_size_orig * downsampling_factor\n    else:\n        kernel_size = kernel_size_orig\n\n    if downsampling_factor % 2 == 0:\n        a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))\n    else:\n        kernel_size -= 1\n        a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))\n\n    with torch.no_grad():\n        r = torch.linspace(-a, a, steps=kernel_size)\n        k = cubic_contribution(r).view(-1, 1)\n        k = torch.matmul(k, k.t())\n        k /= k.sum()\n\n    return k\n\n\ndef reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:\n    '''\n    Apply reflect padding to the given Tensor.\n    Note that it is slightly different from the PyTorch functional.pad,\n    where boundary elements are used only once.\n    Instead, we follow the MATLAB implementation\n    which uses boundary elements twice.\n    For example,\n    [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,\n    while our implementation yields [a, a, b, c, d, d].\n    '''\n    b, c, h, w = x.size()\n    if dim == 2 or dim == -2:\n        padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)\n        padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)\n        for p in range(pad_pre):\n            padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])\n        for p in range(pad_post):\n            padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])\n    else:\n        padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)\n        padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)\n        for p in range(pad_pre):\n            padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])\n        for p in range(pad_post):\n            padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])\n\n    return padding_buffer\n\n\ndef padding(x: torch.Tensor,\n            dim: int,\n            pad_pre: int,\n            pad_post: int,\n            padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:\n    if padding_type is None:\n        return x\n    elif padding_type == 'reflect':\n        x_pad = reflect_padding(x, dim, pad_pre, pad_post)\n    else:\n        raise ValueError('{} padding is not supported!'.format(padding_type))\n\n    return x_pad\n\n\ndef get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:\n    base = base.long()\n    r_min = base.min()\n    r_max = base.max() + kernel_size - 1\n\n    if r_min <= 0:\n        pad_pre = -r_min\n        pad_pre = pad_pre.item()\n        base += pad_pre\n    else:\n        pad_pre = 0\n\n    if r_max >= x_size:\n        pad_post = r_max - x_size + 1\n        pad_post = pad_post.item()\n    else:\n        pad_post = 0\n\n    return pad_pre, pad_post, base\n\n\ndef get_weight(dist: torch.Tensor,\n               kernel_size: int,\n               kernel: str = 'cubic',\n               sigma: float = 2.0,\n               antialiasing_factor: float = 1) -> torch.Tensor:\n    buffer_pos = dist.new_zeros(kernel_size, len(dist))\n    for idx, buffer_sub in enumerate(buffer_pos):\n        buffer_sub.copy_(dist - idx)\n\n    # Expand (downsampling) / Shrink (upsampling) the receptive field.\n    buffer_pos *= antialiasing_factor\n    if kernel == 'cubic':\n        weight = cubic_contribution(buffer_pos)\n    elif kernel == 'gaussian':\n        weight = gaussian_contribution(buffer_pos, sigma=sigma)\n    else:\n        raise ValueError('{} kernel is not supported!'.format(kernel))\n\n    weight /= weight.sum(dim=0, keepdim=True)\n    return weight\n\n\ndef reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:\n    # Resize height\n    if dim == 2 or dim == -2:\n        k = (kernel_size, 1)\n        h_out = x.size(-2) - kernel_size + 1\n        w_out = x.size(-1)\n    # Resize width\n    else:\n        k = (1, kernel_size)\n        h_out = x.size(-2)\n        w_out = x.size(-1) - kernel_size + 1\n\n    unfold = F.unfold(x, k)\n    unfold = unfold.view(unfold.size(0), -1, h_out, w_out)\n    return unfold\n\n\ndef reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:\n    if x.dim() == 4:\n        b, c, h, w = x.size()\n    elif x.dim() == 3:\n        c, h, w = x.size()\n        b = None\n    elif x.dim() == 2:\n        h, w = x.size()\n        b = c = None\n    else:\n        raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))\n\n    x = x.view(-1, 1, h, w)\n    return x, b, c, h, w\n\n\ndef reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:\n    rh = x.size(-2)\n    rw = x.size(-1)\n    # Back to the original dimension\n    if b is not None:\n        x = x.view(b, c, rh, rw)  # 4-dim\n    else:\n        if c is not None:\n            x = x.view(c, rh, rw)  # 3-dim\n        else:\n            x = x.view(rh, rw)  # 2-dim\n\n    return x\n\n\ndef cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:\n    if x.dtype != torch.float32 or x.dtype != torch.float64:\n        dtype = x.dtype\n        x = x.float()\n    else:\n        dtype = None\n\n    return x, dtype\n\n\ndef cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:\n    if dtype is not None:\n        if not dtype.is_floating_point:\n            x = x - x.detach() + x.round()\n        # To prevent over/underflow when converting types\n        if dtype is torch.uint8:\n            x = x.clamp(0, 255)\n\n        x = x.to(dtype=dtype)\n\n    return x\n\n\ndef resize_1d(x: torch.Tensor,\n              dim: int,\n              size: int,\n              scale: float,\n              kernel: str = 'cubic',\n              sigma: float = 2.0,\n              padding_type: str = 'reflect',\n              antialiasing: bool = True) -> torch.Tensor:\n    '''\n    Args:\n        x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).\n        dim (int):\n        scale (float):\n        size (int):\n    Return:\n    '''\n    # Identity case\n    if scale == 1:\n        return x\n\n    # Default bicubic kernel with antialiasing (only when downsampling)\n    if kernel == 'cubic':\n        kernel_size = 4\n    else:\n        kernel_size = math.floor(6 * sigma)\n\n    if antialiasing and (scale < 1):\n        antialiasing_factor = scale\n        kernel_size = math.ceil(kernel_size / antialiasing_factor)\n    else:\n        antialiasing_factor = 1\n\n    # We allow margin to both sizes\n    kernel_size += 2\n\n    # Weights only depend on the shape of input and output,\n    # so we do not calculate gradients here.\n    with torch.no_grad():\n        pos = torch.linspace(\n            0,\n            size - 1,\n            steps=size,\n            dtype=x.dtype,\n            device=x.device,\n        )\n        pos = (pos + 0.5) / scale - 0.5\n        base = pos.floor() - (kernel_size // 2) + 1\n        dist = pos - base\n        weight = get_weight(\n            dist,\n            kernel_size,\n            kernel=kernel,\n            sigma=sigma,\n            antialiasing_factor=antialiasing_factor,\n        )\n        pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))\n\n    # To backpropagate through x\n    x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)\n    unfold = reshape_tensor(x_pad, dim, kernel_size)\n    # Subsampling first\n    if dim == 2 or dim == -2:\n        sample = unfold[..., base, :]\n        weight = weight.view(1, kernel_size, sample.size(2), 1)\n    else:\n        sample = unfold[..., base]\n        weight = weight.view(1, kernel_size, 1, sample.size(3))\n\n    # Apply the kernel\n    x = sample * weight\n    x = x.sum(dim=1, keepdim=True)\n    return x\n\n\ndef downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor:\n    c = x.size(1)\n    k_h = k.size(-2)\n    k_w = k.size(-1)\n\n    k = k.to(dtype=x.dtype, device=x.device)\n    k = k.view(1, 1, k_h, k_w)\n    k = k.repeat(c, c, 1, 1)\n    e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)\n    e = e.view(c, c, 1, 1)\n    k = k * e\n\n    pad_h = (k_h - scale) // 2\n    pad_w = (k_w - scale) // 2\n    x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)\n    x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)\n    y = F.conv2d(x, k, padding=0, stride=scale)\n    return y\n\n\ndef imresize(x: torch.Tensor,\n             scale: typing.Optional[float] = None,\n             sizes: typing.Optional[typing.Tuple[int, int]] = None,\n             kernel: typing.Union[str, torch.Tensor] = 'cubic',\n             sigma: float = 2,\n             rotation_degree: float = 0,\n             padding_type: str = 'reflect',\n             antialiasing: bool = True) -> torch.Tensor:\n    \"\"\"\n    Args:\n        x (torch.Tensor):\n        scale (float):\n        sizes (tuple(int, int)):\n        kernel (str, default='cubic'):\n        sigma (float, default=2):\n        rotation_degree (float, default=0):\n        padding_type (str, default='reflect'):\n        antialiasing (bool, default=True):\n    Return:\n        torch.Tensor:\n    \"\"\"\n    if scale is None and sizes is None:\n        raise ValueError('One of scale or sizes must be specified!')\n    if scale is not None and sizes is not None:\n        raise ValueError('Please specify scale or sizes to avoid conflict!')\n\n    x, b, c, h, w = reshape_input(x)\n\n    if sizes is None and scale is not None:\n        '''\n        # Check if we can apply the convolution algorithm\n        scale_inv = 1 / scale\n        if isinstance(kernel, str) and scale_inv.is_integer():\n            kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)\n        elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():\n            raise ValueError(\n                'An integer downsampling factor '\n                'should be used with a predefined kernel!'\n            )\n        '''\n        # Determine output size\n        sizes = (math.ceil(h * scale), math.ceil(w * scale))\n        scales = (scale, scale)\n\n    if scale is None and sizes is not None:\n        scales = (sizes[0] / h, sizes[1] / w)\n\n    x, dtype = cast_input(x)\n\n    if isinstance(kernel, str) and sizes is not None:\n        # Core resizing module\n        x = resize_1d(\n            x,\n            -2,\n            size=sizes[0],\n            scale=scales[0],\n            kernel=kernel,\n            sigma=sigma,\n            padding_type=padding_type,\n            antialiasing=antialiasing)\n        x = resize_1d(\n            x,\n            -1,\n            size=sizes[1],\n            scale=scales[1],\n            kernel=kernel,\n            sigma=sigma,\n            padding_type=padding_type,\n            antialiasing=antialiasing)\n    elif isinstance(kernel, torch.Tensor) and scale is not None:\n        x = downsampling_2d(x, kernel, scale=int(1 / scale))\n\n    x = reshape_output(x, b, c)\n    x = cast_output(x, dtype)\n    return x\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_color_fix.py",
    "content": "'''\n# --------------------------------------------------------------------------------\n#   Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)\n# --------------------------------------------------------------------------------\n'''\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import functional as F\n\nfrom torchvision.transforms import ToTensor, ToPILImage\n\nfrom .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch\n\n\ndef calc_mean_std(feat: Tensor, eps=1e-5):\n    \"\"\"Calculate mean and std for adaptive_instance_normalization.\n    Args:\n        feat (Tensor): 4D tensor.\n        eps (float): A small value added to the variance to avoid\n            divide-by-zero. Default: 1e-5.\n    \"\"\"\n    size = feat.size()\n    assert len(size) == 4, 'The input feature should be 4D tensor.'\n    b, c = size[:2]\n    feat_var = feat.reshape(b, c, -1).var(dim=2) + eps\n    feat_std = feat_var.sqrt().reshape(b, c, 1, 1)\n    feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)\n    return feat_mean, feat_std\n\ndef adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"Adaptive instance normalization.\n    Adjust the reference features to have the similar color and illuminations\n    as those in the degradate features.\n    Args:\n        content_feat (Tensor): The reference feature.\n        style_feat (Tensor): The degradate features.\n    \"\"\"\n    size = content_feat.size()\n    style_mean, style_std = calc_mean_std(style_feat)\n    content_mean, content_std = calc_mean_std(content_feat)\n    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)\n    return normalized_feat * style_std.expand(size) + style_mean.expand(size)\n\ndef wavelet_blur(image: Tensor, radius: int):\n    \"\"\"\n    Apply wavelet blur to the input tensor.\n    \"\"\"\n    # input shape: (1, 3, H, W)\n    # convolution kernel\n    kernel_vals = [\n        [0.0625, 0.125, 0.0625],\n        [0.125, 0.25, 0.125],\n        [0.0625, 0.125, 0.0625],\n    ]\n    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)\n    # add channel dimensions to the kernel to make it a 4D tensor\n    kernel = kernel[None, None]\n    # repeat the kernel across all input channels\n    kernel = kernel.repeat(3, 1, 1, 1)\n    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')\n    # apply convolution\n    output = F.conv2d(image, kernel, groups=3, dilation=radius)\n    return output\n\ndef wavelet_decomposition(image: Tensor, levels=5):\n    \"\"\"\n    Apply wavelet decomposition to the input tensor.\n    This function only returns the low frequency & the high frequency.\n    \"\"\"\n    high_freq = torch.zeros_like(image)\n    for i in range(levels):\n        radius = 2 ** i\n        low_freq = wavelet_blur(image, radius)\n        high_freq += (image - low_freq)\n        image = low_freq\n\n    return high_freq, low_freq\n\ndef wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"\n    Apply wavelet decomposition, so that the content will have the same color as the style.\n    \"\"\"\n    # calculate the wavelet decomposition of the content feature\n    content_high_freq, content_low_freq = wavelet_decomposition(content_feat)\n    del content_low_freq\n    # calculate the wavelet decomposition of the style feature\n    style_high_freq, style_low_freq = wavelet_decomposition(style_feat)\n    del style_high_freq\n    # reconstruct the content feature with the style's high frequency\n    return content_high_freq + style_low_freq\n\ndef ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"\n    Apply ycbcr decomposition, so that the content will have the same color as the style.\n    \"\"\"\n    content_y = rgb2ycbcrTorch(content_feat, only_y=True)\n    style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False)\n\n    target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1)\n\n    target_rgb = ycbcr2rgbTorch(target_ycbcr)\n\n    return target_rgb\n\n\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_common.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-02-06 10:34:59\n\nimport os\nimport random\nimport requests\nimport importlib\nfrom pathlib import Path\n\ndef mkdir(dir_path, delete=False, parents=True):\n    import shutil\n    if not isinstance(dir_path, Path):\n        dir_path = Path(dir_path)\n    if delete:\n        if dir_path.exists():\n            shutil.rmtree(str(dir_path))\n    if not dir_path.exists():\n        dir_path.mkdir(parents=parents)\n\ndef get_obj_from_str(string, reload=False):\n    current_package = __package__.rsplit(\".\", 1)[0]\n    is_relative_import = string.startswith(\".\")\n    package = current_package if is_relative_import else None\n\n    module, cls = string.rsplit(\".\", 1)\n    if reload:\n        module_imp = importlib.import_module(module, package=package)\n        importlib.reload(module_imp)\n    return getattr(importlib.import_module(module, package=package), cls)\n\ndef instantiate_from_config(config):\n    if not \"target\" in config:\n        raise KeyError(\"Expected key `target` to instantiate.\")\n    return get_obj_from_str(config[\"target\"])(**config.get(\"params\", dict()))\n\ndef str2bool(v):\n    if isinstance(v, bool):\n        return v\n    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n        return True\n    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n        return False\n    else:\n        raise argparse.ArgumentTypeError(\"Boolean value expected.\")\n\ndef get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):\n    '''\n    Get the file paths in the given folder.\n    param exts: list, e.g., ['png',]\n    return: list\n    '''\n    if not isinstance(dir_path, Path):\n        dir_path = Path(dir_path)\n\n    file_paths = []\n    for current_ext in exts:\n        if recursive:\n            file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)])\n        else:\n            file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)])\n\n    return file_paths\n\ndef readline_txt(txt_file):\n    txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file\n    out = []\n    for txt_file_current in txt_file:\n        with open(txt_file_current, 'r') as ff:\n            out.extend([x[:-1] for x in ff.readlines()])\n\n    return out\n\ndef scan_files_from_folder(dir_paths, exts, recursive=True):\n    '''\n    Scaning images from given folder.\n    Input:\n        dir_pathas: str or list.\n        exts: list\n    '''\n    exts = [exts, ] if isinstance(exts, str) else exts\n    dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths\n\n    file_paths = []\n    for current_dir in dir_paths:\n        current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir\n        for current_ext in exts:\n            if recursive:\n                search_flag = f\"**/*.{current_ext}\"\n            else:\n                search_flag = f\"*.{current_ext}\"\n            file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)]))\n\n    return file_paths\n\ndef write_path_to_txt(\n        dir_folder,\n        txt_path,\n        search_key,\n        num_files=None,\n        write_only_name=False,\n        write_only_stem=False,\n        shuffle=False,\n        ):\n    '''\n    Scaning the files in the given folder and write them into a txt file\n    Input:\n        dir_folder: path of the target folder\n        txt_path: path to save the txt file\n        search_key: e.g., '*.png'\n        write_only_name: bool, only record the file names (including extension),\n        write_only_stem: bool, only record the file names (not including extension),\n    '''\n    txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path\n    dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder\n    if txt_path.exists():\n        txt_path.unlink()\n    if write_only_name:\n        path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)])\n    elif write_only_stem:\n        path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)])\n    else:\n        path_list = sorted([str(x) for x in dir_folder.glob(search_key)])\n    if shuffle:\n        random.shuffle(path_list)\n    if num_files is not None:\n        path_list = path_list[:num_files]\n    with open(txt_path, mode='w') as ff:\n        for line in path_list:\n            ff.write(line+'\\n')\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_ema.py",
    "content": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates=True):\n        super().__init__()\n        if decay < 0.0 or decay > 1.0:\n            raise ValueError('Decay must be between 0 and 1')\n\n        self.m_name2s_name = {}\n        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))\n        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates\n        else torch.tensor(-1, dtype=torch.int))\n\n        for name, p in model.named_parameters():\n            if p.requires_grad:\n                # remove as '.'-character is not allowed in buffers\n                s_name = name.replace('.', '')\n                self.m_name2s_name.update({name: s_name})\n                self.register_buffer(s_name, p.clone().detach().data)\n\n        self.collected_params = []\n\n    def reset_num_updates(self):\n        del self.num_updates\n        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))\n\n    def forward(self, model):\n        decay = self.decay\n\n        if self.num_updates >= 0:\n            self.num_updates += 1\n            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))\n\n        one_minus_decay = 1.0 - decay\n\n        with torch.no_grad():\n            m_param = dict(model.named_parameters())\n            shadow_params = dict(self.named_buffers())\n\n            for key in m_param:\n                if m_param[key].requires_grad:\n                    sname = self.m_name2s_name[key]\n                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])\n                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))\n                else:\n                    assert not key in self.m_name2s_name\n\n    def copy_to(self, model):\n        \"\"\"\n        Copying the ema state (i.e., buffers) to the targeted model\n        Input:\n            model: targeted model\n        \"\"\"\n        m_param = dict(model.named_parameters())\n        shadow_params = dict(self.named_buffers())\n        for key in m_param:\n            if m_param[key].requires_grad:\n                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)\n            else:\n                assert not key in self.m_name2s_name\n\n    def store(self, parameters):\n        \"\"\"\n        Save the parameters of the targeted model into the temporary pool for restoring later.\n        Args:\n          parameters: parameters of the targeted model.\n                      Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.\n        \"\"\"\n        self.collected_params = [param.clone() for param in parameters]\n\n    def restore(self, parameters):\n        \"\"\"\n        Restore the parameters from the temporaty pool (stored with the `store` method).\n        Useful to validate the model with EMA parameters without affecting the\n        original optimization process. Store the parameters before the\n        `copy_to` method. After validation (or model saving), use this to\n        restore the former parameters.\n        Args:\n          parameters: Iterable of `torch.nn.Parameter`; the parameters to be\n            updated with the stored parameters.\n        \"\"\"\n        for c_param, param in zip(self.collected_params, parameters):\n            param.data.copy_(c_param.data)\n\n    def resume(self, ckpt, num_updates):\n        \"\"\"\n        Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers\n        Input:\n            model: targerted model\n        \"\"\"\n        self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int))\n\n        shadow_params = dict(self.named_buffers())\n        for key, value in ckpt.items():\n            try:\n                shadow_params[self.m_name2s_name[key]].data.copy_(value.data)\n            except:\n                if key.startswith('module') and key not in shadow_params:\n                    key = key[7:]\n                shadow_params[self.m_name2s_name[key]].data.copy_(value.data)\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_image.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 16:54:19\n\nimport sys\nimport cv2\nimport math\nimport torch\nimport random\nimport numpy as np\nfrom pathlib import Path\n\n# --------------------------Metrics----------------------------\ndef ssim(img1, img2):\n    C1 = (0.01 * 255)**2\n    C2 = (0.03 * 255)**2\n\n    img1 = img1.astype(np.float64)\n    img2 = img2.astype(np.float64)\n    kernel = cv2.getGaussianKernel(11, 1.5)\n    window = np.outer(kernel, kernel.transpose())\n\n    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid\n    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]\n    mu1_sq = mu1**2\n    mu2_sq = mu2**2\n    mu1_mu2 = mu1 * mu2\n    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq\n    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq\n    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *\n                                                            (sigma1_sq + sigma2_sq + C2))\n    return ssim_map.mean()\n\ndef calculate_ssim(im1, im2, border=0, ycbcr=False):\n    '''\n    SSIM the same outputs as MATLAB's\n    im1, im2: h x w x , [0, 255], uint8\n    '''\n    if not im1.shape == im2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n\n    if ycbcr:\n        im1 = rgb2ycbcr(im1, True)\n        im2 = rgb2ycbcr(im2, True)\n\n    h, w = im1.shape[:2]\n    im1 = im1[border:h-border, border:w-border]\n    im2 = im2[border:h-border, border:w-border]\n\n    if im1.ndim == 2:\n        return ssim(im1, im2)\n    elif im1.ndim == 3:\n        if im1.shape[2] == 3:\n            ssims = []\n            for i in range(3):\n                ssims.append(ssim(im1[:,:,i], im2[:,:,i]))\n            return np.array(ssims).mean()\n        elif im1.shape[2] == 1:\n            return ssim(np.squeeze(im1), np.squeeze(im2))\n    else:\n        raise ValueError('Wrong input image dimensions.')\n\ndef calculate_psnr(im1, im2, border=0, ycbcr=False):\n    '''\n    PSNR metric.\n    im1, im2: h x w x , [0, 255], uint8\n    '''\n    if not im1.shape == im2.shape:\n        raise ValueError('Input images must have the same dimensions.')\n\n    if ycbcr:\n        im1 = rgb2ycbcr(im1, True)\n        im2 = rgb2ycbcr(im2, True)\n\n    h, w = im1.shape[:2]\n    im1 = im1[border:h-border, border:w-border]\n    im2 = im2[border:h-border, border:w-border]\n\n    im1 = im1.astype(np.float64)\n    im2 = im2.astype(np.float64)\n    mse = np.mean((im1 - im2)**2)\n    if mse == 0:\n        return float('inf')\n    return 20 * math.log10(255.0 / math.sqrt(mse))\n\ndef normalize_np(im, mean=0.5, std=0.5, reverse=False):\n    '''\n    Input:\n        im: h x w x c, numpy array\n        Normalize: (im - mean) / std\n        Reverse: im * std + mean\n\n    '''\n    if not isinstance(mean, (list, tuple)):\n        mean = [mean, ] * im.shape[2]\n    mean = np.array(mean).reshape([1, 1, im.shape[2]])\n\n    if not isinstance(std, (list, tuple)):\n        std = [std, ] * im.shape[2]\n    std = np.array(std).reshape([1, 1, im.shape[2]])\n\n    if not reverse:\n        out = (im.astype(np.float32) - mean) / std\n    else:\n        out = im.astype(np.float32) * std + mean\n    return out\n\ndef normalize_th(im, mean=0.5, std=0.5, reverse=False):\n    '''\n    Input:\n        im: b x c x h x w, torch tensor\n        Normalize: (im - mean) / std\n        Reverse: im * std + mean\n\n    '''\n    if not isinstance(mean, (list, tuple)):\n        mean = [mean, ] * im.shape[1]\n    mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])\n\n    if not isinstance(std, (list, tuple)):\n        std = [std, ] * im.shape[1]\n    std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])\n\n    if not reverse:\n        out = (im - mean) / std\n    else:\n        out = im * std + mean\n    return out\n\n# ------------------------Image format--------------------------\ndef rgb2ycbcr(im, only_y=True):\n    '''\n    same as matlab rgb2ycbcr\n    Input:\n        im: uint8 [0,255] or float [0,1]\n        only_y: only return Y channel\n    '''\n    # transform to float64 data type, range [0, 255]\n    if im.dtype == np.uint8:\n        im_temp = im.astype(np.float64)\n    else:\n        im_temp = (im * 255).astype(np.float64)\n\n    # convert\n    if only_y:\n        rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0\n    else:\n        rlt = np.matmul(im_temp, np.array([[65.481,  -37.797, 112.0  ],\n                                           [128.553, -74.203, -93.786],\n                                           [24.966,  112.0,   -18.214]])/255.0) + [16, 128, 128]\n    if im.dtype == np.uint8:\n        rlt = rlt.round()\n    else:\n        rlt /= 255.\n    return rlt.astype(im.dtype)\n\ndef rgb2ycbcrTorch(im, only_y=True):\n    '''\n    same as matlab rgb2ycbcr\n    Input:\n        im: float [0,1], N x 3 x H x W\n        only_y: only return Y channel\n    '''\n    # transform to range [0,255.0]\n    im_temp = im.permute([0,2,3,1]) * 255.0  # N x H x W x C --> N x H x W x C\n    # convert\n    if only_y:\n        rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],\n                                        device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0\n    else:\n        scale = torch.tensor(\n            [[65.481,  -37.797, 112.0  ],\n             [128.553, -74.203, -93.786],\n             [24.966,  112.0,   -18.214]],\n            device=im.device, dtype=im.dtype\n        ) / 255.0\n        bias = torch.tensor([16, 128, 128], device=im.device, dtype=im.dtype).view([-1, 1, 1, 3])\n        rlt = torch.matmul(im_temp, scale) + bias\n\n    rlt /= 255.0\n    rlt.clamp_(0.0, 1.0)\n    return rlt.permute([0, 3, 1, 2])\n\ndef ycbcr2rgbTorch(im):\n    '''\n    same as matlab ycbcr2rgb\n    Input:\n        im: float [0,1], N x 3 x H x W\n        only_y: only return Y channel\n    '''\n    # transform to range [0,255.0]\n    im_temp = im.permute([0,2,3,1]) * 255.0  # N x H x W x C --> N x H x W x C\n    # convert\n    scale = torch.tensor(\n        [[0.00456621, 0.00456621, 0.00456621],\n         [0, -0.00153632, 0.00791071],\n         [0.00625893, -0.00318811, 0]],\n        device=im.device, dtype=im.dtype\n        ) * 255.0\n    bias = torch.tensor(\n        [-222.921, 135.576, -276.836], device=im.device, dtype=im.dtype\n    ).view([-1, 1, 1, 3])\n    rlt = torch.matmul(im_temp, scale) + bias\n    rlt /= 255.0\n    rlt.clamp_(0.0, 1.0)\n    return rlt.permute([0, 3, 1, 2])\n\ndef bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)\n\ndef rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)\n\ndef tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):\n    \"\"\"Convert torch Tensors into image numpy arrays.\n\n    After clamping to [min, max], values will be normalized to [0, 1].\n\n    Args:\n        tensor (Tensor or list[Tensor]): Accept shapes:\n            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);\n            2) 3D Tensor of shape (3/1 x H x W);\n            3) 2D Tensor of shape (H x W).\n            Tensor channel should be in RGB order.\n        rgb2bgr (bool): Whether to change rgb to bgr.\n        out_type (numpy type): output types. If ``np.uint8``, transform outputs\n            to uint8 type with range [0, 255]; otherwise, float type with\n            range [0, 1]. Default: ``np.uint8``.\n        min_max (tuple[int]): min and max values for clamp.\n\n    Returns:\n        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of\n        shape (H x W). The channel order is BGR.\n    \"\"\"\n    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):\n        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')\n\n    flag_tensor = torch.is_tensor(tensor)\n    if flag_tensor:\n        tensor = [tensor]\n    result = []\n    for _tensor in tensor:\n        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)\n        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])\n\n        n_dim = _tensor.dim()\n        if n_dim == 4:\n            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()\n            img_np = img_np.transpose(1, 2, 0)\n            if rgb2bgr:\n                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)\n        elif n_dim == 3:\n            img_np = _tensor.numpy()\n            img_np = img_np.transpose(1, 2, 0)\n            if img_np.shape[2] == 1:  # gray image\n                img_np = np.squeeze(img_np, axis=2)\n            else:\n                if rgb2bgr:\n                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)\n        elif n_dim == 2:\n            img_np = _tensor.numpy()\n        else:\n            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')\n        if out_type == np.uint8:\n            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.\n            img_np = (img_np * 255.0).round()\n        img_np = img_np.astype(out_type)\n        result.append(img_np)\n    if len(result) == 1 and flag_tensor:\n        result = result[0]\n    return result\n\n# ------------------------Image resize-----------------------------\ndef imresize_np(img, scale, antialiasing=True):\n    # Now the scale should be the same for H and W\n    # input: img: Numpy, HWC or HW [0,1]\n    # output: HWC or HW [0,1] w/o round\n    img = torch.from_numpy(img)\n    need_squeeze = True if img.dim() == 2 else False\n    if need_squeeze:\n        img.unsqueeze_(2)\n\n    in_H, in_W, in_C = img.size()\n    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)\n    kernel_width = 4\n    kernel = 'cubic'\n\n    # Return the desired dimension order for performing the resize.  The\n    # strategy is to perform the resize first along the dimension with the\n    # smallest scale factor.\n    # Now we do not support this.\n\n    # get weights and indices\n    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(\n        in_H, out_H, scale, kernel, kernel_width, antialiasing)\n    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(\n        in_W, out_W, scale, kernel, kernel_width, antialiasing)\n    # process H dimension\n    # symmetric copying\n    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)\n    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)\n\n    sym_patch = img[:sym_len_Hs, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)\n\n    sym_patch = img[-sym_len_He:, :, :]\n    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(0, inv_idx)\n    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)\n\n    out_1 = torch.FloatTensor(out_H, in_W, in_C)\n    kernel_width = weights_H.size(1)\n    for i in range(out_H):\n        idx = int(indices_H[i][0])\n        for j in range(out_C):\n            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])\n\n    # process W dimension\n    # symmetric copying\n    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)\n    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)\n\n    sym_patch = out_1[:, :sym_len_Ws, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)\n\n    sym_patch = out_1[:, -sym_len_We:, :]\n    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()\n    sym_patch_inv = sym_patch.index_select(1, inv_idx)\n    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)\n\n    out_2 = torch.FloatTensor(out_H, out_W, in_C)\n    kernel_width = weights_W.size(1)\n    for i in range(out_W):\n        idx = int(indices_W[i][0])\n        for j in range(out_C):\n            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])\n    if need_squeeze:\n        out_2.squeeze_()\n\n    return out_2.numpy()\n\ndef calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):\n    if (scale < 1) and (antialiasing):\n        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width\n        kernel_width = kernel_width / scale\n\n    # Output-space coordinates\n    x = torch.linspace(1, out_length, out_length)\n\n    # Input-space coordinates. Calculate the inverse mapping such that 0.5\n    # in output space maps to 0.5 in input space, and 0.5+scale in output\n    # space maps to 1.5 in input space.\n    u = x / scale + 0.5 * (1 - 1 / scale)\n\n    # What is the left-most pixel that can be involved in the computation?\n    left = torch.floor(u - kernel_width / 2)\n\n    # What is the maximum number of pixels that can be involved in the\n    # computation?  Note: it's OK to use an extra pixel here; if the\n    # corresponding weights are all zero, it will be eliminated at the end\n    # of this function.\n    P = math.ceil(kernel_width) + 2\n\n    # The indices of the input pixels involved in computing the k-th output\n    # pixel are in row k of the indices matrix.\n    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(\n        1, P).expand(out_length, P)\n\n    # The weights used to compute the k-th output pixel are in row k of the\n    # weights matrix.\n    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices\n    # apply cubic kernel\n    if (scale < 1) and (antialiasing):\n        weights = scale * cubic(distance_to_center * scale)\n    else:\n        weights = cubic(distance_to_center)\n    # Normalize the weights matrix so that each row sums to 1.\n    weights_sum = torch.sum(weights, 1).view(out_length, 1)\n    weights = weights / weights_sum.expand(out_length, P)\n\n    # If a column in weights is all zero, get rid of it. only consider the first and last column.\n    weights_zero_tmp = torch.sum((weights == 0), 0)\n    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 1, P - 2)\n        weights = weights.narrow(1, 1, P - 2)\n    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):\n        indices = indices.narrow(1, 0, P - 2)\n        weights = weights.narrow(1, 0, P - 2)\n    weights = weights.contiguous()\n    indices = indices.contiguous()\n    sym_len_s = -indices.min() + 1\n    sym_len_e = indices.max() - in_length\n    indices = indices + sym_len_s - 1\n    return weights, indices, int(sym_len_s), int(sym_len_e)\n\n# matlab 'imresize' function, now only support 'bicubic'\ndef cubic(x):\n    absx = torch.abs(x)\n    absx2 = absx**2\n    absx3 = absx**3\n    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \\\n        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))\n\n# ------------------------Image I/O-----------------------------\ndef imread(path, chn='rgb', dtype='float32', force_gray2rgb=True, force_rgba2rgb=False):\n    '''\n    Read image.\n    chn: 'rgb', 'bgr' or 'gray'\n    out:\n        im: h x w x c, numpy tensor\n    '''\n    try:\n        im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)  # BGR, uint8\n    except:\n        print(str(path))\n\n    if im is None:\n        print(str(path))\n\n    if chn.lower() == 'gray':\n        assert im.ndim == 2, f\"{str(path)} can't be successfuly read!\"\n    else:\n        if im.ndim == 2:\n            if force_gray2rgb:\n                im = np.stack([im, im, im], axis=2)\n            else:\n                raise ValueError(f\"{str(path)} has {im.ndim} channels!\")\n        elif im.ndim == 4:\n            if force_rgba2rgb:\n                im = im[:, :, :3]\n            else:\n                raise ValueError(f\"{str(path)} has {im.ndim} channels!\")\n        else:\n            if chn.lower() == 'rgb':\n                im = bgr2rgb(im)\n            elif chn.lower() == 'bgr':\n                pass\n\n    if dtype == 'float32':\n        im = im.astype(np.float32) / 255.\n    elif dtype ==  'float64':\n        im = im.astype(np.float64) / 255.\n    elif dtype == 'uint8':\n        pass\n    else:\n        sys.exit('Please input corrected dtype: float32, float64 or uint8!')\n\n    return im\n\n# ------------------------Augmentation-----------------------------\ndef data_aug_np(image, mode):\n    '''\n    Performs data augmentation of the input image\n    Input:\n        image: a cv2 (OpenCV) image\n        mode: int. Choice of transformation to apply to the image\n                0 - no transformation\n                1 - flip up and down\n                2 - rotate counterwise 90 degree\n                3 - rotate 90 degree and flip up and down\n                4 - rotate 180 degree\n                5 - rotate 180 degree and flip\n                6 - rotate 270 degree\n                7 - rotate 270 degree and flip\n    '''\n    if mode == 0:\n        # original\n        out = image\n    elif mode == 1:\n        # flip up and down\n        out = np.flipud(image)\n    elif mode == 2:\n        # rotate counterwise 90 degree\n        out = np.rot90(image)\n    elif mode == 3:\n        # rotate 90 degree and flip up and down\n        out = np.rot90(image)\n        out = np.flipud(out)\n    elif mode == 4:\n        # rotate 180 degree\n        out = np.rot90(image, k=2)\n    elif mode == 5:\n        # rotate 180 degree and flip\n        out = np.rot90(image, k=2)\n        out = np.flipud(out)\n    elif mode == 6:\n        # rotate 270 degree\n        out = np.rot90(image, k=3)\n    elif mode == 7:\n        # rotate 270 degree and flip\n        out = np.rot90(image, k=3)\n        out = np.flipud(out)\n    else:\n        raise Exception('Invalid choice of image transformation')\n\n    return out.copy()\n\ndef inverse_data_aug_np(image, mode):\n    '''\n    Performs inverse data augmentation of the input image\n    '''\n    if mode == 0:\n        # original\n        out = image\n    elif mode == 1:\n        out = np.flipud(image)\n    elif mode == 2:\n        out = np.rot90(image, axes=(1,0))\n    elif mode == 3:\n        out = np.flipud(image)\n        out = np.rot90(out, axes=(1,0))\n    elif mode == 4:\n        out = np.rot90(image, k=2, axes=(1,0))\n    elif mode == 5:\n        out = np.flipud(image)\n        out = np.rot90(out, k=2, axes=(1,0))\n    elif mode == 6:\n        out = np.rot90(image, k=3, axes=(1,0))\n    elif mode == 7:\n        # rotate 270 degree and flip\n        out = np.flipud(image)\n        out = np.rot90(out, k=3, axes=(1,0))\n    else:\n        raise Exception('Invalid choice of image transformation')\n\n    return out\n\n# ----------------------Visualization----------------------------\ndef imshow(x, title=None, cbar=False):\n    import matplotlib.pyplot as plt\n    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')\n    if title:\n        plt.title(title)\n    if cbar:\n        plt.colorbar()\n    plt.show()\n\ndef imblend_with_mask(im, mask, alpha=0.25):\n    \"\"\"\n    Input:\n        im, mask: h x w x c numpy array, uint8, [0, 255]\n        alpha: scaler in [0.0, 1.0]\n    \"\"\"\n    edge_map = cv2.Canny(mask, 100, 200).astype(np.float32)[:, :, None] / 255.\n\n    assert mask.dtype == np.uint8\n    mask = mask.astype(np.float32) / 255.\n    if mask.ndim == 2:\n        mask = mask[:, :, None]\n\n    back_color = np.array([159, 121, 238], dtype=np.float32).reshape((1,1,3))\n    blend = im.astype(np.float32) * alpha + (1 - alpha) * back_color\n    blend = np.clip(blend, 0, 255)\n    out = im.astype(np.float32) * (1 - mask) + blend * mask\n\n    # paste edge\n    out = out * (1 - edge_map) + np.array([0,255,0], dtype=np.float32).reshape((1,1,3)) * edge_map\n\n    return out.astype(np.uint8)\n# -----------------------Covolution------------------------------\ndef imgrad(im, pading_mode='mirror'):\n    '''\n    Calculate image gradient.\n    Input:\n        im: h x w x c numpy array\n    '''\n    from scipy.ndimage import correlate  # lazy import\n    wx = np.array([[0, 0, 0],\n                   [-1, 1, 0],\n                   [0, 0, 0]], dtype=np.float32)\n    wy = np.array([[0, -1, 0],\n                   [0, 1, 0],\n                   [0, 0, 0]], dtype=np.float32)\n    if im.ndim == 3:\n        gradx = np.stack(\n                [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],\n                axis=2\n                )\n        grady = np.stack(\n                [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],\n                axis=2\n                )\n        grad = np.concatenate((gradx, grady), axis=2)\n    else:\n        gradx = correlate(im, wx, mode=pading_mode)\n        grady = correlate(im, wy, mode=pading_mode)\n        grad = np.stack((gradx, grady), axis=2)\n\n    return {'gradx': gradx, 'grady': grady, 'grad':grad}\n\ndef convtorch(im, weight, mode='reflect'):\n    '''\n    Image convolution with pytorch\n    Input:\n        im: b x c_in x h x w torch tensor\n        weight: c_out x c_in x k x k torch tensor\n    Output:\n        out: c x h x w torch tensor\n    '''\n    radius = weight.shape[-1]\n    chn = im.shape[1]\n    im_pad = torch.nn.functional.pad(im, pad=(radius // 2, )*4, mode=mode)\n    out = torch.nn.functional.conv2d(im_pad, weight, padding=0, groups=chn)\n    return out\n\n# ----------------------Patch Cropping----------------------------\ndef random_crop(im, pch_size):\n    '''\n    Randomly crop a patch from the give image.\n    '''\n    h, w = im.shape[:2]\n    # padding if necessary\n    if h < pch_size or w < pch_size:\n        pad_h = min(max(0, pch_size - h), h)\n        pad_w = min(max(0, pch_size - w), w)\n        im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)\n\n    h, w = im.shape[:2]\n    if h == pch_size:\n        ind_h = 0\n    elif h > pch_size:\n        ind_h = random.randint(0, h-pch_size)\n    else:\n        raise ValueError('Image height is smaller than the patch size')\n    if w == pch_size:\n        ind_w = 0\n    elif w > pch_size:\n        ind_w = random.randint(0, w-pch_size)\n    else:\n        raise ValueError('Image width is smaller than the patch size')\n\n    im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]\n\n    return im_pch\n\nclass ToTensor:\n    def __init__(self, max_value=1.0):\n        self.max_value = max_value\n\n    def __call__(self, im):\n        assert isinstance(im, np.ndarray)\n        if im.ndim == 2:\n            im = im[:, :, np.newaxis]\n        if im.dtype == np.uint8:\n            assert self.max_value == 255.\n            out = torch.from_numpy(im.astype(np.float32).transpose(2,0,1) / self.max_value)\n        else:\n            assert self.max_value == 1.0\n            out = torch.from_numpy(im.transpose(2,0,1))\n        return out\n\nclass RandomCrop:\n    def __init__(self, pch_size, pass_crop=False):\n        self.pch_size = pch_size\n        self.pass_crop = pass_crop\n\n    def __call__(self, im):\n        if self.pass_crop:\n            return im\n        if isinstance(im, list) or isinstance(im, tuple):\n            out = []\n            for current_im in im:\n                out.append(random_crop(current_im, self.pch_size))\n        else:\n            out = random_crop(im, self.pch_size)\n        return out\n\nclass ImageSpliterNp:\n    def __init__(self, im, pch_size, stride, sf=1):\n        '''\n        Input:\n            im: h x w x c, numpy array, [0, 1], low-resolution image in SR\n            pch_size, stride: patch setting\n            sf: scale factor in image super-resolution\n        '''\n        assert stride <= pch_size\n        self.stride = stride\n        self.pch_size = pch_size\n        self.sf = sf\n\n        if im.ndim == 2:\n            im = im[:, :, None]\n\n        height, width, chn = im.shape\n        self.height_starts_list = self.extract_starts(height)\n        self.width_starts_list = self.extract_starts(width)\n        self.length = self.__len__()\n        self.num_pchs = 0\n\n        self.im_ori = im\n        self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)\n        self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)\n\n    def extract_starts(self, length):\n        starts = list(range(0, length, self.stride))\n        if starts[-1] + self.pch_size > length:\n            starts[-1] = length - self.pch_size\n        return starts\n\n    def __len__(self):\n        return len(self.height_starts_list) * len(self.width_starts_list)\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        if self.num_pchs < self.length:\n            w_start_idx = self.num_pchs // len(self.height_starts_list)\n            w_start = self.width_starts_list[w_start_idx] * self.sf\n            w_end = w_start + self.pch_size * self.sf\n\n            h_start_idx = self.num_pchs % len(self.height_starts_list)\n            h_start = self.height_starts_list[h_start_idx] * self.sf\n            h_end = h_start + self.pch_size * self.sf\n\n            pch = self.im_ori[h_start:h_end, w_start:w_end,]\n            self.w_start, self.w_end = w_start, w_end\n            self.h_start, self.h_end = h_start, h_end\n\n            self.num_pchs += 1\n        else:\n            raise StopIteration(0)\n\n        return pch, (h_start, h_end, w_start, w_end)\n\n    def update(self, pch_res, index_infos):\n        '''\n        Input:\n            pch_res: pch_size x pch_size x 3, [0,1]\n            index_infos: (h_start, h_end, w_start, w_end)\n        '''\n        if index_infos is None:\n            w_start, w_end = self.w_start, self.w_end\n            h_start, h_end = self.h_start, self.h_end\n        else:\n            h_start, h_end, w_start, w_end = index_infos\n\n        self.im_res[h_start:h_end, w_start:w_end] += pch_res\n        self.pixel_count[h_start:h_end, w_start:w_end] += 1\n\n    def gather(self):\n        assert np.all(self.pixel_count != 0)\n        return self.im_res / self.pixel_count\n\nclass ImageSpliterTh:\n    def __init__(self, im, pch_size, stride, sf=1, extra_bs=1, weight_type='Gaussian'):\n        '''\n        Input:\n            im: n x c x h x w, torch tensor, float, low-resolution image in SR\n            pch_size, stride: patch setting\n            sf: scale factor in image super-resolution\n            pch_bs: aggregate pchs to processing, only used when inputing single image\n        '''\n        assert weight_type in ['Gaussian', 'ones']\n        self.weight_type = weight_type\n        assert stride <= pch_size\n        self.stride = stride\n        self.pch_size = pch_size\n        self.sf = sf\n        self.extra_bs = extra_bs\n\n        bs, chn, height, width= im.shape\n        self.true_bs = bs\n\n        self.height_starts_list = self.extract_starts(height)\n        self.width_starts_list = self.extract_starts(width)\n        self.starts_list = []\n        for ii in self.height_starts_list:\n            for jj in self.width_starts_list:\n                self.starts_list.append([ii, jj])\n\n        self.length = self.__len__()\n        self.count_pchs = 0\n\n        self.im_ori = im\n        self.dtype = torch.float64\n        self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)\n        self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)\n\n    def extract_starts(self, length):\n        if length <= self.pch_size:\n            starts = [0,]\n        else:\n            starts = list(range(0, length, self.stride))\n            for ii in range(len(starts)):\n                if starts[ii] + self.pch_size > length:\n                    starts[ii] = length - self.pch_size\n            starts = sorted(set(starts), key=starts.index)\n        return starts\n\n    def __len__(self):\n        return len(self.height_starts_list) * len(self.width_starts_list)\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        if self.count_pchs < self.length:\n            index_infos = []\n            current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs]\n            for ii, (h_start, w_start) in enumerate(current_starts_list):\n                w_end = w_start + self.pch_size\n                h_end = h_start + self.pch_size\n                current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end]\n                if ii == 0:\n                    pch =  current_pch\n                else:\n                    pch = torch.cat([pch, current_pch], dim=0)\n\n                h_start *= self.sf\n                h_end *= self.sf\n                w_start *= self.sf\n                w_end *= self.sf\n                index_infos.append([h_start, h_end, w_start, w_end])\n\n            self.count_pchs += len(current_starts_list)\n        else:\n            raise StopIteration()\n\n        return pch, index_infos\n\n    def update(self, pch_res, index_infos):\n        '''\n        Input:\n            pch_res: (n*extra_bs) x c x pch_size x pch_size, float\n            index_infos: [(h_start, h_end, w_start, w_end),]\n        '''\n        assert pch_res.shape[0] % self.true_bs == 0\n        pch_list = torch.split(pch_res, self.true_bs, dim=0)\n        assert len(pch_list) == len(index_infos)\n        for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos):\n            current_pch = pch_list[ii].type(self.dtype)\n            current_weight = self.get_weight(current_pch.shape[-2], current_pch.shape[-1])\n            self.im_res[:, :, h_start:h_end, w_start:w_end] +=  current_pch * current_weight\n            self.pixel_count[:, :, h_start:h_end, w_start:w_end] += current_weight\n\n    @staticmethod\n    def generate_kernel_1d(ksize):\n        sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8  # opencv default setting\n        if ksize % 2 == 0:\n            kernel = cv2.getGaussianKernel(ksize=ksize+1, sigma=sigma, ktype=cv2.CV_64F)\n            kernel = kernel[1:, ]\n        else:\n            kernel = cv2.getGaussianKernel(ksize=ksize, sigma=sigma, ktype=cv2.CV_64F)\n\n        return kernel\n\n    def get_weight(self, height, width):\n        if self.weight_type == 'ones':\n            kernel = torch.ones(1, 1, height, width)\n        elif self.weight_type == 'Gaussian':\n            kernel_h = self.generate_kernel_1d(height).reshape(-1, 1)\n            kernel_w = self.generate_kernel_1d(width).reshape(1, -1)\n            kernel = np.matmul(kernel_h, kernel_w)\n            kernel = torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0) # 1 x 1 x pch_size x pch_size\n        else:\n            raise ValueError(f\"Unsupported weight type: {self.weight_type}\")\n\n        return kernel.to(dtype=self.dtype, device=self.im_ori.device)\n\n    def gather(self):\n        assert torch.all(self.pixel_count != 0)\n        return self.im_res.div(self.pixel_count)\n\n# ----------------------Patch Cliping----------------------------\nclass Clamper:\n    def __init__(self, min_max=(-1, 1)):\n        self.min_bound, self.max_bound = min_max[0], min_max[1]\n\n    def __call__(self, im):\n        if isinstance(im, np.ndarray):\n            return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)\n        elif isinstance(im, torch.Tensor):\n            return torch.clamp(im, min=self.min_bound, max=self.max_bound)\n        else:\n            raise TypeError(f'ndarray or Tensor expected, got {type(im)}')\n\n# ----------------------Interpolation----------------------------\nclass Bicubic:\n    def __init__(self, scale=None, out_shape=None, activate_matlab=True, resize_back=False):\n        self.scale = scale\n        self.activate_matlab = activate_matlab\n        self.out_shape = out_shape\n        self.resize_back = resize_back\n\n    def __call__(self, im):\n        if self.activate_matlab:\n            out = imresize_np(im, scale=self.scale)\n            if self.resize_back:\n                out = imresize_np(out, scale=1/self.scale)\n        else:\n            out = cv2.resize(\n                    im,\n                    dsize=self.out_shape,\n                    fx=self.scale,\n                    fy=self.scale,\n                    interpolation=cv2.INTER_CUBIC,\n                    )\n            if self.resize_back:\n                out = cv2.resize(\n                        out,\n                        dsize=self.out_shape,\n                        fx=1/self.scale,\n                        fy=1/self.scale,\n                        interpolation=cv2.INTER_CUBIC,\n                        )\n        return out\n\nclass SmallestMaxSize:\n    def __init__(self, max_size, pass_resize=False, interpolation=None):\n        self.pass_resize = pass_resize\n        self.max_size = max_size\n        self.interpolation = interpolation\n        self.str2mode = {\n                'nearest': cv2.INTER_NEAREST_EXACT,\n                'bilinear': cv2.INTER_LINEAR,\n                'bicubic': cv2.INTER_CUBIC\n                }\n        if self.interpolation is not None:\n            assert interpolation in self.str2mode, f\"Not supported interpolation mode: {interpolation}\"\n\n    def get_interpolation(self, size):\n        if self.interpolation is None:\n            if size < self.max_size:   # upsampling\n                interpolation = cv2.INTER_CUBIC\n            else:                      # downsampling\n                interpolation = cv2.INTER_AREA\n        else:\n            interpolation = self.str2mode[self.interpolation]\n\n        return interpolation\n\n    def __call__(self, im):\n        h, w = im.shape[:2]\n        if self.pass_resize or min(h, w) == self.max_size:\n            out = im\n        else:\n            if h < w:\n                dsize = (int(self.max_size * w / h), self.max_size)\n                out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(h))\n            else:\n                dsize = (self.max_size, int(self.max_size * h / w))\n                out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(w))\n            if out.dtype == np.uint8:\n                out = np.clip(out, 0, 255)\n            else:\n                out = np.clip(out, 0, 1.0)\n\n        return out\n\n# ----------------------augmentation----------------------------\nclass SpatialAug:\n    def __init__(self, pass_aug, only_hflip=False, only_vflip=False, only_hvflip=False):\n        self.only_hflip = only_hflip\n        self.only_vflip = only_vflip\n        self.only_hvflip = only_hvflip\n        self.pass_aug = pass_aug\n\n    def __call__(self, im, flag=None):\n        if self.pass_aug:\n            return im\n\n        if flag is None:\n            if self.only_hflip:\n                flag = random.choice([0, 5])\n            elif self.only_vflip:\n                flag = random.choice([0, 1])\n            elif self.only_hvflip:\n                flag = random.choice([0, 1, 5])\n            else:\n                flag = random.randint(0, 7)\n\n        if isinstance(im, list) or isinstance(im, tuple):\n            out = []\n            for current_im in im:\n                out.append(data_aug_np(current_im, flag))\n        else:\n            out = data_aug_np(im, flag)\n        return out\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_net.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 20:29:36\n\ndef reload_model(model, ckpt):\n    module_flag = list(ckpt.keys())[0].startswith('module.')\n    compile_flag = '_orig_mod' in list(ckpt.keys())[0]\n\n    for source_key, source_value in model.state_dict().items():\n        target_key = source_key\n        if compile_flag and (not '_orig_mod.' in source_key):\n            target_key = '_orig_mod.' + target_key\n        if module_flag and (not source_key.startswith('module')):\n            target_key = 'module.' + target_key\n\n        assert target_key in ckpt\n        source_value.copy_(ckpt[target_key])\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_opts.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 15:07:43\n\nimport argparse\n\ndef update_args(args_json, args_parser):\n    for arg in vars(args_parser):\n        args_json[arg] = getattr(args_parser, arg)\n\ndef str2bool(v):\n    \"\"\"\n    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse\n    \"\"\"\n    if isinstance(v, bool):\n        return v\n    if v.lower() in (\"yes\", \"true\", \"t\", \"y\", \"1\"):\n        return True\n    elif v.lower() in (\"no\", \"false\", \"f\", \"n\", \"0\"):\n        return False\n    else:\n        raise argparse.ArgumentTypeError(\"boolean value expected\")\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_sisr.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-12-07 21:37:58\n\nimport cv2\nimport numpy as np\n\ndef modcrop(im, sf):\n    h, w = im.shape[:2]\n    h -= (h % sf)\n    w -= (w % sf)\n    return im[:h, :w,]\n\n#-----------------------------------------Transform--------------------------------------------\nclass Bicubic:\n    def __init__(self, scale=None, out_shape=None, matlab_mode=True):\n        self.scale = scale\n        self.out_shape = out_shape\n\n    def __call__(self, im):\n        out = cv2.resize(\n                im,\n                dsize=self.out_shape,\n                fx=self.scale,\n                fy=self.scale,\n                interpolation=cv2.INTER_CUBIC,\n                )\n        return out\n"
  },
  {
    "path": "configs/degradation_testing_realesrgan.yaml",
    "content": "degradation:\n  sf: 4\n  # the first degradation process\n  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep\n  resize_range: [0.5, 1.5]\n  gaussian_noise_prob: 0.5\n  noise_range: [1, 15]\n  poisson_scale_range: [0.05, 0.3]\n  gray_noise_prob: 0.4\n  jpeg_range: [70, 95]\n\n  # the second degradation process\n  second_order_prob: 0.0\n  second_blur_prob: 0.2\n  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep\n  resize_range2: [0.8, 1.2]\n  gaussian_noise_prob2: 0.5\n  noise_range2: [1, 10]\n  poisson_scale_range2: [0.05, 0.2]\n  gray_noise_prob2: 0.4\n  jpeg_range2: [80, 100]\n\n  gt_size: 512\n\nopts:\n  data_source: ~\n  im_exts: ['png', 'JPEG']\n  io_backend:\n    type: disk\n  blur_kernel_size: 13\n  kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n  kernel_prob: [0.60, 0.40, 0.0, 0.0, 0.0, 0.0]\n  sinc_prob: 0.1\n  blur_sigma: [0.2, 0.8]\n  betag_range: [1.0, 1.5]\n  betap_range: [1, 1.2]\n\n  blur_kernel_size2: 7\n  kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n  kernel_prob2: [0.60, 0.4, 0.0, 0.0, 0.0, 0.0]\n  sinc_prob2: 0.0\n  blur_sigma2: [0.2, 0.5]\n  betag_range2: [0.5, 0.8]\n  betap_range2: [1, 1.2]\n\n  final_sinc_prob: 0.2\n\n  gt_size: ${degradation.gt_size}\n  crop_pad_size: ${degradation.gt_size}\n  use_hflip: False\n  use_rot: False\n\n"
  },
  {
    "path": "configs/sample-sd-turbo.yaml",
    "content": "seed: 12345\n\n\n# Super-resolution settings\nbasesr:\n  sf: 4\n  chopping:     # for latent diffusion\n    pch_size: 128\n    weight_type: Gaussian\n\n# VAE settings\ntiled_vae: True\nlatent_tiled_size: 128\nsample_tiled_size: 1024\ngradient_checkpointing_vae: True\nsliced_vae: False\n\n# classifer-free guidance\ncfg_scale: 1.0\n\n# sampling settings \nstart_timesteps: 200\n\n# color fixing\ncolor_fix: ~\n\n# Stable Diffusion \nbase_model: sd-turbo\nsd_pipe:\n  target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline\n  enable_grad_checkpoint: True\n  params:\n    pretrained_model_name_or_path: stabilityai/sd-turbo\n    cache_dir: /mnt/sfs-common/zsyue/modelbase/stable-diffusion/sd-turbo\n    use_safetensors: True\n    torch_dtype: torch.float16\n\nmodel_start:\n  target: .noise_predictor.NoisePredictor\n  ckpt_path: ~           # For initializing\n  params:\n    in_channels: 3\n    down_block_types:\n      - AttnDownBlock2D\n      - AttnDownBlock2D\n    up_block_types:\n      - AttnUpBlock2D\n      - AttnUpBlock2D\n    block_out_channels:\n      - 256    # 192, 256\n      - 512    # 384, 512\n    layers_per_block: \n      - 3\n      - 3\n    act_fn: silu\n    latent_channels: 4\n    norm_num_groups: 32\n    sample_size: 128\n    mid_block_add_attention: True\n    resnet_time_scale_shift: default\n    temb_channels: 512\n    attention_head_dim: 64 \n    freq_shift: 0\n    flip_sin_to_cos: True\n    double_z: True\n\nmodel_middle:\n  target: .noise_predictor.NoisePredictor\n  params:\n    in_channels: 3\n    down_block_types:\n      - AttnDownBlock2D\n      - AttnDownBlock2D\n    up_block_types:\n      - AttnUpBlock2D\n      - AttnUpBlock2D\n    block_out_channels:\n      - 256    # 192, 256\n      - 512    # 384, 512\n    layers_per_block: \n      - 3\n      - 3\n    act_fn: silu\n    latent_channels: 4\n    norm_num_groups: 32\n    sample_size: 128\n    mid_block_add_attention: True\n    resnet_time_scale_shift: default\n    temb_channels: 512\n    attention_head_dim: 64 \n    freq_shift: 0\n    flip_sin_to_cos: True\n    double_z: True\n"
  },
  {
    "path": "configs/sd-turbo-sr-ldis.yaml",
    "content": "trainer:\n  target: trainer.TrainerSDTurboSR\n\nsd_pipe:\n  target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline\n  num_train_steps: 1000\n  enable_grad_checkpoint: True\n  compile: False\n  vae_split: 8\n  params:\n    pretrained_model_name_or_path: stabilityai/sd-turbo\n    cache_dir: weights\n    use_safetensors: True\n    torch_dtype: torch.float16\n\nllpips:\n  target: latent_lpips.lpips.LPIPS\n  ckpt_path: weights/vgg16_sdturbo_lpips.pth\n  compile: False\n  params:\n    pretrained: False\n    net: vgg16\n    lpips: True\n    spatial: False\n    pnet_rand: False\n    pnet_tune: True\n    use_dropout: True\n    eval_mode: True\n    latent: True\n    in_chans: 4\n    verbose: True\n\nmodel:\n  target: .noise_predictor.NoisePredictor\n  ckpt_start_path: ~     # only used for training the intermidiate model\n  ckpt_path: ~           # For initializing\n  compile: False\n  params:\n    in_channels: 3\n    down_block_types:\n      - AttnDownBlock2D\n      - AttnDownBlock2D\n    up_block_types:\n      - AttnUpBlock2D\n      - AttnUpBlock2D\n    block_out_channels:\n      - 256      # 192, 256\n      - 512      # 384, 512\n    layers_per_block: \n      - 3\n      - 3\n    act_fn: silu\n    latent_channels: 4\n    norm_num_groups: 32\n    sample_size: 128\n    mid_block_add_attention: True\n    resnet_time_scale_shift: default\n    temb_channels: 512\n    attention_head_dim: 64 \n    freq_shift: 0\n    flip_sin_to_cos: True\n    double_z: True\n\ndiscriminator:\n  target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator\n  enable_grad_checkpoint: True\n  compile: False\n  params:\n    sample_size: 64\n    in_channels: 4\n    center_input_sample: False\n    flip_sin_to_cos: True\n    freq_shift: 0\n    down_block_types:\n      - DownBlock2D\n      - CrossAttnDownBlock2D\n      - CrossAttnDownBlock2D\n    mid_block_type: UNetMidBlock2DCrossAttn\n    up_block_types:\n      - CrossAttnUpBlock2D\n      - CrossAttnUpBlock2D\n      - UpBlock2D\n    only_cross_attention: False\n    block_out_channels:\n      - 128\n      - 256\n      - 512\n    layers_per_block:\n      - 1\n      - 2\n      - 2\n    downsample_padding: 1\n    mid_block_scale_factor: 1\n    dropout: 0.0\n    act_fn: silu\n    norm_num_groups: 32\n    norm_eps: 1e-5\n    cross_attention_dim: 1024\n    transformer_layers_per_block: 1\n    reverse_transformer_layers_per_block: ~\n    encoder_hid_dim: ~\n    encoder_hid_dim_type: ~\n    attention_head_dim:\n      - 8 \n      - 16 \n      - 16 \n    num_attention_heads: ~\n    dual_cross_attention: False\n    use_linear_projection: False\n    class_embed_type: ~\n    addition_embed_type: text \n    addition_time_embed_dim: 256\n    num_class_embeds: ~\n    upcast_attention: ~\n    resnet_time_scale_shift: default\n    resnet_skip_time_act: False\n    resnet_out_scale_factor: 1.0\n    time_embedding_type: positional\n    time_embedding_dim: ~\n    time_embedding_act_fn: ~\n    timestep_post_act: ~\n    time_cond_proj_dim: ~\n    conv_in_kernel: 3\n    conv_out_kernel: 3\n    projection_class_embeddings_input_dim: 2560\n    attention_type: default\n    class_embeddings_concat: False\n    mid_block_only_cross_attention: ~\n    cross_attention_norm: ~\n    addition_embed_type_num_heads: 64\n\ndegradation:\n  sf: 4\n  # the first degradation process\n  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep\n  resize_range: [0.15, 1.5]\n  gaussian_noise_prob: 0.5\n  noise_range: [1, 30]\n  poisson_scale_range: [0.05, 3.0]\n  gray_noise_prob: 0.4\n  jpeg_range: [30, 95]\n\n  # the second degradation process\n  second_order_prob: 0.5\n  second_blur_prob: 0.8\n  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep\n  resize_range2: [0.3, 1.2]\n  gaussian_noise_prob2: 0.5\n  noise_range2: [1, 25]\n  poisson_scale_range2: [0.05, 2.5]\n  gray_noise_prob2: 0.4\n  jpeg_range2: [30, 95]\n\n  gt_size: 512 \n  resize_back: False\n  use_sharp: False\n\ndata:\n  train:\n    type: realesrgan\n    params:\n      data_source: \n        source1:\n          root_path: /mnt/sfs-common/zsyue/database/FFHQ\n          image_path: images1024\n          moment_path: ~\n          text_path: ~\n          im_ext: png\n          length: 20000\n        source2:\n          root_path: /mnt/sfs-common/zsyue/database/LSDIR/train\n          image_path: images \n          moment_path: ~\n          text_path: ~\n          im_ext: png\n      max_token_length: 77   # 77\n      io_backend:\n        type: disk\n      blur_kernel_size: 21\n      kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n      kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]\n      sinc_prob: 0.1\n      blur_sigma: [0.2, 3.0]\n      betag_range: [0.5, 4.0]\n      betap_range: [1, 2.0]\n\n      blur_kernel_size2: 15\n      kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n      kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]\n      sinc_prob2: 0.1\n      blur_sigma2: [0.2, 1.5]\n      betag_range2: [0.5, 4.0]\n      betap_range2: [1, 2.0]\n\n      final_sinc_prob: 0.8\n\n      gt_size: ${degradation.gt_size}\n      use_hflip: True\n      use_rot: False\n      random_crop: True\n  val:\n    type: base\n    params:\n      dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq\n      transform_type: default\n      transform_kwargs:\n          mean: 0.0\n          std: 1.0\n      extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt\n      extra_transform_type: default\n      extra_transform_kwargs:\n          mean: 0.0\n          std: 1.0\n      im_exts: png\n      length: 16\n      recursive: False\n\ntrain:\n  # predict started inverser\n  start_mode: True\n  # learning rate\n  lr: 5e-5                      # learning rate \n  lr_min: 5e-5                  # learning rate \n  lr_schedule: ~\n  warmup_iterations: 2000\n  # discriminator\n  lr_dis: 5e-5                  # learning rate for dicriminator\n  weight_decay_dis: 1e-3        # weight decay for dicriminator\n  dis_init_iterations: 10000    # iterations used for updating the discriminator\n  dis_update_freq: 1            \n  # dataloader\n  batch: 64               \n  microbatch: 16 \n  num_workers: 4\n  prefetch_factor: 2            \n  use_text: True\n  # optimization settings\n  weight_decay: 0               \n  ema_rate: 0.999\n  iterations: 200000            # total iterations\n  # logging\n  save_freq: 5000\n  log_freq: [200, 5000]         # [training loss, training images, val images]\n  local_logging: True           # manually save images\n  tf_logging: False             # tensorboard logging\n  # loss \n  loss_type: L2\n  loss_coef:\n    ldif: 1.0\n  timesteps: [200, 100]\n  num_inference_steps: 5\n  # mixed precision\n  use_amp: True                \n  use_fsdp: False                \n  # random seed \n  seed: 123456                 \n  global_seeding: False\n  noise_detach: False\n\nvalidate:\n  batch: 2\n  use_ema: True            \n  log_freq: 4      # logging frequence\n  val_y_channel: True\n"
  },
  {
    "path": "node.py",
    "content": "from .comfyui_invsr_trimmed import get_configs, InvSamplerSR, BaseSampler, Namespace\nimport torch\nfrom comfy.utils import ProgressBar\nfrom folder_paths import get_full_path, get_folder_paths, models_dir\nimport os\nimport torch.nn.functional as F\n\ndef split_tensor_into_batches(tensor, batch_size):\n    \"\"\"\n    Split a tensor into smaller batches of specified size\n    \n    Args:\n        tensor (torch.Tensor): Input tensor of shape (N, C, H, W)\n        batch_size (int): Desired batch size for splitting\n        \n    Returns:\n        list: List of tensors, each with batch_size (except possibly the last one)\n    \"\"\"\n    # Get original batch size\n    original_batch_size = tensor.size(0)\n    \n    # Calculate number of full batches and remaining samples\n    num_full_batches = original_batch_size // batch_size\n    remaining_samples = original_batch_size % batch_size\n    \n    # Split tensor into chunks\n    batches = []\n    \n    # Handle full batches\n    for i in range(num_full_batches):\n        start_idx = i * batch_size\n        end_idx = start_idx + batch_size\n        batch = tensor[start_idx:end_idx]\n        batches.append(batch)\n    \n    # Handle remaining samples if any\n    if remaining_samples > 0:\n        last_batch = tensor[-remaining_samples:]\n        batches.append(last_batch)\n    \n    return batches\n\n\nclass LoadInvSRModels:\n    @classmethod\n    def INPUT_TYPES(s):\n        return {\n            \"required\": {\n                \"sd_model\": (['stabilityai/sd-turbo'],),\n                \"invsr_model\": (['noise_predictor_sd_turbo_v5.pth', 'noise_predictor_sd_turbo_v5_diftune.pth'],),\n                \"dtype\": (['fp16', 'fp32', 'bf16'], {\"default\": \"fp16\"}),\n                \"tiled_vae\": (\"BOOLEAN\", {\"default\": True}),\n            },\n        }\n\n    RETURN_TYPES = (\"INVSR_PIPE\",)\n    RETURN_NAMES = (\"invsr_pipe\",)\n    FUNCTION = \"loadmodel\"\n    CATEGORY = \"INVSR\"\n\n    def loadmodel(self, sd_model, invsr_model, dtype, tiled_vae):\n        match dtype:\n            case \"fp16\":\n                dtype = \"torch.float16\"\n            case \"fp32\":\n                dtype = \"torch.float32\"\n            case \"bf16\":\n                dtype = \"torch.bfloat16\"\n\n        cfg_path = os.path.join(\n            os.path.dirname(__file__), \"configs\", \"sample-sd-turbo.yaml\"\n        )\n        sd_path = get_folder_paths(\"diffusers\")[0]\n\n        try:\n            ckpt_dir = get_folder_paths(\"invsr\")[0]\n        except:\n            ckpt_dir = os.path.join(models_dir, \"invsr\")\n\n        args = Namespace(\n            bs=1,\n            chopping_bs=8,\n            timesteps=None,\n            num_steps=1,\n            cfg_path=cfg_path,\n            sd_path=sd_path,\n            started_ckpt_dir=ckpt_dir,\n            tiled_vae=tiled_vae,\n            color_fix=\"\",\n            chopping_size=128,\n            invsr_model=invsr_model\n        )\n        configs = get_configs(args)\n        configs[\"sd_pipe\"][\"params\"][\"torch_dtype\"] = dtype\n        base_sampler = BaseSampler(configs)\n\n        return ((base_sampler, invsr_model),)\n\nclass InvSRSampler:\n    @classmethod\n    def INPUT_TYPES(s):\n        return {\n            \"required\": {\n                \"invsr_pipe\": (\"INVSR_PIPE\",),\n                \"images\": (\"IMAGE\",),\n                \"num_steps\": (\"INT\",{\"default\": 1, \"min\": 1, \"max\": 5}),\n                \"cfg\": (\"FLOAT\",{\"default\": 1.0, \"step\":0.1}),\n                # \"scale_factor\": (\"INT\",{\"default\": 4}),\n                \"batch_size\": (\"INT\",{\"default\": 1}),\n                \"chopping_batch_size\": (\"INT\",{\"default\": 8}),\n                \"chopping_size\": ([128, 256, 512],{\"default\": 128}),\n                \"color_fix\": (['none', 'wavelet', 'ycbcr'], {\"default\": \"none\"}),\n                \"seed\": (\"INT\", {\"default\": 123, \"min\": 0, \"max\": 2**32 - 1, \"step\": 1}),\n            },\n        }\n\n    RETURN_TYPES = (\"IMAGE\",)\n    RETURN_NAMES = (\"image\",)\n    FUNCTION = \"process\"\n    CATEGORY = \"INVSR\"\n\n    def process(self, invsr_pipe, images, num_steps, cfg, batch_size, chopping_batch_size, chopping_size, color_fix, seed):\n        base_sampler, invsr_model = invsr_pipe\n        if color_fix == \"none\":\n            color_fix = \"\"\n\n        cfg_path = os.path.join(\n            os.path.dirname(__file__), \"configs\", \"sample-sd-turbo.yaml\"\n        )\n        sd_path = get_folder_paths(\"diffusers\")[0]\n\n        try:\n            ckpt_dir = get_folder_paths(\"invsr\")[0]\n        except:\n            ckpt_dir = os.path.join(models_dir, \"invsr\")\n\n        args = Namespace(\n            bs=batch_size,\n            chopping_bs=chopping_batch_size,\n            timesteps=None,\n            num_steps=num_steps,\n            cfg_path=cfg_path,\n            sd_path=sd_path,\n            started_ckpt_dir=ckpt_dir,\n            tiled_vae=base_sampler.configs.tiled_vae,\n            color_fix=color_fix,\n            chopping_size=chopping_size,\n            invsr_model=invsr_model\n        )\n        configs = get_configs(args, log=True)\n        configs[\"cfg_scale\"] = cfg\n        # configs[\"basesr\"][\"sf\"] = scale_factor\n        \n        base_sampler.configs = configs\n        base_sampler.setup_seed(seed)\n        sampler = InvSamplerSR(base_sampler)\n\n        images_bchw = images.permute(0,3,1,2)\n        og_h, og_w = images_bchw.shape[2:]\n\n        # Calculate new dimensions divisible by 16\n        new_height = ((og_h + 15) // 16) * 16  # Round up to nearest multiple of 16\n        new_width = ((og_w + 15) // 16) * 16\n        resized = False\n        \n        if og_h != new_height or og_w != new_width:\n            resized = True\n            print(f\"[InvSR] - Image not divisible by 16. Resizing to {new_height} (h) x {new_width} (w)\")\n            images_bchw = F.interpolate(images_bchw, size=(new_height, new_width), mode='bicubic', align_corners=False)\n\n        batches = split_tensor_into_batches(images_bchw, batch_size)\n\n        results = []\n        pbar = ProgressBar(len(batches))\n\n        for batch in batches:\n            result = sampler.inference(image_bchw=batch)\n            results.append(torch.from_numpy(result))\n            pbar.update(1)\n\n        result_t = torch.cat(results, dim=0)\n\n        # Resize to original dimensions * 4\n        if resized:\n            result_t = F.interpolate(result_t, size=(og_h * 4, og_w * 4), mode='bicubic', align_corners=False)\n            \n        return (result_t.permute(0,2,3,1),)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[project]\nname = \"invsr\"\ndescription = \"This project is an unofficial ComfyUI implementation of [a/InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)\"\nversion = \"1.0.1\"\nlicense = {file = \"LICENSE\"}\ndependencies = [\"pyiqa==0.1.12\", \"opencv-python\", \"albumentations==1.4.18\", \"bitsandbytes\", \"protobuf\", \"python-box\", \"omegaconf\", \"loguru\"]\n\n[project.urls]\nRepository = \"https://github.com/yuvraj108c/ComfyUI_InvSR\"\n#  Used by Comfy Registry https://comfyregistry.org\n\n[tool.comfy]\nPublisherId = \"yuvraj108c\"\nDisplayName = \"ComfyUI_InvSR\"\nIcon = \"\"\n"
  },
  {
    "path": "requirements.txt",
    "content": "opencv-contrib-python-headless\nomegaconf\ndiffusers\nnumpy<2\nhuggingface-hub\ntransformers\n"
  },
  {
    "path": "workflows/invsr.json",
    "content": "{\n  \"last_node_id\": 27,\n  \"last_link_id\": 40,\n  \"nodes\": [\n    {\n      \"id\": 25,\n      \"type\": \"LoadInvSRModels\",\n      \"pos\": [\n        -877.6565551757812,\n        719.943603515625\n      ],\n      \"size\": [\n        413.84686279296875,\n        130\n      ],\n      \"flags\": {},\n      \"order\": 0,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"invsr_pipe\",\n          \"type\": \"INVSR_PIPE\",\n          \"links\": [\n            38\n          ],\n          \"slot_index\": 0\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"LoadInvSRModels\"\n      },\n      \"widgets_values\": [\n        \"stabilityai/sd-turbo\",\n        \"noise_predictor_sd_turbo_v5.pth\",\n        \"bf16\",\n        true\n      ]\n    },\n    {\n      \"id\": 17,\n      \"type\": \"LoadImage\",\n      \"pos\": [\n        -867.9894409179688,\n        918.590087890625\n      ],\n      \"size\": [\n        367.18212890625,\n        439.6049499511719\n      ],\n      \"flags\": {},\n      \"order\": 1,\n      \"mode\": 0,\n      \"inputs\": [],\n      \"outputs\": [\n        {\n          \"name\": \"IMAGE\",\n          \"type\": \"IMAGE\",\n          \"links\": [\n            40\n          ],\n          \"slot_index\": 0\n        },\n        {\n          \"name\": \"MASK\",\n          \"type\": \"MASK\",\n          \"links\": null\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"LoadImage\"\n      },\n      \"widgets_values\": [\n        \"i1.png\",\n        \"image\"\n      ]\n    },\n    {\n      \"id\": 27,\n      \"type\": \"InvSRSampler\",\n      \"pos\": [\n        -365.7756652832031,\n        859.0978393554688\n      ],\n      \"size\": [\n        315,\n        246\n      ],\n      \"flags\": {},\n      \"order\": 2,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"invsr_pipe\",\n          \"type\": \"INVSR_PIPE\",\n          \"link\": 38\n        },\n        {\n          \"name\": \"images\",\n          \"type\": \"IMAGE\",\n          \"link\": 40\n        }\n      ],\n      \"outputs\": [\n        {\n          \"name\": \"image\",\n          \"type\": \"IMAGE\",\n          \"links\": [\n            39\n          ],\n          \"slot_index\": 0\n        }\n      ],\n      \"properties\": {\n        \"Node name for S&R\": \"InvSRSampler\"\n      },\n      \"widgets_values\": [\n        1,\n        5,\n        1,\n        8,\n        128,\n        \"wavelet\",\n        536149006,\n        \"randomize\"\n      ]\n    },\n    {\n      \"id\": 26,\n      \"type\": \"PreviewImage\",\n      \"pos\": [\n        35.55903625488281,\n        643.07470703125\n      ],\n      \"size\": [\n        889.1287231445312,\n        869.5750732421875\n      ],\n      \"flags\": {},\n      \"order\": 3,\n      \"mode\": 0,\n      \"inputs\": [\n        {\n          \"name\": \"images\",\n          \"type\": \"IMAGE\",\n          \"link\": 39\n        }\n      ],\n      \"outputs\": [],\n      \"properties\": {\n        \"Node name for S&R\": \"PreviewImage\"\n      }\n    }\n  ],\n  \"links\": [\n    [\n      38,\n      25,\n      0,\n      27,\n      0,\n      \"INVSR_PIPE\"\n    ],\n    [\n      39,\n      27,\n      0,\n      26,\n      0,\n      \"IMAGE\"\n    ],\n    [\n      40,\n      17,\n      0,\n      27,\n      1,\n      \"IMAGE\"\n    ]\n  ],\n  \"groups\": [],\n  \"config\": {},\n  \"extra\": {\n    \"ds\": {\n      \"scale\": 0.9090909090909091,\n      \"offset\": [\n        1131.660545256908,\n        -643.8089187150072\n      ]\n    },\n    \"VHS_latentpreview\": false,\n    \"VHS_latentpreviewrate\": 0\n  },\n  \"version\": 0.4\n}"
  }
]