Full Code of yuvraj108c/ComfyUI_InvSR for AI

main 20a0e003e676 cached
31 files
199.0 KB
51.9k tokens
199 symbols
1 requests
Download .txt
Showing preview only (210K chars total). Download the full file or copy to clipboard to get everything.
Repository: yuvraj108c/ComfyUI_InvSR
Branch: main
Commit: 20a0e003e676
Files: 31
Total size: 199.0 KB

Directory structure:
gitextract_393ad4r7/

├── .github/
│   ├── FUNDING.yml
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── comfyui_invsr_trimmed/
│   ├── __init__.py
│   ├── inference_invsr.py
│   ├── latent_lpips/
│   │   ├── __init__.py
│   │   ├── lpips.py
│   │   └── pretrained_networks.py
│   ├── noise_predictor.py
│   ├── pipeline_stable_diffusion_inversion_sr.py
│   ├── sampler_invsr.py
│   ├── time_aware_encoder.py
│   └── utils/
│       ├── __init__.py
│       ├── resize.py
│       ├── util_color_fix.py
│       ├── util_common.py
│       ├── util_ema.py
│       ├── util_image.py
│       ├── util_net.py
│       ├── util_opts.py
│       └── util_sisr.py
├── configs/
│   ├── degradation_testing_realesrgan.yaml
│   ├── sample-sd-turbo.yaml
│   └── sd-turbo-sr-ldis.yaml
├── node.py
├── pyproject.toml
├── requirements.txt
└── workflows/
    └── invsr.json

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/FUNDING.yml
================================================
github: yuvraj108c
custom: ["https://paypal.me/yuvraj108c", "https://buymeacoffee.com/yuvraj108cz"]


================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to Comfy registry
on:
  workflow_dispatch:
  push:
    branches:
      - main
      - master
    paths:
      - "pyproject.toml"

permissions:
  issues: write

jobs:
  publish-node:
    name: Publish Custom Node to registry
    runs-on: ubuntu-latest
    if: ${{ github.repository_owner == 'yuvraj108c' }}
    steps:
      - name: Check out code
        uses: actions/checkout@v4
        with:
          submodules: true
      - name: Publish Custom Node
        uses: Comfy-Org/publish-node-action@v1
        with:
          ## Add your own personal access token to your Github Repository secrets and reference it here.
          personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}


================================================
FILE: .gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
# *.egg-info
*.bak
checkpoints
results
backup

================================================
FILE: LICENSE
================================================
S-Lab License 1.0

Copyright 2024 S-Lab

Redistribution and use for non-commercial purpose in source and 
binary forms, with or without modification, are permitted provided 
that the following conditions are met:

1. Redistributions of source code must retain the above copyright 
   notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright 
   notice, this list of conditions and the following disclaimer in 
   the documentation and/or other materials provided with the 
   distribution.

3. Neither the name of the copyright holder nor the names of its 
   contributors may be used to endorse or promote products derived 
   from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

In the event that redistribution and/or use for commercial purpose in 
source or binary forms, with or without modification is required, 
please contact the contributor(s) of the work.


================================================
FILE: README.md
================================================
<div align="center">

# ComfyUI InvSR
[![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.09013-b31b1b.svg)](https://arxiv.org/abs/2412.09013) 

This project is a ComfyUI wrapper for [InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)

**Last tested**: 2 January 2026 (ComfyUI v0.7.0@f2fda02 | Torch 2.9.1 | Python 3.10.12 | RTX4090 | CUDA 13.0 | Debian 12)

<img height="400" src="https://github.com/user-attachments/assets/6c057a3c-3355-4060-9161-a88ab6f6d986" />

</div>

## ⭐ Support
If you like my projects and wish to see updates and new features, please consider supporting me. It helps a lot! 

[![ComfyUI-Depth-Anything-Tensorrt](https://img.shields.io/badge/ComfyUI--Depth--Anything--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt)
[![ComfyUI-Upscaler-Tensorrt](https://img.shields.io/badge/ComfyUI--Upscaler--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt)
[![ComfyUI-Dwpose-Tensorrt](https://img.shields.io/badge/ComfyUI--Dwpose--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Dwpose-Tensorrt)
[![ComfyUI-Rife-Tensorrt](https://img.shields.io/badge/ComfyUI--Rife--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Rife-Tensorrt)

[![ComfyUI-Whisper](https://img.shields.io/badge/ComfyUI--Whisper-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Whisper)
[![ComfyUI_InvSR](https://img.shields.io/badge/ComfyUI__InvSR-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI_InvSR)
[![ComfyUI-Thera](https://img.shields.io/badge/ComfyUI--Thera-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Thera)
[![ComfyUI-Video-Depth-Anything](https://img.shields.io/badge/ComfyUI--Video--Depth--Anything-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Video-Depth-Anything)
[![ComfyUI-PiperTTS](https://img.shields.io/badge/ComfyUI--PiperTTS-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-PiperTTS)

[![buy-me-coffees](https://i.imgur.com/3MDbAtw.png)](https://www.buymeacoffee.com/yuvraj108cZ)
[![paypal-donation](https://i.imgur.com/w5jjubk.png)](https://paypal.me/yuvraj108c)
---

## Installation
Navigate to the ComfyUI `/custom_nodes` directory
```bash
git clone https://github.com/yuvraj108c/ComfyUI_InvSR
cd ComfyUI_InvSR

pip install -r requirements.txt
```

## Usage
- Load [example workflow](workflows/invsr.json) 
- Diffusers model (stabilityai/sd-turbo) will download automatically to `ComfyUI/models/diffusers`
- InvSR model (noise_predictor_sd_turbo_v5.pth) will download automatically to `ComfyUI/models/invsr`
- To deal with large images, e.g, 1k---->4k, set `chopping_size` 256
- If your GPU memory is limited, please set `chopping_batch_size` to 1

## Parameters
- `num_steps`: number of inference steps
- `cfg`: classifier-free guidance scale
- `batch_size`: Controls how many complete images are processed simultaneously
- `chopping_batch_size`: Controls how many patches from the same image are processed simultaneously
- `chopping_size`: Controls the size of patches when splitting large images
- `color_fix`: Method to fix color shift in processed images

## Updates
**28 April 2025**
- Update diffusers versions in requirements.txt to fix https://github.com/yuvraj108c/ComfyUI_InvSR/issues/26, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/21, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/15
- Add support for `noise_predictor_sd_turbo_v5_diftune.pth`
  
**03 February 2025**
- Add cfg parameter
- Make image divisible by 16
- Use `mm` to set torch device
  
**31 January 2025**
- Merged https://github.com/yuvraj108c/ComfyUI_InvSR/pull/5 by [wfjsw](https://github.com/wfjsw)
  - Compatibility with `diffusers>=0.28`
  - Massive code refactoring & cleanup

## Citation
```bibtex
@article{yue2024InvSR,
  title={Arbitrary-steps Image Super-resolution via Diffusion Inversion},
  author={Yue, Zongsheng and Kang, Liao and Loy, Chen Change},
  journal = {arXiv preprint arXiv:2412.09013},
  year={2024},
}
```

## License
This project is licensed under [NTU S-Lab License 1.0](LICENSE)

## Acknowledgments
Thanks to [simplepod.ai](https://simplepod.ai/) for providing GPU servers

## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=yuvraj108c/ComfyUI_InvSR&type=Date)](https://star-history.com/#yuvraj108c/ComfyUI_InvSR&Date)


================================================
FILE: __init__.py
================================================
from .node import LoadInvSRModels, InvSRSampler
 
NODE_CLASS_MAPPINGS = { 
    "LoadInvSRModels" : LoadInvSRModels,
    "InvSRSampler" : InvSRSampler
}

NODE_DISPLAY_NAME_MAPPINGS = {
     "LoadInvSRModels" : "Load InvSR Models",
     "InvSRSampler" : "InvSRSampler"
}

__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']


================================================
FILE: comfyui_invsr_trimmed/__init__.py
================================================
from .inference_invsr import get_configs, Namespace
from .sampler_invsr import InvSamplerSR, BaseSampler
from .noise_predictor import NoisePredictor
from .time_aware_encoder import TimeAwareEncoder

__all__ = [
    "get_configs", 
    "Namespace",
    "InvSamplerSR", 
    "BaseSampler", 
    "NoisePredictor", 
    "TimeAwareEncoder"
]


================================================
FILE: comfyui_invsr_trimmed/inference_invsr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2023-03-11 17:17:41

import numpy as np
from pathlib import Path
from omegaconf import OmegaConf
from .sampler_invsr import InvSamplerSR, BaseSampler

from .utils import util_common
from .utils.util_opts import str2bool
from huggingface_hub import hf_hub_download
from shutil import copy2

class Namespace:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)
    
    def __repr__(self):
        items = [f"{key}={repr(value)}" for key, value in vars(self).items()]
        return f"Namespace({', '.join(items)})"

def get_configs(args, log=False):
    configs = OmegaConf.load(args.cfg_path)

    if args.timesteps is not None:
        assert len(args.timesteps) == args.num_steps
        configs.timesteps = sorted(args.timesteps, reverse=True)
    else:
        if args.num_steps == 1:
            configs.timesteps = [200,]
        elif args.num_steps == 2:
            configs.timesteps = [200, 100]
        elif args.num_steps == 3:
            configs.timesteps = [200, 100, 50]
        elif args.num_steps == 4:
            configs.timesteps = [200, 150, 100, 50]
        elif args.num_steps == 5:
            configs.timesteps = [250, 200, 150, 100, 50]
        else:
            assert args.num_steps <= 250
            configs.timesteps = np.linspace(
                start=args.started_step, stop=0, num=args.num_steps, endpoint=False, dtype=np.int64()
            ).tolist()
    if log:
        print(f'[InvSR] - Setting timesteps for inference: {configs.timesteps}')

    # path to save Stable Diffusion
    sd_path = args.sd_path if args.sd_path else "./weights"
    util_common.mkdir(sd_path, delete=False, parents=True)
    configs.sd_pipe.params.cache_dir = sd_path

    # path to save noise predictor
    started_ckpt_name = args.invsr_model

    if getattr(args, "started_ckpt_dir", None) is not None:
        started_ckpt_dir = args.started_ckpt_dir
    else:
        started_ckpt_dir = "./weights"

    if getattr(args, "started_ckpt_path", None) is not None:
        started_ckpt_path = args.started_ckpt_path
    else:
        started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name
        util_common.mkdir(started_ckpt_dir, delete=False, parents=True)

    if not Path(started_ckpt_path).exists():
        temp_path = hf_hub_download(
            repo_id="OAOA/InvSR",
            filename=started_ckpt_name,
        )
        copy2(temp_path, started_ckpt_path)
    configs.model_start.ckpt_path = str(started_ckpt_path)

    configs.bs = args.bs
    configs.tiled_vae = args.tiled_vae
    configs.color_fix = args.color_fix
    configs.basesr.chopping.pch_size = args.chopping_size
    if args.bs > 1:
        configs.basesr.chopping.extra_bs = 1
    else:
        configs.basesr.chopping.extra_bs = args.chopping_bs

    return configs


================================================
FILE: comfyui_invsr_trimmed/latent_lpips/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function



================================================
FILE: comfyui_invsr_trimmed/latent_lpips/lpips.py
================================================

from __future__ import absolute_import

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from . import pretrained_networks as pn

def normalize_tensor(in_feat,eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
    return in_feat/(norm_factor+eps)

def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2,3],keepdim=keepdim)

def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
    in_H, in_W = in_tens.shape[2], in_tens.shape[3]
    return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)

# Learned perceptual metric
class LPIPS(nn.Module):
    def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
        pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True,
                latent=False, in_chans=3, verbose=True):
        """ Initializes a perceptual loss torch.nn.Module

        Parameters (default listed first)
        ---------------------------------
        lpips : bool
            [True] use linear layers on top of base/trunk network
            [False] means no linear layers; each layer is averaged together
        pretrained : bool
            This flag controls the linear layers, which are only in effect when lpips=True above
            [True] means linear layers are calibrated with human perceptual judgments
            [False] means linear layers are randomly initialized
        pnet_rand : bool
            [False] means trunk loaded with ImageNet classification weights
            [True] means randomly initialized trunk
        net : str
            ['alex','vgg','squeeze'] are the base/trunk networks available
        version : str
            ['v0.1'] is the default and latest
            ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
        model_path : 'str'
            [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1

        The following parameters should only be changed if training the network

        eval_mode : bool
            [True] is for test mode (default)
            [False] is for training mode
        pnet_tune
            [False] keep base/trunk frozen
            [True] tune the base/trunk network
        use_dropout : bool
            [True] to use dropout when training linear layers
            [False] for no dropout when training linear layers
        """

        super(LPIPS, self).__init__()
        if(verbose):
            print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
                ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))

        self.pnet_type = net
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.latent = latent
        self.lpips = lpips # false means baseline of just averaging all layers
        self.version = version
        self.scaling_layer = ScalingLayer()

        if(self.pnet_type in ['vgg','vgg16']):
            if not latent:
                net_type = pn.vgg16
            else:
                net_type = pn.vgg16_latent
            self.chns = [64,128,256,512,512]
        elif(self.pnet_type=='alex'):
            net_type = pn.alexnet
            self.chns = [64,192,384,256,256]
        elif(self.pnet_type=='squeeze'):
            net_type = pn.squeezenet
            self.chns = [64,128,256,384,384,512,512]
        self.L = len(self.chns)

        if latent:
            self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune, in_chans=in_chans)
        else:
            self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if(lpips):
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
            if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins+=[self.lin5,self.lin6]
            self.lins = nn.ModuleList(self.lins)

            if(pretrained):
                if(model_path is None):
                    import inspect
                    import os
                    model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))

                if(verbose):
                    print('Loading model from: %s'%model_path)
                missing_keys, unexpected_keys = self.load_state_dict(
                        torch.load(model_path, map_location='cpu'),
                        strict=False,
                        )
                print(f'Number of missing keys when loading chckepoint: {len(missing_keys)}')
                print(f'Number of unexpected keys when loading chckepoint: {len(unexpected_keys)}')

        if(eval_mode):
            self.eval()

    def forward(self, in0, in1, retPerLayer=False, normalize=False):
        if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
            in0 = 2 * in0  - 1
            in1 = 2 * in1  - 1

        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if (not self.latent and self.version=='0.1') else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = 0
        for l in range(self.L):
            val += res[l]

        if(retPerLayer):
            return (val, res)
        else:
            return val

class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale

class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(),] if(use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class Dist2LogitLayer(nn.Module):
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
        if(use_sigmoid):
            layers += [nn.Sigmoid(),]
        self.model = nn.Sequential(*layers)

    def forward(self,d0,d1,eps=0.1):
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))

class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid, use_sigmoid=False)
        # self.parameters = list(self.net.parameters())
        # self.loss = torch.nn.BCELoss()
        self.loss = torch.nn.BCEWithLogitsLoss()

    def forward(self, d0, d1, judge):
        per = (judge+1.)/2.
        self.logit = self.net.forward(d0,d1)
        return self.loss(self.logit, per)

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network',net)
    print('Total number of parameters: %d' % num_params)


================================================
FILE: comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models as tv

class squeezenet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(squeezenet, self).__init__()
        pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()
        self.slice7 = torch.nn.Sequential()
        self.N_slices = 7
        for x in range(2):
            self.slice1.add_module(str(x), pretrained_features[x])
        for x in range(2,5):
            self.slice2.add_module(str(x), pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), pretrained_features[x])
        for x in range(10, 11):
            self.slice5.add_module(str(x), pretrained_features[x])
        for x in range(11, 12):
            self.slice6.add_module(str(x), pretrained_features[x])
        for x in range(12, 13):
            self.slice7.add_module(str(x), pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        h = self.slice6(h)
        h_relu6 = h
        h = self.slice7(h)
        h_relu7 = h
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)

        return out


class alexnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(alexnet, self).__init__()
        alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

        return out

class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out

class vgg16_latent(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, in_chans=3):
        super(vgg16_latent, self).__init__()
        vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        # max pooling layers: vgg_pretrained_features[5, 9, 16, 23]
        for x in range(4):
            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        if (not in_chans == 3):
            # assert in_chans == 4
            weight = self.slice1[0].weight.data[:, 0,].unsqueeze(1).repeat(1, in_chans, 1, 1)
            self.slice1[0].weight.data = weight
        for x in range(5, 9): # skip max pooling at index 5
            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(10, 16): # skip max pooling at index 9
            assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(17, 23): # skip max pooling at index 16
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out


class resnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, num=18):
        super(resnet, self).__init__()
        if(num==18):
            self.net = tv.resnet18(pretrained=pretrained)
        elif(num==34):
            self.net = tv.resnet34(pretrained=pretrained)
        elif(num==50):
            self.net = tv.resnet50(pretrained=pretrained)
        elif(num==101):
            self.net = tv.resnet101(pretrained=pretrained)
        elif(num==152):
            self.net = tv.resnet152(pretrained=pretrained)
        self.N_slices = 5

        self.conv1 = self.net.conv1
        self.bn1 = self.net.bn1
        self.relu = self.net.relu
        self.maxpool = self.net.maxpool
        self.layer1 = self.net.layer1
        self.layer2 = self.net.layer2
        self.layer3 = self.net.layer3
        self.layer4 = self.net.layer4

    def forward(self, X):
        h = self.conv1(X)
        h = self.bn1(h)
        h = self.relu(h)
        h_relu1 = h
        h = self.maxpool(h)
        h = self.layer1(h)
        h_conv2 = h
        h = self.layer2(h)
        h_conv3 = h
        h = self.layer3(h)
        h_conv4 = h
        h = self.layer4(h)
        h_conv5 = h

        outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)

        return out


================================================
FILE: comfyui_invsr_trimmed/noise_predictor.py
================================================
from typing import Dict, Optional, Tuple, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.autoencoders.vae import (
    Decoder,
    DecoderOutput,
    DiagonalGaussianDistribution,
    Encoder,
)
from diffusers.models.attention_processor import (
    ADDED_KV_ATTENTION_PROCESSORS,
    CROSS_ATTENTION_PROCESSORS,
    AttentionProcessor,
    AttnAddedKVProcessor,
    AttnProcessor,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.utils.accelerate_utils import apply_forward_hook
from .time_aware_encoder import TimeAwareEncoder

class NoisePredictor(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    r"""
    A noise predicted model from the encoder of AutoencoderKL.

    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
    for all models (such as downloading or saving).

    Parameters:
        in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
        down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            Tuple of downsample block types.
        up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
            Tuple of upsample block types.
        block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
            Tuple of block output channels.
        act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
        latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
        sample_size (`int`, *optional*, defaults to `32`): Sample input size.
        mid_block_add_attention (`bool`, *optional*, default to `True`):
            If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
            mid_block will only have resnet blocks
        temb_channels (`int`, *optional*, default to 256): Number of channels for time embedding
        freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
        flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
            Whether to flip sin to cos for Fourier time embedding.
        double_z (`bool`, *optional*, defaults to `True`):
            Whether to double the number of output channels for the last block.
    """

    _supports_gradient_checkpointing = True
    _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]

    @register_to_config
    def __init__(
        self,
        in_channels: int = 3,
        down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
        up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
        block_out_channels: Tuple[int] = (64,),
        layers_per_block: int = 1,
        act_fn: str = "silu",
        latent_channels: int = 4,
        norm_num_groups: int = 32,
        sample_size: int = 32,
        mid_block_add_attention: bool = True,
        attention_head_dim: int = 1,
        resnet_time_scale_shift: str = "default",
        temb_channels: int = 256,
        freq_shift: int = 0,
        flip_sin_to_cos: bool = True,
        double_z: bool = True,
    ):
        super().__init__()

        # pass init params to Encoder
        self.encoder = TimeAwareEncoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            down_block_types=down_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            act_fn=act_fn,
            norm_num_groups=norm_num_groups,
            double_z=double_z,
            mid_block_add_attention=mid_block_add_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            temb_channels=temb_channels,
            freq_shift=freq_shift,
            flip_sin_to_cos=flip_sin_to_cos,
            attention_head_dim=attention_head_dim,
        )

        self.use_slicing = False
        self.use_tiling = False
        self.double_z = double_z

        # only relevant if vae tiling is enabled
        self.tile_sample_min_size = self.config.sample_size
        sample_size = (
            self.config.sample_size[0]
            if isinstance(self.config.sample_size, (list, tuple))
            else self.config.sample_size
        )
        self.tile_latent_min_size = int(
            sample_size / (2 ** (len(self.config.block_out_channels) - 1))
        )
        self.tile_overlap_factor = 0.25

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, (Encoder, Decoder)):
            module.gradient_checkpointing = value

    def enable_tiling(self, use_tiling: bool = True):
        r"""
        Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
        compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
        processing larger images.
        """
        self.use_tiling = use_tiling

    def disable_tiling(self):
        r"""
        Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.enable_tiling(False)

    def enable_slicing(self):
        r"""
        Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
        compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
        """
        self.use_slicing = True

    def disable_slicing(self):
        r"""
        Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
        decoding in one step.
        """
        self.use_slicing = False

    @property
    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
    def attn_processors(self) -> Dict[str, AttentionProcessor]:
        r"""
        Returns:
            `dict` of attention processors: A dictionary containing all attention processors used in the model with
            indexed by its weight name.
        """
        # set recursively
        processors = {}

        def fn_recursive_add_processors(
            name: str,
            module: torch.nn.Module,
            processors: Dict[str, AttentionProcessor],
        ):
            if hasattr(module, "get_processor"):
                processors[f"{name}.processor"] = module.get_processor()

            for sub_name, child in module.named_children():
                fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)

            return processors

        for name, module in self.named_children():
            fn_recursive_add_processors(name, module, processors)

        return processors

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
    def set_attn_processor(
        self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
    ):
        r"""
        Sets the attention processor to use to compute attention.

        Parameters:
            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
                The instantiated processor class or a dictionary of processor classes that will be set as the processor
                for **all** `Attention` layers.

                If `processor` is a dict, the key needs to define the path to the corresponding cross attention
                processor. This is strongly recommended when setting trainable attention processors.

        """
        count = len(self.attn_processors.keys())

        if isinstance(processor, dict) and len(processor) != count:
            raise ValueError(
                f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
                f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
            )

        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
            if hasattr(module, "set_processor"):
                if not isinstance(processor, dict):
                    module.set_processor(processor)
                else:
                    module.set_processor(processor.pop(f"{name}.processor"))

            for sub_name, child in module.named_children():
                fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)

        for name, module in self.named_children():
            fn_recursive_attn_processor(name, module, processor)

    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
    def set_default_attn_processor(self):
        """
        Disables custom attention processors and sets the default attention implementation.
        """
        if all(
            proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
            for proc in self.attn_processors.values()
        ):
            processor = AttnAddedKVProcessor()
        elif all(
            proc.__class__ in CROSS_ATTENTION_PROCESSORS
            for proc in self.attn_processors.values()
        ):
            processor = AttnProcessor()
        else:
            raise ValueError(
                f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
            )

        self.set_attn_processor(processor)

    @apply_forward_hook
    def encode(
        self,
        x: torch.Tensor,
        timestep: Union[int, torch.Tensor],
        return_dict: bool = True,
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        Encode a batch of images into latents.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded images. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_tiling and (
            x.shape[-1] > self.tile_sample_min_size
            or x.shape[-2] > self.tile_sample_min_size
        ):
            return self.tiled_encode(x, timestep, return_dict=return_dict)

        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self.encoder(x_slice, timestep) for x_slice in x.split(1)]
            h = torch.cat(encoded_slices)
        else:
            h = self.encoder(x, timestep)

        if not self.double_z:
            return h

        posterior = DiagonalGaussianDistribution(h)

        if not return_dict:
            return (posterior,)

        return AutoencoderKLOutput(latent_dist=posterior)

    def tiled_encode(
        self,
        x: torch.Tensor,
        timestep: Union[int, torch.Tensor],
        return_dict: bool = True,
    ) -> AutoencoderKLOutput:
        r"""Encode a batch of images using a tiled encoder.

        When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
        steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
        different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
        tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
        output, but they should be much less noticeable.

        Args:
            x (`torch.Tensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
            [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
                If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
                `tuple` is returned.
        """
        overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
        blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
        row_limit = self.tile_latent_min_size - blend_extent

        # Split the image into 512x512 tiles and encode them separately.
        rows = []
        for i in range(0, x.shape[2], overlap_size):
            row = []
            for j in range(0, x.shape[3], overlap_size):
                tile = x[
                    :,
                    :,
                    i : i + self.tile_sample_min_size,
                    j : j + self.tile_sample_min_size,
                ]
                tile = self.encoder(tile, timestep)
                if self.config.use_quant_conv:
                    tile = self.quant_conv(tile)
                row.append(tile)
            rows.append(row)
        result_rows = []
        for i, row in enumerate(rows):
            result_row = []
            for j, tile in enumerate(row):
                # blend the above tile and the left tile
                # to the current tile and add the current tile to the result row
                if i > 0:
                    tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
                if j > 0:
                    tile = self.blend_h(row[j - 1], tile, blend_extent)
                result_row.append(tile[:, :, :row_limit, :row_limit])
            result_rows.append(torch.cat(result_row, dim=3))

        moments = torch.cat(result_rows, dim=2)
        posterior = DiagonalGaussianDistribution(moments)

        if not return_dict:
            return (posterior,)

        return AutoencoderKLOutput(latent_dist=posterior)

    def forward(
        self,
        sample: torch.Tensor,
        timesteps: torch.Tensor,
        sample_posterior: bool = True,
        center_input_sample: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, torch.Tensor]:
        r"""
        Args:
            sample (`torch.Tensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        if center_input_sample:
            sample = sample * 2 - 1.0

        if not self.double_z:
            h = self.encode(sample, timesteps)
            return h
        else:
            posterior = self.encode(sample, timesteps).latent_dist

        if sample_posterior:
            return posterior.sample()
        else:
            return posterior


================================================
FILE: comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
    PIL_INTERPOLATION,
    USE_PEFT_BACKEND,
    deprecate,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
    StableDiffusionSafetyChecker,
)


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

EXAMPLE_DOC_STRING = """
    Examples:
        ```py
        >>> import requests
        >>> import torch
        >>> from PIL import Image
        >>> from io import BytesIO

        >>> from diffusers import StableDiffusionImg2ImgPipeline

        >>> device = "cuda"
        >>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
        >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
        >>> pipe = pipe.to(device)

        >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

        >>> response = requests.get(url)
        >>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
        >>> init_image = init_image.resize((768, 512))

        >>> prompt = "A fantasy landscape, trending on artstation"

        >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
        >>> images[0].save("fantasy_landscape.png")
        ```
"""


def retrieve_latents(
    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
        return encoder_output.latent_dist.sample(generator)
    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
        return encoder_output.latent_dist.mode()
    elif hasattr(encoder_output, "latents"):
        return encoder_output.latents
    else:
        raise AttributeError("Could not access latents of provided encoder_output")


def preprocess(image):
    deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
    deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
    if isinstance(image, torch.Tensor):
        return image
    elif isinstance(image, PIL.Image.Image):
        image = [image]

    if isinstance(image[0], PIL.Image.Image):
        w, h = image[0].size
        w, h = (x - x % 8 for x in (w, h))  # resize to integer multiple of 8

        image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
        image = np.concatenate(image, axis=0)
        image = np.array(image).astype(np.float32) / 255.0
        image = image.transpose(0, 3, 1, 2)
        image = 2.0 * image - 1.0
        image = torch.from_numpy(image)
    elif isinstance(image[0], torch.Tensor):
        image = torch.cat(image, dim=0)
    return image


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
    scheduler,
    timesteps: Optional[List[int]] = None,
    device: Optional[Union[str, torch.device]] = None,
    **kwargs,
):
    """
    Prepare the sampling timesteps and noise sigmas.

    Args:
        scheduler (`SchedulerMixin`):
            The scheduler to get timesteps from.
        num_inference_steps (`int`):
            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
            must be `None`.
        device (`str` or `torch.device`, *optional*):
            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
        timesteps (`List[int]`, *optional*):
            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
            `num_inference_steps` and `sigmas` must be `None`.
        sigmas (`List[float]`, *optional*):
            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
            `num_inference_steps` and `timesteps` must be `None`.

    Returns:
        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
        second element is the number of inference steps.
    """
    num_inference_steps = len(timesteps)
    timesteps = torch.tensor(timesteps, dtype=torch.float32, device=device) - 1
    scheduler.timesteps = timesteps

    if not hasattr(scheduler, 'sigmas_cache'):
        scheduler.sigmas_cache = scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000
    sigmas = scheduler.sigmas_cache[timesteps.long()]

    # minimal sigma
    if scheduler.config.final_sigmas_type == "sigma_min":
        sigma_last = ((1 - scheduler.alphas_cumprod[0]) / scheduler.alphas_cumprod[0]) ** 0.5
    elif scheduler.config.final_sigmas_type == "zero":
        sigma_last = 0
    else:
        raise ValueError(
            f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {scheduler.config.final_sigmas_type}"
        )
    sigma_last = torch.tensor([sigma_last,], dtype=torch.float32).to(device=sigmas.device)
    sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32)
    scheduler.sigmas = sigmas.to("cpu")  # to avoid too much CPU/GPU communication

    scheduler._step_index = None
    scheduler._begin_index = None

    return scheduler.timesteps, num_inference_steps

class StableDiffusionInvEnhancePipeline(
    DiffusionPipeline,
    StableDiffusionMixin,
    TextualInversionLoaderMixin,
    IPAdapterMixin,
    StableDiffusionLoraLoaderMixin,
    FromSingleFileMixin,
):
    r"""
    Pipeline for text-guided image-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
    implemented for all pipelines (downloading, saving, running on a particular device, etc.).

    The pipeline also inherits the following loading methods:
        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
        - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
        - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
        text_encoder ([`~transformers.CLIPTextModel`]):
            Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
        tokenizer ([`~transformers.CLIPTokenizer`]):
            A `CLIPTokenizer` to tokenize text.
        unet ([`UNet2DConditionModel`]):
            A `UNet2DConditionModel` to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offensive or harmful.
            Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
            about a model's potential harms.
        feature_extractor ([`~transformers.CLIPImageProcessor`]):
            A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
    """

    model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
    _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
    _exclude_from_cpu_offload = ["safety_checker"]
    _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        safety_checker: StableDiffusionSafetyChecker,
        feature_extractor: CLIPImageProcessor,
        image_encoder: CLIPVisionModelWithProjection = None,
        requires_safety_checker: bool = True,
    ):
        super().__init__()

        if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
                f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
                "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
                " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
                " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
                " file"
            )
            deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["steps_offset"] = 1
            scheduler._internal_dict = FrozenDict(new_config)

        if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
            deprecation_message = (
                f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
                " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
                " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
                " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
                " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
            )
            deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(scheduler.config)
            new_config["clip_sample"] = False
            scheduler._internal_dict = FrozenDict(new_config)

        if safety_checker is None and requires_safety_checker:
            logger.warning(
                f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
                " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
                " results in services or applications open to the public. Both the diffusers team and Hugging Face"
                " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
                " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
                " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
            )

        if safety_checker is not None and feature_extractor is None:
            raise ValueError(
                "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
                " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
            )

        is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
            version.parse(unet.config._diffusers_version).base_version
        ) < version.parse("0.9.0.dev0")
        is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
            deprecation_message = (
                "The configuration file of the unet has set the default `sample_size` to smaller than"
                " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
                " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
                " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
                " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
                " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
                " in the config might lead to incorrect results in future versions. If you have downloaded this"
                " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
                " the `unet/config.json` file"
            )
            deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
            new_config = dict(unet.config)
            new_config["sample_size"] = 64
            unet._internal_dict = FrozenDict(new_config)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
            image_encoder=image_encoder,
        )
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
        self.register_to_config(requires_safety_checker=requires_safety_checker)

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        **kwargs,
    ):
        deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
        deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)

        prompt_embeds_tuple = self.encode_prompt(
            prompt=prompt,
            device=device,
            num_images_per_prompt=num_images_per_prompt,
            do_classifier_free_guidance=do_classifier_free_guidance,
            negative_prompt=negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=lora_scale,
            **kwargs,
        )

        # concatenate for backwards comp
        prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])

        return prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
    def encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        lora_scale: Optional[float] = None,
        clip_skip: Optional[int] = None,
    ):
        r"""
        Encodes the prompt into text encoder hidden states.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                prompt to be encoded
            device: (`torch.device`):
                torch device
            num_images_per_prompt (`int`):
                number of images that should be generated per prompt
            do_classifier_free_guidance (`bool`):
                whether to use classifier free guidance or not
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts not to guide the image generation. If not defined, one has to pass
                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
                less than `1`).
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
                provided, text embeddings will be generated from `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
                argument.
            lora_scale (`float`, *optional*):
                A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
        """
        # set lora scale so that monkey patched LoRA
        # function of text encoder can correctly access it
        if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
            self._lora_scale = lora_scale

            # dynamically adjust the LoRA scale
            if not USE_PEFT_BACKEND:
                adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
            else:
                scale_lora_layers(self.text_encoder, lora_scale)

        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        if prompt_embeds is None:
            # textual inversion: process multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)

            text_inputs = self.tokenizer(
                prompt,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                truncation=True,
                return_tensors="pt",
            )
            text_input_ids = text_inputs.input_ids
            untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids

            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
                text_input_ids, untruncated_ids
            ):
                removed_text = self.tokenizer.batch_decode(
                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
                )
                logger.warning(
                    "The following part of your input was truncated because CLIP can only handle sequences up to"
                    f" {self.tokenizer.model_max_length} tokens: {removed_text}"
                )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = text_inputs.attention_mask.to(device)
            else:
                attention_mask = None

            if clip_skip is None:
                prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
                prompt_embeds = prompt_embeds[0]
            else:
                prompt_embeds = self.text_encoder(
                    text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
                )
                # Access the `hidden_states` first, that contains a tuple of
                # all the hidden states from the encoder layers. Then index into
                # the tuple to access the hidden states from the desired layer.
                prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
                # We also need to apply the final LayerNorm here to not mess with the
                # representations. The `last_hidden_states` that we typically use for
                # obtaining the final prompt representations passes through the LayerNorm
                # layer.
                prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)

        if self.text_encoder is not None:
            prompt_embeds_dtype = self.text_encoder.dtype
        elif self.unet is not None:
            prompt_embeds_dtype = self.unet.dtype
        else:
            prompt_embeds_dtype = prompt_embeds.dtype

        prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        bs_embed, seq_len, _ = prompt_embeds.shape
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # get unconditional embeddings for classifier free guidance
        if do_classifier_free_guidance and negative_prompt_embeds is None:
            uncond_tokens: List[str]
            if negative_prompt is None:
                uncond_tokens = [""] * batch_size
            elif prompt is not None and type(prompt) is not type(negative_prompt):
                raise TypeError(
                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
                    f" {type(prompt)}."
                )
            elif isinstance(negative_prompt, str):
                uncond_tokens = [negative_prompt]
            elif batch_size != len(negative_prompt):
                raise ValueError(
                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
                    " the batch size of `prompt`."
                )
            else:
                uncond_tokens = negative_prompt

            # textual inversion: process multi-vector tokens if necessary
            if isinstance(self, TextualInversionLoaderMixin):
                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)

            max_length = prompt_embeds.shape[1]
            uncond_input = self.tokenizer(
                uncond_tokens,
                padding="max_length",
                max_length=max_length,
                truncation=True,
                return_tensors="pt",
            )

            if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
                attention_mask = uncond_input.attention_mask.to(device)
            else:
                attention_mask = None

            negative_prompt_embeds = self.text_encoder(
                uncond_input.input_ids.to(device),
                attention_mask=attention_mask,
            )
            negative_prompt_embeds = negative_prompt_embeds[0]

        if do_classifier_free_guidance:
            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
            seq_len = negative_prompt_embeds.shape[1]

            negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

        if self.text_encoder is not None:
            if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
                # Retrieve the original scale by scaling back the LoRA layers
                unscale_lora_layers(self.text_encoder, lora_scale)

        return prompt_embeds, negative_prompt_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
        dtype = next(self.image_encoder.parameters()).dtype

        if not isinstance(image, torch.Tensor):
            image = self.feature_extractor(image, return_tensors="pt").pixel_values

        image = image.to(device=device, dtype=dtype)
        if output_hidden_states:
            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
            uncond_image_enc_hidden_states = self.image_encoder(
                torch.zeros_like(image), output_hidden_states=True
            ).hidden_states[-2]
            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
                num_images_per_prompt, dim=0
            )
            return image_enc_hidden_states, uncond_image_enc_hidden_states
        else:
            image_embeds = self.image_encoder(image).image_embeds
            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
            uncond_image_embeds = torch.zeros_like(image_embeds)

            return image_embeds, uncond_image_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
    def prepare_ip_adapter_image_embeds(
        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
    ):
        image_embeds = []
        if do_classifier_free_guidance:
            negative_image_embeds = []
        if ip_adapter_image_embeds is None:
            if not isinstance(ip_adapter_image, list):
                ip_adapter_image = [ip_adapter_image]

            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
                raise ValueError(
                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                )

            for single_ip_adapter_image, image_proj_layer in zip(
                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
            ):
                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
                single_image_embeds, single_negative_image_embeds = self.encode_image(
                    single_ip_adapter_image, device, 1, output_hidden_state
                )

                image_embeds.append(single_image_embeds[None, :])
                if do_classifier_free_guidance:
                    negative_image_embeds.append(single_negative_image_embeds[None, :])
        else:
            for single_image_embeds in ip_adapter_image_embeds:
                if do_classifier_free_guidance:
                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
                    negative_image_embeds.append(single_negative_image_embeds)
                image_embeds.append(single_image_embeds)

        ip_adapter_image_embeds = []
        for i, single_image_embeds in enumerate(image_embeds):
            single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
            if do_classifier_free_guidance:
                single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
                single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)

            single_image_embeds = single_image_embeds.to(device=device)
            ip_adapter_image_embeds.append(single_image_embeds)

        return ip_adapter_image_embeds

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
    def run_safety_checker(self, image, device, dtype):
        if self.safety_checker is None:
            has_nsfw_concept = None
        else:
            if torch.is_tensor(image):
                feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
            else:
                feature_extractor_input = self.image_processor.numpy_to_pil(image)
            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
            image, has_nsfw_concept = self.safety_checker(
                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
            )
        return image, has_nsfw_concept

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
    def decode_latents(self, latents):
        deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
        deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)

        latents = 1 / self.vae.config.scaling_factor * latents
        image = self.vae.decode(latents, return_dict=False)[0]
        image = (image / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        return image

    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
    def prepare_extra_step_kwargs(self, generator, eta):
        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
        # and should be between [0, 1]

        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
        extra_step_kwargs = {}
        if accepts_eta:
            extra_step_kwargs["eta"] = eta

        # check if the scheduler accepts generator
        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
        if accepts_generator:
            extra_step_kwargs["generator"] = generator
        return extra_step_kwargs

    def check_inputs(
        self,
        prompt,
        strength,
        callback_steps,
        negative_prompt=None,
        prompt_embeds=None,
        negative_prompt_embeds=None,
        ip_adapter_image=None,
        ip_adapter_image_embeds=None,
        callback_on_step_end_tensor_inputs=None,
    ):
        if strength < 0 or strength > 1:
            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")

        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
            raise ValueError(
                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                f" {type(callback_steps)}."
            )

        if callback_on_step_end_tensor_inputs is not None and not all(
            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
        ):
            raise ValueError(
                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
            )
        if prompt is not None and prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
                " only forward one of the two."
            )
        elif prompt is None and prompt_embeds is None:
            raise ValueError(
                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
            )
        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")

        if negative_prompt is not None and negative_prompt_embeds is not None:
            raise ValueError(
                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
            )

        if prompt_embeds is not None and negative_prompt_embeds is not None:
            if prompt_embeds.shape != negative_prompt_embeds.shape:
                raise ValueError(
                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
                    f" {negative_prompt_embeds.shape}."
                )

        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
            raise ValueError(
                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
            )

        if ip_adapter_image_embeds is not None:
            if not isinstance(ip_adapter_image_embeds, list):
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
                )
            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
                raise ValueError(
                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
                )

    def get_timesteps(self, num_inference_steps, strength, device):
        # get the original timestep using init_timestep
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)

        t_start = max(num_inference_steps - init_timestep, 0)
        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
        if hasattr(self.scheduler, "set_begin_index"):
            self.scheduler.set_begin_index(t_start * self.scheduler.order)

        return timesteps, num_inference_steps - t_start

    def prepare_latents(
        self, image, timestep, batch_size, num_images_per_prompt, dtype, device,
        noise=None, generator=None,
    ):
        if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
            raise ValueError(
                f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
            )

        image = image.to(device=device, dtype=dtype)

        batch_size = batch_size * num_images_per_prompt

        if image.shape[1] == 4:
            init_latents = image

        else:
            if isinstance(generator, list) and len(generator) != batch_size:
                raise ValueError(
                    f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
                    f" size of {batch_size}. Make sure the batch size matches the length of the generators."
                )

            elif isinstance(generator, list):
                if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
                    image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
                elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
                    raise ValueError(
                        f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
                    )

                init_latents = [
                    retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
                    for i in range(batch_size)
                ]
                init_latents = torch.cat(init_latents, dim=0)
            else:
                init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

            init_latents = self.vae.config.scaling_factor * init_latents

        if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
            # expand init_latents for batch_size
            deprecation_message = (
                f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
                " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
                " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
                " your script to pass as many initial images as text prompts to suppress this warning."
            )
            deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
            additional_image_per_prompt = batch_size // init_latents.shape[0]
            init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
        elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
            raise ValueError(
                f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
            )
        else:
            init_latents = torch.cat([init_latents], dim=0)

        shape = init_latents.shape
        if noise is None:
            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

        # get latents
        init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
        latents = init_latents

        return latents.type(dtype)

    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
    def get_guidance_scale_embedding(
        self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
    ) -> torch.Tensor:
        """
        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298

        Args:
            w (`torch.Tensor`):
                Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
            embedding_dim (`int`, *optional*, defaults to 512):
                Dimension of the embeddings to generate.
            dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
                Data type of the generated embeddings.

        Returns:
            `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
        """
        assert len(w.shape) == 1
        w = w * 1000.0

        half_dim = embedding_dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
        emb = w.to(dtype)[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = torch.nn.functional.pad(emb, (0, 1))
        assert emb.shape == (w.shape[0], embedding_dim)
        return emb

    @property
    def guidance_scale(self):
        return self._guidance_scale

    @property
    def clip_skip(self):
        return self._clip_skip

    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    @property
    def do_classifier_free_guidance(self):
        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None

    @property
    def cross_attention_kwargs(self):
        return self._cross_attention_kwargs

    @property
    def num_timesteps(self):
        return self._num_timesteps

    @property
    def interrupt(self):
        return self._interrupt

    @torch.no_grad()
    @replace_example_docstring(EXAMPLE_DOC_STRING)
    def __call__(
        self,
        prompt: Union[str, List[str]] = None,
        image: PipelineImageInput = None,
        target_size: Tuple[int] = (512, 512),
        strength: float = 0.25,
        num_inference_steps: Optional[int] = 20,
        timesteps: List[int] = None,
        sigmas: List[float] = None,
        guidance_scale: Optional[float] = 7.5,
        negative_prompt: Optional[Union[str, List[str]]] = None,
        num_images_per_prompt: Optional[int] = 1,
        eta: Optional[float] = 0.0,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        prompt_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        ip_adapter_image: Optional[PipelineImageInput] = None,
        ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
        output_type: Optional[str] = "pil",
        return_dict: bool = True,
        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
        clip_skip: int = None,
        callback_on_step_end: Optional[
            Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
        ] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        **kwargs,
    ):
        r"""
        The call function to the pipeline for generation.

        Args:
            prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
                `Image`, numpy array or tensor representing an low quality image batch to be used as the starting point. For both
                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
                latents as `image`, but if passing latents directly it is not encoded again.
            target_size ('Tuple[int]'): Targeted image resolution (height, width)
            strength (`float`, *optional*, defaults to 0.25):
                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
                starting point and more noise is added the higher the `strength`. The number of denoising steps depends
                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
                essentially ignores `image`.
            num_inference_steps (`int`, *optional*, defaults to 20):
                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
                expense of slower inference. This parameter is modulated by `strength`.
            timesteps (`List[int]`, *optional*):
                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
                passed will be used. Must be in descending order.
            sigmas (`List[float]`, *optional*):
                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
                will be used.
            start_noise_predictor ('nn.Module', *optional*): Noise predictor for the initial step
            intermediate_noise_predictor ('nn.Module', *optional*): Noise predictor for the intermediate steps.
            guidance_scale (`float`, *optional*, defaults to 7.5):
                A higher guidance scale value encourages the model to generate images closely linked to the text
                `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
            negative_prompt (`str` or `List[str]`, *optional*):
                The prompt or prompts to guide what to not include in image generation. If not defined, you need to
                pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
            num_images_per_prompt (`int`, *optional*, defaults to 1):
                The number of images to generate per prompt.
            eta (`float`, *optional*, defaults to 0.0):
                Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
                to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
                A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
                generation deterministic.
            prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
                provided, text embeddings are generated from the `prompt` input argument.
            negative_prompt_embeds (`torch.Tensor`, *optional*):
                Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
                not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
            ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
                IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
                contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
                provided, embeddings are computed from the `ip_adapter_image` input argument.
            output_type (`str`, *optional*, defaults to `"pil"`):
                The output format of the generated image. Choose between `PIL.Image` or `np.array`.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
                plain tuple.
            cross_attention_kwargs (`dict`, *optional*):
                A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
                [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
            clip_skip (`int`, *optional*):
                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
                the output of the pre-final layer will be used for computing the prompt embeddings.
            callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
                A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
                each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
                DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
                list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
            callback_on_step_end_tensor_inputs (`List`, *optional*):
                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
                `._callback_tensor_inputs` attribute of your pipeline class.
        Examples:

        Returns:
            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
                If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
                otherwise a `tuple` is returned where the first element is a list with the generated images and the
                second element is a list of `bool`s indicating whether the corresponding generated image contains
                "not-safe-for-work" (nsfw) content.
        """

        callback = kwargs.pop("callback", None)
        callback_steps = kwargs.pop("callback_steps", None)

        if callback is not None:
            deprecate(
                "callback",
                "1.0.0",
                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )
        if callback_steps is not None:
            deprecate(
                "callback_steps",
                "1.0.0",
                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
            )

        if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
            callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

        # 1. Check inputs. Raise error if not correct
        self.check_inputs(
            prompt,
            strength,
            callback_steps,
            negative_prompt,
            prompt_embeds,
            negative_prompt_embeds,
            ip_adapter_image,
            ip_adapter_image_embeds,
            callback_on_step_end_tensor_inputs,
        )

        self._guidance_scale = guidance_scale
        self._clip_skip = clip_skip
        self._cross_attention_kwargs = cross_attention_kwargs
        self._interrupt = False

        # 2. Define call parameters
        if prompt is not None and isinstance(prompt, str):
            batch_size = 1
        elif prompt is not None and isinstance(prompt, list):
            batch_size = len(prompt)
        else:
            batch_size = prompt_embeds.shape[0]

        device = self._execution_device

        # 3. Encode input prompt
        text_encoder_lora_scale = (
            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
        )
        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            device,
            num_images_per_prompt,
            self.do_classifier_free_guidance,
            negative_prompt,
            prompt_embeds=prompt_embeds,
            negative_prompt_embeds=negative_prompt_embeds,
            lora_scale=text_encoder_lora_scale,
            clip_skip=self.clip_skip,
        )
        # For classifier free guidance, we need to do two forward passes.
        # Here we concatenate the unconditional and text embeddings into a single batch
        # to avoid doing two forward passes
        if self.do_classifier_free_guidance:
            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
            image_embeds = self.prepare_ip_adapter_image_embeds(
                ip_adapter_image,
                ip_adapter_image_embeds,
                device,
                batch_size * num_images_per_prompt,
                self.do_classifier_free_guidance,
            )

        # 4. Preprocess image
        self.image_processor.config.do_normalize = False
        image = self.image_processor.preprocess(image)  # [0, 1], torch tensor, (b,c,h,w)
        self.image_processor.config.do_normalize = True
        image_up = torch.nn.functional.interpolate(image, size=target_size, mode='bicubic') # upsampling
        image_up = self.image_processor.normalize(image_up)  # [-1, 1]

        # 5. set timesteps
        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, timesteps, device)
        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

        # 6. Prepare latent variables
        if getattr(self, 'start_noise_predictor', None) is not None:
            with torch.amp.autocast('cuda'):
                noise = self.start_noise_predictor(
                    image, latent_timestep, sample_posterior=True, center_input_sample=True,
                )
        else:
            noise = None
        latents = self.prepare_latents(
            image_up, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype,
            device, noise, generator,
        )

        # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

        # 7.1 Add image embeds for IP-Adapter
        added_cond_kwargs = (
            {"image_embeds": image_embeds}
            if ip_adapter_image is not None or ip_adapter_image_embeds is not None
            else None
        )

        # 7.2 Optionally get Guidance Scale Embedding
        timestep_cond = None
        if self.unet.config.time_cond_proj_dim is not None:
            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
            timestep_cond = self.get_guidance_scale_embedding(
                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
            ).to(device=device, dtype=latents.dtype)

        # 8. Denoising loop
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                # compute the previous noisy sample x_t -> x_t-1
                if getattr(self, 'intermediate_noise_predictor', None) is not None and i + 1 < len(timesteps):
                    t_next = timesteps[i+1]
                    with torch.amp.autocast('cuda'):
                        noise = self.intermediate_noise_predictor(image, t_next, center_input_sample=True)
                    extra_step_kwargs['noise'] = noise
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

        if not output_type == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
                0
            ]
            image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
        else:
            image = latents
            has_nsfw_concept = None

        if has_nsfw_concept is None:
            do_denormalize = [True] * image.shape[0]
        else:
            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]

        image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)

        # Offload all models
        self.maybe_free_model_hooks()

        if not return_dict:
            return (image, has_nsfw_concept)

        return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)


================================================
FILE: comfyui_invsr_trimmed/sampler_invsr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-13 16:59:27

import os, sys, math, random

import numpy as np
from pathlib import Path

from .utils import util_net
from .utils import util_image
from .utils import util_common
from .utils import util_color_fix

import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mean_psnr

from .pipeline_stable_diffusion_inversion_sr import StableDiffusionInvEnhancePipeline
from diffusers import AutoencoderKL
import comfy.model_management as mm
DEVICE = mm.get_torch_device()

_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
           'meticulous detailing, hyper sharpness, perfect without deformations'
_negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
           'painting, drawing, sketch, oil painting'

class BaseSampler:
    def __init__(self, configs):
        '''
        Input:
            configs: config, see the yaml file in folder ./configs/
                configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}
            seed: int, random seed
        '''
        self.configs = configs

        self.setup_seed()

        self.build_model()

    def setup_seed(self, seed=None):
        seed = self.configs.seed if seed is None else seed
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    def write_log(self, log_str):
        print(log_str, flush=True)

    def build_model(self):
        # Build Stable diffusion
        params = dict(self.configs.sd_pipe.params)
        torch_dtype = params.pop('torch_dtype')
        params['torch_dtype'] = get_torch_dtype(torch_dtype)
        base_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)
        if self.configs.get('scheduler', None) is not None:
            pipe_id = self.configs.scheduler.target.split('.')[-1]
            self.write_log(f'Loading scheduler of {pipe_id}...')
            base_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(
                base_pipe.scheduler.config
            )
            self.write_log('Loaded Done')
        if self.configs.get('vae_fp16', None) is not None:
            params_vae = dict(self.configs.vae_fp16.params)
            torch_dtype = params_vae.pop('torch_dtype')
            params_vae['torch_dtype'] = get_torch_dtype(torch_dtype)
            pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path
            self.write_log(f'Loading improved vae from {pipe_id}...')
            base_pipe.vae = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(
                **params_vae,
            )
            self.write_log('Loaded Done')
        if self.configs.base_model in ['sd-turbo', 'sd2base'] :
            sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
        else:
            raise ValueError(f"Unsupported base model: {self.configs.base_model}!")
        sd_pipe.to(DEVICE)
        if self.configs.sliced_vae:
            sd_pipe.vae.enable_slicing()
        if self.configs.tiled_vae:
            sd_pipe.vae.enable_tiling()
            sd_pipe.vae.tile_latent_min_size = self.configs.latent_tiled_size
            sd_pipe.vae.tile_sample_min_size = self.configs.sample_tiled_size
        if self.configs.gradient_checkpointing_vae:
            self.write_log(f"Activating gradient checkpoing for vae...")
            sd_pipe.vae.enable_gradient_checkpointing()

        model_configs = self.configs.model_start
        params = model_configs.get('params', dict)
        model_start = util_common.get_obj_from_str(model_configs.target)(**params)
        model_start.to(DEVICE)
        ckpt_path = model_configs.get('ckpt_path')
        assert ckpt_path is not None
        self.write_log(f"[InvSR] - Loading started model from {ckpt_path}...")
        state = torch.load(ckpt_path, map_location=DEVICE)
        if 'state_dict' in state:
            state = state['state_dict']
        util_net.reload_model(model_start, state)
        # self.write_log(f"Loading Done")
        model_start.eval()
        setattr(sd_pipe, 'start_noise_predictor', model_start)

        self.sd_pipe = sd_pipe

class InvSamplerSR(BaseSampler):
    def __init__(self, base_sampler):
        self.configs = base_sampler.configs
        self.sd_pipe = base_sampler.sd_pipe
        
    @torch.no_grad()
    def sample_func(self, im_cond):
        '''
        Input:
            im_cond: b x c x h x w, torch tensor, [0,1], RGB
        Output:
            xt: h x w x c, numpy array, [0,1], RGB
        '''

        

        ori_h_lq, ori_w_lq = im_cond.shape[-2:]
        ori_w_hq = ori_w_lq * self.configs.basesr.sf
        ori_h_hq = ori_h_lq * self.configs.basesr.sf
        vae_sf = (2 ** (len(self.sd_pipe.vae.config.block_out_channels) - 1))
        if hasattr(self.sd_pipe, 'unet'):
            diffusion_sf = (2 ** (len(self.sd_pipe.unet.config.block_out_channels) - 1))
        else:
            diffusion_sf = self.sd_pipe.transformer.patch_size
        mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
        idle_pch_size = self.configs.basesr.chopping.pch_size

        if min(im_cond.shape[-2:]) >= idle_pch_size:
            pad_h_up = pad_w_left = 0
        else:
            while min(im_cond.shape[-2:]) < idle_pch_size:
                pad_h_up = max(min((idle_pch_size - im_cond.shape[-2]) // 2, im_cond.shape[-2]-1), 0)
                pad_h_down = max(min(idle_pch_size - im_cond.shape[-2] - pad_h_up, im_cond.shape[-2]-1), 0)
                pad_w_left = max(min((idle_pch_size - im_cond.shape[-1]) // 2, im_cond.shape[-1]-1), 0)
                pad_w_right = max(min(idle_pch_size - im_cond.shape[-1] - pad_w_left, im_cond.shape[-1]-1), 0)
                im_cond = F.pad(im_cond, pad=(pad_w_left, pad_w_right, pad_h_up, pad_h_down), mode='reflect')

        if im_cond.shape[-2] == idle_pch_size and im_cond.shape[-1] == idle_pch_size:
            target_size = (
                im_cond.shape[-2] * self.configs.basesr.sf,
                im_cond.shape[-1] * self.configs.basesr.sf
            )
            res_sr = self.sd_pipe(
                image=im_cond.type(torch.float16),
                prompt=[_positive, ]*im_cond.shape[0],
                negative_prompt=[_negative, ]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None,
                target_size=target_size,
                timesteps=self.configs.timesteps,
                guidance_scale=self.configs.cfg_scale,
                output_type="pt",    # torch tensor, b x c x h x w, [0, 1]
            ).images
        else:
            if not (im_cond.shape[-2] % mod_lq == 0 and im_cond.shape[-1] % mod_lq == 0):
                target_h_lq = math.ceil(im_cond.shape[-2] / mod_lq) * mod_lq
                target_w_lq = math.ceil(im_cond.shape[-1] / mod_lq) * mod_lq
                pad_h = target_h_lq - im_cond.shape[-2]
                pad_w = target_w_lq - im_cond.shape[-1]
                im_cond= F.pad(im_cond, pad=(0, pad_w, 0, pad_h), mode='reflect')

            im_spliter = util_image.ImageSpliterTh(
                im_cond,
                pch_size=idle_pch_size,
                stride= int(idle_pch_size * 0.50),
                sf=self.configs.basesr.sf,
                weight_type=self.configs.basesr.chopping.weight_type,
                extra_bs=self.configs.basesr.chopping.extra_bs,
            )
            
            # pbar = ProgressBar(len(im_spliter) * im_cond.shape[0])

            for im_lq_pch, index_infos in im_spliter:
                target_size = (
                    im_lq_pch.shape[-2] * self.configs.basesr.sf,
                    im_lq_pch.shape[-1] * self.configs.basesr.sf,
                )

                # start = torch.cuda.Event(enable_timing=True)
                # end = torch.cuda.Event(enable_timing=True)
                # start.record()

                res_sr_pch = self.sd_pipe(
                    image=im_lq_pch.type(torch.float16),
                    prompt=[_positive, ]*im_lq_pch.shape[0],
                    negative_prompt=[_negative, ]*im_lq_pch.shape[0] if self.configs.cfg_scale > 1.0 else None,
                    target_size=target_size,
                    timesteps=self.configs.timesteps,
                    guidance_scale=self.configs.cfg_scale,
                    output_type="pt",    # torch tensor, b x c x h x w, [0, 1]
                ).images

                # end.record()
                # torch.cuda.synchronize()
                # print(f"Time: {start.elapsed_time(end):.6f}")

                im_spliter.update(res_sr_pch, index_infos)
                # pbar.update(im_lq_pch.shape[0])
            res_sr = im_spliter.gather()

        pad_h_up *= self.configs.basesr.sf
        pad_w_left *= self.configs.basesr.sf
        res_sr = res_sr[:, :, pad_h_up:ori_h_hq+pad_h_up, pad_w_left:ori_w_hq+pad_w_left]

        if self.configs.color_fix:
            im_cond_up = F.interpolate(
                im_cond, size=res_sr.shape[-2:], mode='bicubic', align_corners=False, antialias=True
            )
            if self.configs.color_fix == 'ycbcr':
                res_sr = util_color_fix.ycbcr_color_replace(res_sr, im_cond_up)
            elif self.configs.color_fix == 'wavelet':
                res_sr = util_color_fix.wavelet_reconstruction(res_sr, im_cond_up)
            else:
                raise ValueError(f"Unsupported color fixing type: {self.configs.color_fix}")

        res_sr = res_sr.clamp(0.0, 1.0).cpu().float().numpy()

        return res_sr

    def inference(self, image_bchw):
        return self.sample_func(image_bchw.to(DEVICE))

def get_torch_dtype(torch_dtype: str):
    if torch_dtype == 'torch.float16':
        return torch.float16
    elif torch_dtype == 'torch.bfloat16':
        return torch.bfloat16
    elif torch_dtype == 'torch.float32':
        return torch.float32
    else:
        raise ValueError(f'Unexpected torch dtype:{torch_dtype}')


================================================
FILE: comfyui_invsr_trimmed/time_aware_encoder.py
================================================
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn

from diffusers.utils import is_torch_version
from diffusers.models.unets.unet_2d_blocks import (
    UNetMidBlock2D,
    get_down_block,
)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps


class TimeAwareEncoder(nn.Module):
    r"""
    The `TimeAwareEncoder` layer of a variational autoencoder that encodes its input into a latent representation.

    Args:
        in_channels (`int`, *optional*, defaults to 3):
            The number of input channels.
        out_channels (`int`, *optional*, defaults to 3):
            The number of output channels.
        down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
            The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
            options.
        block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
            The number of output channels for each block.
        layers_per_block (`int`, *optional*, defaults to 2):
            The number of layers per block.
        norm_num_groups (`int`, *optional*, defaults to 32):
            The number of groups for normalization.
        act_fn (`str`, *optional*, defaults to `"silu"`):
            The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
        double_z (`bool`, *optional*, defaults to `True`):
            Whether to double the number of output channels for the last block.
        resnet_time_scale_shift (`str`, defaults to `"default"`)
    """

    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 3,
        down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
        block_out_channels: Tuple[int, ...] = (64,),
        layers_per_block: Union[int, Tuple[int, ...]] = 2,
        norm_num_groups: int = 32,
        act_fn: str = "silu",
        double_z: bool = True,
        mid_block_add_attention=True,
        resnet_time_scale_shift: str = "default",
        temb_channels: int = 256,
        freq_shift: int = 0,
        flip_sin_to_cos: bool = True,
        attention_head_dim: int = 1,
    ):
        super().__init__()
        if isinstance(layers_per_block, int):
            layers_per_block = (layers_per_block,) * len(down_block_types)
        self.layers_per_block = layers_per_block

        timestep_input_dim = max(128, block_out_channels[0])
        self.time_proj = Timesteps(timestep_input_dim, flip_sin_to_cos, freq_shift)
        self.time_embedding = TimestepEmbedding(timestep_input_dim, temb_channels)

        self.conv_in = nn.Conv2d(
            in_channels,
            block_out_channels[0],
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.down_blocks = nn.ModuleList([])

        # down
        output_channel = block_out_channels[0]
        for i, down_block_type in enumerate(down_block_types):
            input_channel = output_channel
            output_channel = block_out_channels[i]
            is_final_block = i == len(block_out_channels) - 1

            down_block = get_down_block(
                down_block_type,
                num_layers=self.layers_per_block[i],
                in_channels=input_channel,
                out_channels=output_channel,
                add_downsample=not is_final_block,
                resnet_eps=1e-6,
                downsample_padding=0,
                resnet_act_fn=act_fn,
                resnet_groups=norm_num_groups,
                attention_head_dim=attention_head_dim,
                resnet_time_scale_shift=resnet_time_scale_shift,
                temb_channels=temb_channels,
            )
            self.down_blocks.append(down_block)

        # mid
        self.mid_block = UNetMidBlock2D(
            in_channels=block_out_channels[-1],
            resnet_eps=1e-6,
            resnet_act_fn=act_fn,
            output_scale_factor=1,
            attention_head_dim=attention_head_dim,
            resnet_groups=norm_num_groups,
            add_attention=mid_block_add_attention,
            resnet_time_scale_shift=resnet_time_scale_shift,
            temb_channels=temb_channels,
        )

        # out
        self.conv_norm_out = nn.GroupNorm(
            num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
        )
        self.conv_act = nn.SiLU()

        conv_out_channels = 2 * out_channels if double_z else out_channels
        self.conv_out = nn.Conv2d(
            block_out_channels[-1], conv_out_channels, 3, padding=1
        )

        self.gradient_checkpointing = False

    def forward(
        self,
        sample: torch.Tensor,
        timesteps: Union[torch.Tensor, int],
    ) -> torch.Tensor:
        r"""The forward method of the `Encoder` class."""

        # time embedding
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor(
                [timesteps], dtype=torch.long, device=sample.device
            )
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(
            sample.shape[0], dtype=timesteps.dtype, device=timesteps.device
        )

        t_emb = self.time_proj(timesteps)

        # timesteps does not contain any weights and will always return f32 tensors
        # but time_embedding might actually be running in fp16. so we need to cast here.
        # there might be better ways to encapsulate this.
        t_emb = t_emb.to(dtype=list(self.time_embedding.parameters())[0].dtype)
        emb = self.time_embedding(t_emb)

        sample = self.conv_in(sample)

        if self.training and self.gradient_checkpointing:

            def create_custom_forward(module):
                def custom_forward(*inputs):
                    return module(*inputs)

                return custom_forward

            # down
            if is_torch_version(">=", "1.11.0"):
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(down_block),
                        sample,
                        emb,
                        use_reentrant=False,
                    )
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block),
                    sample,
                    emb,
                    use_reentrant=False,
                )
            else:
                for down_block in self.down_blocks:
                    sample = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(down_block), sample, emb
                    )
                # middle
                sample = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(self.mid_block), sample, emb
                )

        else:
            # down
            for down_block in self.down_blocks:
                sample, _ = down_block(sample, emb)

            # middle
            sample = self.mid_block(sample, emb)

        # post-process
        sample = self.conv_norm_out(sample)
        sample = self.conv_act(sample)
        sample = self.conv_out(sample)

        return sample


================================================
FILE: comfyui_invsr_trimmed/utils/__init__.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-01-18 11:40:23




================================================
FILE: comfyui_invsr_trimmed/utils/resize.py
================================================
"""
A standalone PyTorch implementation for fast and efficient bicubic resampling.
The resulting values are the same to MATLAB function imresize('bicubic').
## Author:      Sanghyun Son
## Email:       sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
## Version:     1.2.0
## Last update: July 9th, 2020 (KST)
Dependency: torch
Example::
>>> import torch
>>> import core
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
>>> y = core.imresize(x, sizes=(3, 3))
>>> print(y)
tensor([[[[ 0.7506,  2.1004,  3.4503],
          [ 6.1505,  7.5000,  8.8499],
          [11.5497, 12.8996, 14.2494]]]])
"""

import math
import typing

import torch
from torch.nn import functional as F

__all__ = ['imresize']

_I = typing.Optional[int]
_D = typing.Optional[torch.dtype]


def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
    range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
    cont = range_around_0.to(dtype=x.dtype)
    return cont


def linear_contribution(x: torch.Tensor) -> torch.Tensor:
    ax = x.abs()
    range_01 = ax.le(1)
    cont = (1 - ax) * range_01.to(dtype=x.dtype)
    return cont


def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
    ax = x.abs()
    ax2 = ax * ax
    ax3 = ax * ax2

    range_01 = ax.le(1)
    range_12 = torch.logical_and(ax.gt(1), ax.le(2))

    cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
    cont_01 = cont_01 * range_01.to(dtype=x.dtype)

    cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
    cont_12 = cont_12 * range_12.to(dtype=x.dtype)

    cont = cont_01 + cont_12
    return cont


def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
    range_3sigma = (x.abs() <= 3 * sigma + 1)
    # Normalization will be done after
    cont = torch.exp(-x.pow(2) / (2 * sigma**2))
    cont = cont * range_3sigma.to(dtype=x.dtype)
    return cont


def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
    '''
    For downsampling with integer scale only.
    '''
    downsampling_factor = int(1 / scale)
    if kernel == 'cubic':
        kernel_size_orig = 4
    else:
        raise ValueError('Pass!')

    if antialiasing:
        kernel_size = kernel_size_orig * downsampling_factor
    else:
        kernel_size = kernel_size_orig

    if downsampling_factor % 2 == 0:
        a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
    else:
        kernel_size -= 1
        a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))

    with torch.no_grad():
        r = torch.linspace(-a, a, steps=kernel_size)
        k = cubic_contribution(r).view(-1, 1)
        k = torch.matmul(k, k.t())
        k /= k.sum()

    return k


def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:
    '''
    Apply reflect padding to the given Tensor.
    Note that it is slightly different from the PyTorch functional.pad,
    where boundary elements are used only once.
    Instead, we follow the MATLAB implementation
    which uses boundary elements twice.
    For example,
    [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
    while our implementation yields [a, a, b, c, d, d].
    '''
    b, c, h, w = x.size()
    if dim == 2 or dim == -2:
        padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
        padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
        for p in range(pad_pre):
            padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
        for p in range(pad_post):
            padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
    else:
        padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
        padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
        for p in range(pad_pre):
            padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
        for p in range(pad_post):
            padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])

    return padding_buffer


def padding(x: torch.Tensor,
            dim: int,
            pad_pre: int,
            pad_post: int,
            padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
    if padding_type is None:
        return x
    elif padding_type == 'reflect':
        x_pad = reflect_padding(x, dim, pad_pre, pad_post)
    else:
        raise ValueError('{} padding is not supported!'.format(padding_type))

    return x_pad


def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
    base = base.long()
    r_min = base.min()
    r_max = base.max() + kernel_size - 1

    if r_min <= 0:
        pad_pre = -r_min
        pad_pre = pad_pre.item()
        base += pad_pre
    else:
        pad_pre = 0

    if r_max >= x_size:
        pad_post = r_max - x_size + 1
        pad_post = pad_post.item()
    else:
        pad_post = 0

    return pad_pre, pad_post, base


def get_weight(dist: torch.Tensor,
               kernel_size: int,
               kernel: str = 'cubic',
               sigma: float = 2.0,
               antialiasing_factor: float = 1) -> torch.Tensor:
    buffer_pos = dist.new_zeros(kernel_size, len(dist))
    for idx, buffer_sub in enumerate(buffer_pos):
        buffer_sub.copy_(dist - idx)

    # Expand (downsampling) / Shrink (upsampling) the receptive field.
    buffer_pos *= antialiasing_factor
    if kernel == 'cubic':
        weight = cubic_contribution(buffer_pos)
    elif kernel == 'gaussian':
        weight = gaussian_contribution(buffer_pos, sigma=sigma)
    else:
        raise ValueError('{} kernel is not supported!'.format(kernel))

    weight /= weight.sum(dim=0, keepdim=True)
    return weight


def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
    # Resize height
    if dim == 2 or dim == -2:
        k = (kernel_size, 1)
        h_out = x.size(-2) - kernel_size + 1
        w_out = x.size(-1)
    # Resize width
    else:
        k = (1, kernel_size)
        h_out = x.size(-2)
        w_out = x.size(-1) - kernel_size + 1

    unfold = F.unfold(x, k)
    unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
    return unfold


def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
    if x.dim() == 4:
        b, c, h, w = x.size()
    elif x.dim() == 3:
        c, h, w = x.size()
        b = None
    elif x.dim() == 2:
        h, w = x.size()
        b = c = None
    else:
        raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))

    x = x.view(-1, 1, h, w)
    return x, b, c, h, w


def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
    rh = x.size(-2)
    rw = x.size(-1)
    # Back to the original dimension
    if b is not None:
        x = x.view(b, c, rh, rw)  # 4-dim
    else:
        if c is not None:
            x = x.view(c, rh, rw)  # 3-dim
        else:
            x = x.view(rh, rw)  # 2-dim

    return x


def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
    if x.dtype != torch.float32 or x.dtype != torch.float64:
        dtype = x.dtype
        x = x.float()
    else:
        dtype = None

    return x, dtype


def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
    if dtype is not None:
        if not dtype.is_floating_point:
            x = x - x.detach() + x.round()
        # To prevent over/underflow when converting types
        if dtype is torch.uint8:
            x = x.clamp(0, 255)

        x = x.to(dtype=dtype)

    return x


def resize_1d(x: torch.Tensor,
              dim: int,
              size: int,
              scale: float,
              kernel: str = 'cubic',
              sigma: float = 2.0,
              padding_type: str = 'reflect',
              antialiasing: bool = True) -> torch.Tensor:
    '''
    Args:
        x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
        dim (int):
        scale (float):
        size (int):
    Return:
    '''
    # Identity case
    if scale == 1:
        return x

    # Default bicubic kernel with antialiasing (only when downsampling)
    if kernel == 'cubic':
        kernel_size = 4
    else:
        kernel_size = math.floor(6 * sigma)

    if antialiasing and (scale < 1):
        antialiasing_factor = scale
        kernel_size = math.ceil(kernel_size / antialiasing_factor)
    else:
        antialiasing_factor = 1

    # We allow margin to both sizes
    kernel_size += 2

    # Weights only depend on the shape of input and output,
    # so we do not calculate gradients here.
    with torch.no_grad():
        pos = torch.linspace(
            0,
            size - 1,
            steps=size,
            dtype=x.dtype,
            device=x.device,
        )
        pos = (pos + 0.5) / scale - 0.5
        base = pos.floor() - (kernel_size // 2) + 1
        dist = pos - base
        weight = get_weight(
            dist,
            kernel_size,
            kernel=kernel,
            sigma=sigma,
            antialiasing_factor=antialiasing_factor,
        )
        pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))

    # To backpropagate through x
    x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
    unfold = reshape_tensor(x_pad, dim, kernel_size)
    # Subsampling first
    if dim == 2 or dim == -2:
        sample = unfold[..., base, :]
        weight = weight.view(1, kernel_size, sample.size(2), 1)
    else:
        sample = unfold[..., base]
        weight = weight.view(1, kernel_size, 1, sample.size(3))

    # Apply the kernel
    x = sample * weight
    x = x.sum(dim=1, keepdim=True)
    return x


def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor:
    c = x.size(1)
    k_h = k.size(-2)
    k_w = k.size(-1)

    k = k.to(dtype=x.dtype, device=x.device)
    k = k.view(1, 1, k_h, k_w)
    k = k.repeat(c, c, 1, 1)
    e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
    e = e.view(c, c, 1, 1)
    k = k * e

    pad_h = (k_h - scale) // 2
    pad_w = (k_w - scale) // 2
    x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
    x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
    y = F.conv2d(x, k, padding=0, stride=scale)
    return y


def imresize(x: torch.Tensor,
             scale: typing.Optional[float] = None,
             sizes: typing.Optional[typing.Tuple[int, int]] = None,
             kernel: typing.Union[str, torch.Tensor] = 'cubic',
             sigma: float = 2,
             rotation_degree: float = 0,
             padding_type: str = 'reflect',
             antialiasing: bool = True) -> torch.Tensor:
    """
    Args:
        x (torch.Tensor):
        scale (float):
        sizes (tuple(int, int)):
        kernel (str, default='cubic'):
        sigma (float, default=2):
        rotation_degree (float, default=0):
        padding_type (str, default='reflect'):
        antialiasing (bool, default=True):
    Return:
        torch.Tensor:
    """
    if scale is None and sizes is None:
        raise ValueError('One of scale or sizes must be specified!')
    if scale is not None and sizes is not None:
        raise ValueError('Please specify scale or sizes to avoid conflict!')

    x, b, c, h, w = reshape_input(x)

    if sizes is None and scale is not None:
        '''
        # Check if we can apply the convolution algorithm
        scale_inv = 1 / scale
        if isinstance(kernel, str) and scale_inv.is_integer():
            kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
        elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
            raise ValueError(
                'An integer downsampling factor '
                'should be used with a predefined kernel!'
            )
        '''
        # Determine output size
        sizes = (math.ceil(h * scale), math.ceil(w * scale))
        scales = (scale, scale)

    if scale is None and sizes is not None:
        scales = (sizes[0] / h, sizes[1] / w)

    x, dtype = cast_input(x)

    if isinstance(kernel, str) and sizes is not None:
        # Core resizing module
        x = resize_1d(
            x,
            -2,
            size=sizes[0],
            scale=scales[0],
            kernel=kernel,
            sigma=sigma,
            padding_type=padding_type,
            antialiasing=antialiasing)
        x = resize_1d(
            x,
            -1,
            size=sizes[1],
            scale=scales[1],
            kernel=kernel,
            sigma=sigma,
            padding_type=padding_type,
            antialiasing=antialiasing)
    elif isinstance(kernel, torch.Tensor) and scale is not None:
        x = downsampling_2d(x, kernel, scale=int(1 / scale))

    x = reshape_output(x, b, c)
    x = cast_output(x, dtype)
    return x


================================================
FILE: comfyui_invsr_trimmed/utils/util_color_fix.py
================================================
'''
# --------------------------------------------------------------------------------
#   Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''

import torch
from torch import Tensor
from torch.nn import functional as F

from torchvision.transforms import ToTensor, ToPILImage

from .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch


def calc_mean_std(feat: Tensor, eps=1e-5):
    """Calculate mean and std for adaptive_instance_normalization.
    Args:
        feat (Tensor): 4D tensor.
        eps (float): A small value added to the variance to avoid
            divide-by-zero. Default: 1e-5.
    """
    size = feat.size()
    assert len(size) == 4, 'The input feature should be 4D tensor.'
    b, c = size[:2]
    feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
    feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
    return feat_mean, feat_std

def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
    """Adaptive instance normalization.
    Adjust the reference features to have the similar color and illuminations
    as those in the degradate features.
    Args:
        content_feat (Tensor): The reference feature.
        style_feat (Tensor): The degradate features.
    """
    size = content_feat.size()
    style_mean, style_std = calc_mean_std(style_feat)
    content_mean, content_std = calc_mean_std(content_feat)
    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def wavelet_blur(image: Tensor, radius: int):
    """
    Apply wavelet blur to the input tensor.
    """
    # input shape: (1, 3, H, W)
    # convolution kernel
    kernel_vals = [
        [0.0625, 0.125, 0.0625],
        [0.125, 0.25, 0.125],
        [0.0625, 0.125, 0.0625],
    ]
    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
    # add channel dimensions to the kernel to make it a 4D tensor
    kernel = kernel[None, None]
    # repeat the kernel across all input channels
    kernel = kernel.repeat(3, 1, 1, 1)
    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
    # apply convolution
    output = F.conv2d(image, kernel, groups=3, dilation=radius)
    return output

def wavelet_decomposition(image: Tensor, levels=5):
    """
    Apply wavelet decomposition to the input tensor.
    This function only returns the low frequency & the high frequency.
    """
    high_freq = torch.zeros_like(image)
    for i in range(levels):
        radius = 2 ** i
        low_freq = wavelet_blur(image, radius)
        high_freq += (image - low_freq)
        image = low_freq

    return high_freq, low_freq

def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
    """
    Apply wavelet decomposition, so that the content will have the same color as the style.
    """
    # calculate the wavelet decomposition of the content feature
    content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
    del content_low_freq
    # calculate the wavelet decomposition of the style feature
    style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
    del style_high_freq
    # reconstruct the content feature with the style's high frequency
    return content_high_freq + style_low_freq

def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):
    """
    Apply ycbcr decomposition, so that the content will have the same color as the style.
    """
    content_y = rgb2ycbcrTorch(content_feat, only_y=True)
    style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False)

    target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1)

    target_rgb = ycbcr2rgbTorch(target_ycbcr)

    return target_rgb




================================================
FILE: comfyui_invsr_trimmed/utils/util_common.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-02-06 10:34:59

import os
import random
import requests
import importlib
from pathlib import Path

def mkdir(dir_path, delete=False, parents=True):
    import shutil
    if not isinstance(dir_path, Path):
        dir_path = Path(dir_path)
    if delete:
        if dir_path.exists():
            shutil.rmtree(str(dir_path))
    if not dir_path.exists():
        dir_path.mkdir(parents=parents)

def get_obj_from_str(string, reload=False):
    current_package = __package__.rsplit(".", 1)[0]
    is_relative_import = string.startswith(".")
    package = current_package if is_relative_import else None

    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module, package=package)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=package), cls)

def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("Boolean value expected.")

def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):
    '''
    Get the file paths in the given folder.
    param exts: list, e.g., ['png',]
    return: list
    '''
    if not isinstance(dir_path, Path):
        dir_path = Path(dir_path)

    file_paths = []
    for current_ext in exts:
        if recursive:
            file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)])
        else:
            file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)])

    return file_paths

def readline_txt(txt_file):
    txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
    out = []
    for txt_file_current in txt_file:
        with open(txt_file_current, 'r') as ff:
            out.extend([x[:-1] for x in ff.readlines()])

    return out

def scan_files_from_folder(dir_paths, exts, recursive=True):
    '''
    Scaning images from given folder.
    Input:
        dir_pathas: str or list.
        exts: list
    '''
    exts = [exts, ] if isinstance(exts, str) else exts
    dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths

    file_paths = []
    for current_dir in dir_paths:
        current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir
        for current_ext in exts:
            if recursive:
                search_flag = f"**/*.{current_ext}"
            else:
                search_flag = f"*.{current_ext}"
            file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)]))

    return file_paths

def write_path_to_txt(
        dir_folder,
        txt_path,
        search_key,
        num_files=None,
        write_only_name=False,
        write_only_stem=False,
        shuffle=False,
        ):
    '''
    Scaning the files in the given folder and write them into a txt file
    Input:
        dir_folder: path of the target folder
        txt_path: path to save the txt file
        search_key: e.g., '*.png'
        write_only_name: bool, only record the file names (including extension),
        write_only_stem: bool, only record the file names (not including extension),
    '''
    txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path
    dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder
    if txt_path.exists():
        txt_path.unlink()
    if write_only_name:
        path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)])
    elif write_only_stem:
        path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)])
    else:
        path_list = sorted([str(x) for x in dir_folder.glob(search_key)])
    if shuffle:
        random.shuffle(path_list)
    if num_files is not None:
        path_list = path_list[:num_files]
    with open(txt_path, mode='w') as ff:
        for line in path_list:
            ff.write(line+'\n')


================================================
FILE: comfyui_invsr_trimmed/utils/util_ema.py
================================================
import torch
from torch import nn


class LitEma(nn.Module):
    def __init__(self, model, decay=0.9999, use_num_upates=True):
        super().__init__()
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')

        self.m_name2s_name = {}
        self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
        else torch.tensor(-1, dtype=torch.int))

        for name, p in model.named_parameters():
            if p.requires_grad:
                # remove as '.'-character is not allowed in buffers
                s_name = name.replace('.', '')
                self.m_name2s_name.update({name: s_name})
                self.register_buffer(s_name, p.clone().detach().data)

        self.collected_params = []

    def reset_num_updates(self):
        del self.num_updates
        self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))

    def forward(self, model):
        decay = self.decay

        if self.num_updates >= 0:
            self.num_updates += 1
            decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))

        one_minus_decay = 1.0 - decay

        with torch.no_grad():
            m_param = dict(model.named_parameters())
            shadow_params = dict(self.named_buffers())

            for key in m_param:
                if m_param[key].requires_grad:
                    sname = self.m_name2s_name[key]
                    shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
                    shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
                else:
                    assert not key in self.m_name2s_name

    def copy_to(self, model):
        """
        Copying the ema state (i.e., buffers) to the targeted model
        Input:
            model: targeted model
        """
        m_param = dict(model.named_parameters())
        shadow_params = dict(self.named_buffers())
        for key in m_param:
            if m_param[key].requires_grad:
                m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
            else:
                assert not key in self.m_name2s_name

    def store(self, parameters):
        """
        Save the parameters of the targeted model into the temporary pool for restoring later.
        Args:
          parameters: parameters of the targeted model.
                      Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters from the temporaty pool (stored with the `store` method).
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

    def resume(self, ckpt, num_updates):
        """
        Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers
        Input:
            model: targerted model
        """
        self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int))

        shadow_params = dict(self.named_buffers())
        for key, value in ckpt.items():
            try:
                shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
            except:
                if key.startswith('module') and key not in shadow_params:
                    key = key[7:]
                shadow_params[self.m_name2s_name[key]].data.copy_(value.data)


================================================
FILE: comfyui_invsr_trimmed/utils/util_image.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 16:54:19

import sys
import cv2
import math
import torch
import random
import numpy as np
from pathlib import Path

# --------------------------Metrics----------------------------
def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

def calculate_ssim(im1, im2, border=0, ycbcr=False):
    '''
    SSIM the same outputs as MATLAB's
    im1, im2: h x w x , [0, 255], uint8
    '''
    if not im1.shape == im2.shape:
        raise ValueError('Input images must have the same dimensions.')

    if ycbcr:
        im1 = rgb2ycbcr(im1, True)
        im2 = rgb2ycbcr(im2, True)

    h, w = im1.shape[:2]
    im1 = im1[border:h-border, border:w-border]
    im2 = im2[border:h-border, border:w-border]

    if im1.ndim == 2:
        return ssim(im1, im2)
    elif im1.ndim == 3:
        if im1.shape[2] == 3:
            ssims = []
            for i in range(3):
                ssims.append(ssim(im1[:,:,i], im2[:,:,i]))
            return np.array(ssims).mean()
        elif im1.shape[2] == 1:
            return ssim(np.squeeze(im1), np.squeeze(im2))
    else:
        raise ValueError('Wrong input image dimensions.')

def calculate_psnr(im1, im2, border=0, ycbcr=False):
    '''
    PSNR metric.
    im1, im2: h x w x , [0, 255], uint8
    '''
    if not im1.shape == im2.shape:
        raise ValueError('Input images must have the same dimensions.')

    if ycbcr:
        im1 = rgb2ycbcr(im1, True)
        im2 = rgb2ycbcr(im2, True)

    h, w = im1.shape[:2]
    im1 = im1[border:h-border, border:w-border]
    im2 = im2[border:h-border, border:w-border]

    im1 = im1.astype(np.float64)
    im2 = im2.astype(np.float64)
    mse = np.mean((im1 - im2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

def normalize_np(im, mean=0.5, std=0.5, reverse=False):
    '''
    Input:
        im: h x w x c, numpy array
        Normalize: (im - mean) / std
        Reverse: im * std + mean

    '''
    if not isinstance(mean, (list, tuple)):
        mean = [mean, ] * im.shape[2]
    mean = np.array(mean).reshape([1, 1, im.shape[2]])

    if not isinstance(std, (list, tuple)):
        std = [std, ] * im.shape[2]
    std = np.array(std).reshape([1, 1, im.shape[2]])

    if not reverse:
        out = (im.astype(np.float32) - mean) / std
    else:
        out = im.astype(np.float32) * std + mean
    return out

def normalize_th(im, mean=0.5, std=0.5, reverse=False):
    '''
    Input:
        im: b x c x h x w, torch tensor
        Normalize: (im - mean) / std
        Reverse: im * std + mean

    '''
    if not isinstance(mean, (list, tuple)):
        mean = [mean, ] * im.shape[1]
    mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])

    if not isinstance(std, (list, tuple)):
        std = [std, ] * im.shape[1]
    std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])

    if not reverse:
        out = (im - mean) / std
    else:
        out = im * std + mean
    return out

# ------------------------Image format--------------------------
def rgb2ycbcr(im, only_y=True):
    '''
    same as matlab rgb2ycbcr
    Input:
        im: uint8 [0,255] or float [0,1]
        only_y: only return Y channel
    '''
    # transform to float64 data type, range [0, 255]
    if im.dtype == np.uint8:
        im_temp = im.astype(np.float64)
    else:
        im_temp = (im * 255).astype(np.float64)

    # convert
    if only_y:
        rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0
    else:
        rlt = np.matmul(im_temp, np.array([[65.481,  -37.797, 112.0  ],
                                           [128.553, -74.203, -93.786],
                                           [24.966,  112.0,   -18.214]])/255.0) + [16, 128, 128]
    if im.dtype == np.uint8:
        rlt = rlt.round()
    else:
        rlt /= 255.
    return rlt.astype(im.dtype)

def rgb2ycbcrTorch(im, only_y=True):
    '''
    same as matlab rgb2ycbcr
    Input:
        im: float [0,1], N x 3 x H x W
        only_y: only return Y channel
    '''
    # transform to range [0,255.0]
    im_temp = im.permute([0,2,3,1]) * 255.0  # N x H x W x C --> N x H x W x C
    # convert
    if only_y:
        rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],
                                        device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0
    else:
        scale = torch.tensor(
            [[65.481,  -37.797, 112.0  ],
             [128.553, -74.203, -93.786],
             [24.966,  112.0,   -18.214]],
            device=im.device, dtype=im.dtype
        ) / 255.0
        bias = torch.tensor([16, 128, 128], device=im.device, dtype=im.dtype).view([-1, 1, 1, 3])
        rlt = torch.matmul(im_temp, scale) + bias

    rlt /= 255.0
    rlt.clamp_(0.0, 1.0)
    return rlt.permute([0, 3, 1, 2])

def ycbcr2rgbTorch(im):
    '''
    same as matlab ycbcr2rgb
    Input:
        im: float [0,1], N x 3 x H x W
        only_y: only return Y channel
    '''
    # transform to range [0,255.0]
    im_temp = im.permute([0,2,3,1]) * 255.0  # N x H x W x C --> N x H x W x C
    # convert
    scale = torch.tensor(
        [[0.00456621, 0.00456621, 0.00456621],
         [0, -0.00153632, 0.00791071],
         [0.00625893, -0.00318811, 0]],
        device=im.device, dtype=im.dtype
        ) * 255.0
    bias = torch.tensor(
        [-222.921, 135.576, -276.836], device=im.device, dtype=im.dtype
    ).view([-1, 1, 1, 3])
    rlt = torch.matmul(im_temp, scale) + bias
    rlt /= 255.0
    rlt.clamp_(0.0, 1.0)
    return rlt.permute([0, 3, 1, 2])

def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)

def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)

def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
    """Convert torch Tensors into image numpy arrays.

    After clamping to [min, max], values will be normalized to [0, 1].

    Args:
        tensor (Tensor or list[Tensor]): Accept shapes:
            1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
            2) 3D Tensor of shape (3/1 x H x W);
            3) 2D Tensor of shape (H x W).
            Tensor channel should be in RGB order.
        rgb2bgr (bool): Whether to change rgb to bgr.
        out_type (numpy type): output types. If ``np.uint8``, transform outputs
            to uint8 type with range [0, 255]; otherwise, float type with
            range [0, 1]. Default: ``np.uint8``.
        min_max (tuple[int]): min and max values for clamp.

    Returns:
        (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
        shape (H x W). The channel order is BGR.
    """
    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

    flag_tensor = torch.is_tensor(tensor)
    if flag_tensor:
        tensor = [tensor]
    result = []
    for _tensor in tensor:
        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])

        n_dim = _tensor.dim()
        if n_dim == 4:
            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
            img_np = img_np.transpose(1, 2, 0)
            if rgb2bgr:
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 3:
            img_np = _tensor.numpy()
            img_np = img_np.transpose(1, 2, 0)
            if img_np.shape[2] == 1:  # gray image
                img_np = np.squeeze(img_np, axis=2)
            else:
                if rgb2bgr:
                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 2:
            img_np = _tensor.numpy()
        else:
            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
        if out_type == np.uint8:
            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
            img_np = (img_np * 255.0).round()
        img_np = img_np.astype(out_type)
        result.append(img_np)
    if len(result) == 1 and flag_tensor:
        result = result[0]
    return result

# ------------------------Image resize-----------------------------
def imresize_np(img, scale, antialiasing=True):
    # Now the scale should be the same for H and W
    # input: img: Numpy, HWC or HW [0,1]
    # output: HWC or HW [0,1] w/o round
    img = torch.from_numpy(img)
    need_squeeze = True if img.dim() == 2 else False
    if need_squeeze:
        img.unsqueeze_(2)

    in_H, in_W, in_C = img.size()
    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
    kernel_width = 4
    kernel = 'cubic'

    # Return the desired dimension order for performing the resize.  The
    # strategy is to perform the resize first along the dimension with the
    # smallest scale factor.
    # Now we do not support this.

    # get weights and indices
    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
        in_H, out_H, scale, kernel, kernel_width, antialiasing)
    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
        in_W, out_W, scale, kernel, kernel_width, antialiasing)
    # process H dimension
    # symmetric copying
    img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
    img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)

    sym_patch = img[:sym_len_Hs, :, :]
    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(0, inv_idx)
    img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)

    sym_patch = img[-sym_len_He:, :, :]
    inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(0, inv_idx)
    img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)

    out_1 = torch.FloatTensor(out_H, in_W, in_C)
    kernel_width = weights_H.size(1)
    for i in range(out_H):
        idx = int(indices_H[i][0])
        for j in range(out_C):
            out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])

    # process W dimension
    # symmetric copying
    out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
    out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)

    sym_patch = out_1[:, :sym_len_Ws, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)

    sym_patch = out_1[:, -sym_len_We:, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)

    out_2 = torch.FloatTensor(out_H, out_W, in_C)
    kernel_width = weights_W.size(1)
    for i in range(out_W):
        idx = int(indices_W[i][0])
        for j in range(out_C):
            out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
    if need_squeeze:
        out_2.squeeze_()

    return out_2.numpy()

def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
    if (scale < 1) and (antialiasing):
        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
        kernel_width = kernel_width / scale

    # Output-space coordinates
    x = torch.linspace(1, out_length, out_length)

    # Input-space coordinates. Calculate the inverse mapping such that 0.5
    # in output space maps to 0.5 in input space, and 0.5+scale in output
    # space maps to 1.5 in input space.
    u = x / scale + 0.5 * (1 - 1 / scale)

    # What is the left-most pixel that can be involved in the computation?
    left = torch.floor(u - kernel_width / 2)

    # What is the maximum number of pixels that can be involved in the
    # computation?  Note: it's OK to use an extra pixel here; if the
    # corresponding weights are all zero, it will be eliminated at the end
    # of this function.
    P = math.ceil(kernel_width) + 2

    # The indices of the input pixels involved in computing the k-th output
    # pixel are in row k of the indices matrix.
    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
        1, P).expand(out_length, P)

    # The weights used to compute the k-th output pixel are in row k of the
    # weights matrix.
    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
    # apply cubic kernel
    if (scale < 1) and (antialiasing):
        weights = scale * cubic(distance_to_center * scale)
    else:
        weights = cubic(distance_to_center)
    # Normalize the weights matrix so that each row sums to 1.
    weights_sum = torch.sum(weights, 1).view(out_length, 1)
    weights = weights / weights_sum.expand(out_length, P)

    # If a column in weights is all zero, get rid of it. only consider the first and last column.
    weights_zero_tmp = torch.sum((weights == 0), 0)
    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 1, P - 2)
        weights = weights.narrow(1, 1, P - 2)
    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 0, P - 2)
        weights = weights.narrow(1, 0, P - 2)
    weights = weights.contiguous()
    indices = indices.contiguous()
    sym_len_s = -indices.min() + 1
    sym_len_e = indices.max() - in_length
    indices = indices + sym_len_s - 1
    return weights, indices, int(sym_len_s), int(sym_len_e)

# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
    absx = torch.abs(x)
    absx2 = absx**2
    absx3 = absx**3
    return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
        (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))

# ------------------------Image I/O-----------------------------
def imread(path, chn='rgb', dtype='float32', force_gray2rgb=True, force_rgba2rgb=False):
    '''
    Read image.
    chn: 'rgb', 'bgr' or 'gray'
    out:
        im: h x w x c, numpy tensor
    '''
    try:
        im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)  # BGR, uint8
    except:
        print(str(path))

    if im is None:
        print(str(path))

    if chn.lower() == 'gray':
        assert im.ndim == 2, f"{str(path)} can't be successfuly read!"
    else:
        if im.ndim == 2:
            if force_gray2rgb:
                im = np.stack([im, im, im], axis=2)
            else:
                raise ValueError(f"{str(path)} has {im.ndim} channels!")
        elif im.ndim == 4:
            if force_rgba2rgb:
                im = im[:, :, :3]
            else:
                raise ValueError(f"{str(path)} has {im.ndim} channels!")
        else:
            if chn.lower() == 'rgb':
                im = bgr2rgb(im)
            elif chn.lower() == 'bgr':
                pass

    if dtype == 'float32':
        im = im.astype(np.float32) / 255.
    elif dtype ==  'float64':
        im = im.astype(np.float64) / 255.
    elif dtype == 'uint8':
        pass
    else:
        sys.exit('Please input corrected dtype: float32, float64 or uint8!')

    return im

# ------------------------Augmentation-----------------------------
def data_aug_np(image, mode):
    '''
    Performs data augmentation of the input image
    Input:
        image: a cv2 (OpenCV) image
        mode: int. Choice of transformation to apply to the image
                0 - no transformation
                1 - flip up and down
                2 - rotate counterwise 90 degree
                3 - rotate 90 degree and flip up and down
                4 - rotate 180 degree
                5 - rotate 180 degree and flip
                6 - rotate 270 degree
                7 - rotate 270 degree and flip
    '''
    if mode == 0:
        # original
        out = image
    elif mode == 1:
        # flip up and down
        out = np.flipud(image)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(image)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(image)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(image, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(image, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(image, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(image, k=3)
        out = np.flipud(out)
    else:
        raise Exception('Invalid choice of image transformation')

    return out.copy()

def inverse_data_aug_np(image, mode):
    '''
    Performs inverse data augmentation of the input image
    '''
    if mode == 0:
        # original
        out = image
    elif mode == 1:
        out = np.flipud(image)
    elif mode == 2:
        out = np.rot90(image, axes=(1,0))
    elif mode == 3:
        out = np.flipud(image)
        out = np.rot90(out, axes=(1,0))
    elif mode == 4:
        out = np.rot90(image, k=2, axes=(1,0))
    elif mode == 5:
        out = np.flipud(image)
        out = np.rot90(out, k=2, axes=(1,0))
    elif mode == 6:
        out = np.rot90(image, k=3, axes=(1,0))
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.flipud(image)
        out = np.rot90(out, k=3, axes=(1,0))
    else:
        raise Exception('Invalid choice of image transformation')

    return out

# ----------------------Visualization----------------------------
def imshow(x, title=None, cbar=False):
    import matplotlib.pyplot as plt
    plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

def imblend_with_mask(im, mask, alpha=0.25):
    """
    Input:
        im, mask: h x w x c numpy array, uint8, [0, 255]
        alpha: scaler in [0.0, 1.0]
    """
    edge_map = cv2.Canny(mask, 100, 200).astype(np.float32)[:, :, None] / 255.

    assert mask.dtype == np.uint8
    mask = mask.astype(np.float32) / 255.
    if mask.ndim == 2:
        mask = mask[:, :, None]

    back_color = np.array([159, 121, 238], dtype=np.float32).reshape((1,1,3))
    blend = im.astype(np.float32) * alpha + (1 - alpha) * back_color
    blend = np.clip(blend, 0, 255)
    out = im.astype(np.float32) * (1 - mask) + blend * mask

    # paste edge
    out = out * (1 - edge_map) + np.array([0,255,0], dtype=np.float32).reshape((1,1,3)) * edge_map

    return out.astype(np.uint8)
# -----------------------Covolution------------------------------
def imgrad(im, pading_mode='mirror'):
    '''
    Calculate image gradient.
    Input:
        im: h x w x c numpy array
    '''
    from scipy.ndimage import correlate  # lazy import
    wx = np.array([[0, 0, 0],
                   [-1, 1, 0],
                   [0, 0, 0]], dtype=np.float32)
    wy = np.array([[0, -1, 0],
                   [0, 1, 0],
                   [0, 0, 0]], dtype=np.float32)
    if im.ndim == 3:
        gradx = np.stack(
                [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],
                axis=2
                )
        grady = np.stack(
                [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],
                axis=2
                )
        grad = np.concatenate((gradx, grady), axis=2)
    else:
        gradx = correlate(im, wx, mode=pading_mode)
        grady = correlate(im, wy, mode=pading_mode)
        grad = np.stack((gradx, grady), axis=2)

    return {'gradx': gradx, 'grady': grady, 'grad':grad}

def convtorch(im, weight, mode='reflect'):
    '''
    Image convolution with pytorch
    Input:
        im: b x c_in x h x w torch tensor
        weight: c_out x c_in x k x k torch tensor
    Output:
        out: c x h x w torch tensor
    '''
    radius = weight.shape[-1]
    chn = im.shape[1]
    im_pad = torch.nn.functional.pad(im, pad=(radius // 2, )*4, mode=mode)
    out = torch.nn.functional.conv2d(im_pad, weight, padding=0, groups=chn)
    return out

# ----------------------Patch Cropping----------------------------
def random_crop(im, pch_size):
    '''
    Randomly crop a patch from the give image.
    '''
    h, w = im.shape[:2]
    # padding if necessary
    if h < pch_size or w < pch_size:
        pad_h = min(max(0, pch_size - h), h)
        pad_w = min(max(0, pch_size - w), w)
        im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)

    h, w = im.shape[:2]
    if h == pch_size:
        ind_h = 0
    elif h > pch_size:
        ind_h = random.randint(0, h-pch_size)
    else:
        raise ValueError('Image height is smaller than the patch size')
    if w == pch_size:
        ind_w = 0
    elif w > pch_size:
        ind_w = random.randint(0, w-pch_size)
    else:
        raise ValueError('Image width is smaller than the patch size')

    im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]

    return im_pch

class ToTensor:
    def __init__(self, max_value=1.0):
        self.max_value = max_value

    def __call__(self, im):
        assert isinstance(im, np.ndarray)
        if im.ndim == 2:
            im = im[:, :, np.newaxis]
        if im.dtype == np.uint8:
            assert self.max_value == 255.
            out = torch.from_numpy(im.astype(np.float32).transpose(2,0,1) / self.max_value)
        else:
            assert self.max_value == 1.0
            out = torch.from_numpy(im.transpose(2,0,1))
        return out

class RandomCrop:
    def __init__(self, pch_size, pass_crop=False):
        self.pch_size = pch_size
        self.pass_crop = pass_crop

    def __call__(self, im):
        if self.pass_crop:
            return im
        if isinstance(im, list) or isinstance(im, tuple):
            out = []
            for current_im in im:
                out.append(random_crop(current_im, self.pch_size))
        else:
            out = random_crop(im, self.pch_size)
        return out

class ImageSpliterNp:
    def __init__(self, im, pch_size, stride, sf=1):
        '''
        Input:
            im: h x w x c, numpy array, [0, 1], low-resolution image in SR
            pch_size, stride: patch setting
            sf: scale factor in image super-resolution
        '''
        assert stride <= pch_size
        self.stride = stride
        self.pch_size = pch_size
        self.sf = sf

        if im.ndim == 2:
            im = im[:, :, None]

        height, width, chn = im.shape
        self.height_starts_list = self.extract_starts(height)
        self.width_starts_list = self.extract_starts(width)
        self.length = self.__len__()
        self.num_pchs = 0

        self.im_ori = im
        self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
        self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)

    def extract_starts(self, length):
        starts = list(range(0, length, self.stride))
        if starts[-1] + self.pch_size > length:
            starts[-1] = length - self.pch_size
        return starts

    def __len__(self):
        return len(self.height_starts_list) * len(self.width_starts_list)

    def __iter__(self):
        return self

    def __next__(self):
        if self.num_pchs < self.length:
            w_start_idx = self.num_pchs // len(self.height_starts_list)
            w_start = self.width_starts_list[w_start_idx] * self.sf
            w_end = w_start + self.pch_size * self.sf

            h_start_idx = self.num_pchs % len(self.height_starts_list)
            h_start = self.height_starts_list[h_start_idx] * self.sf
            h_end = h_start + self.pch_size * self.sf

            pch = self.im_ori[h_start:h_end, w_start:w_end,]
            self.w_start, self.w_end = w_start, w_end
            self.h_start, self.h_end = h_start, h_end

            self.num_pchs += 1
        else:
            raise StopIteration(0)

        return pch, (h_start, h_end, w_start, w_end)

    def update(self, pch_res, index_infos):
        '''
        Input:
            pch_res: pch_size x pch_size x 3, [0,1]
            index_infos: (h_start, h_end, w_start, w_end)
        '''
        if index_infos is None:
            w_start, w_end = self.w_start, self.w_end
            h_start, h_end = self.h_start, self.h_end
        else:
            h_start, h_end, w_start, w_end = index_infos

        self.im_res[h_start:h_end, w_start:w_end] += pch_res
        self.pixel_count[h_start:h_end, w_start:w_end] += 1

    def gather(self):
        assert np.all(self.pixel_count != 0)
        return self.im_res / self.pixel_count

class ImageSpliterTh:
    def __init__(self, im, pch_size, stride, sf=1, extra_bs=1, weight_type='Gaussian'):
        '''
        Input:
            im: n x c x h x w, torch tensor, float, low-resolution image in SR
            pch_size, stride: patch setting
            sf: scale factor in image super-resolution
            pch_bs: aggregate pchs to processing, only used when inputing single image
        '''
        assert weight_type in ['Gaussian', 'ones']
        self.weight_type = weight_type
        assert stride <= pch_size
        self.stride = stride
        self.pch_size = pch_size
        self.sf = sf
        self.extra_bs = extra_bs

        bs, chn, height, width= im.shape
        self.true_bs = bs

        self.height_starts_list = self.extract_starts(height)
        self.width_starts_list = self.extract_starts(width)
        self.starts_list = []
        for ii in self.height_starts_list:
            for jj in self.width_starts_list:
                self.starts_list.append([ii, jj])

        self.length = self.__len__()
        self.count_pchs = 0

        self.im_ori = im
        self.dtype = torch.float64
        self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)
        self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)

    def extract_starts(self, length):
        if length <= self.pch_size:
            starts = [0,]
        else:
            starts = list(range(0, length, self.stride))
            for ii in range(len(starts)):
                if starts[ii] + self.pch_size > length:
                    starts[ii] = length - self.pch_size
            starts = sorted(set(starts), key=starts.index)
        return starts

    def __len__(self):
        return len(self.height_starts_list) * len(self.width_starts_list)

    def __iter__(self):
        return self

    def __next__(self):
        if self.count_pchs < self.length:
            index_infos = []
            current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs]
            for ii, (h_start, w_start) in enumerate(current_starts_list):
                w_end = w_start + self.pch_size
                h_end = h_start + self.pch_size
                current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end]
                if ii == 0:
                    pch =  current_pch
                else:
                    pch = torch.cat([pch, current_pch], dim=0)

                h_start *= self.sf
                h_end *= self.sf
                w_start *= self.sf
                w_end *= self.sf
                index_infos.append([h_start, h_end, w_start, w_end])

            self.count_pchs += len(current_starts_list)
        else:
            raise StopIteration()

        return pch, index_infos

    def update(self, pch_res, index_infos):
        '''
        Input:
            pch_res: (n*extra_bs) x c x pch_size x pch_size, float
            index_infos: [(h_start, h_end, w_start, w_end),]
        '''
        assert pch_res.shape[0] % self.true_bs == 0
        pch_list = torch.split(pch_res, self.true_bs, dim=0)
        assert len(pch_list) == len(index_infos)
        for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos):
            current_pch = pch_list[ii].type(self.dtype)
            current_weight = self.get_weight(current_pch.shape[-2], current_pch.shape[-1])
            self.im_res[:, :, h_start:h_end, w_start:w_end] +=  current_pch * current_weight
            self.pixel_count[:, :, h_start:h_end, w_start:w_end] += current_weight

    @staticmethod
    def generate_kernel_1d(ksize):
        sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8  # opencv default setting
        if ksize % 2 == 0:
            kernel = cv2.getGaussianKernel(ksize=ksize+1, sigma=sigma, ktype=cv2.CV_64F)
            kernel = kernel[1:, ]
        else:
            kernel = cv2.getGaussianKernel(ksize=ksize, sigma=sigma, ktype=cv2.CV_64F)

        return kernel

    def get_weight(self, height, width):
        if self.weight_type == 'ones':
            kernel = torch.ones(1, 1, height, width)
        elif self.weight_type == 'Gaussian':
            kernel_h = self.generate_kernel_1d(height).reshape(-1, 1)
            kernel_w = self.generate_kernel_1d(width).reshape(1, -1)
            kernel = np.matmul(kernel_h, kernel_w)
            kernel = torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0) # 1 x 1 x pch_size x pch_size
        else:
            raise ValueError(f"Unsupported weight type: {self.weight_type}")

        return kernel.to(dtype=self.dtype, device=self.im_ori.device)

    def gather(self):
        assert torch.all(self.pixel_count != 0)
        return self.im_res.div(self.pixel_count)

# ----------------------Patch Cliping----------------------------
class Clamper:
    def __init__(self, min_max=(-1, 1)):
        self.min_bound, self.max_bound = min_max[0], min_max[1]

    def __call__(self, im):
        if isinstance(im, np.ndarray):
            return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)
        elif isinstance(im, torch.Tensor):
            return torch.clamp(im, min=self.min_bound, max=self.max_bound)
        else:
            raise TypeError(f'ndarray or Tensor expected, got {type(im)}')

# ----------------------Interpolation----------------------------
class Bicubic:
    def __init__(self, scale=None, out_shape=None, activate_matlab=True, resize_back=False):
        self.scale = scale
        self.activate_matlab = activate_matlab
        self.out_shape = out_shape
        self.resize_back = resize_back

    def __call__(self, im):
        if self.activate_matlab:
            out = imresize_np(im, scale=self.scale)
            if self.resize_back:
                out = imresize_np(out, scale=1/self.scale)
        else:
            out = cv2.resize(
                    im,
                    dsize=self.out_shape,
                    fx=self.scale,
                    fy=self.scale,
                    interpolation=cv2.INTER_CUBIC,
                    )
            if self.resize_back:
                out = cv2.resize(
                        out,
                        dsize=self.out_shape,
                        fx=1/self.scale,
                        fy=1/self.scale,
                        interpolation=cv2.INTER_CUBIC,
                        )
        return out

class SmallestMaxSize:
    def __init__(self, max_size, pass_resize=False, interpolation=None):
        self.pass_resize = pass_resize
        self.max_size = max_size
        self.interpolation = interpolation
        self.str2mode = {
                'nearest': cv2.INTER_NEAREST_EXACT,
                'bilinear': cv2.INTER_LINEAR,
                'bicubic': cv2.INTER_CUBIC
                }
        if self.interpolation is not None:
            assert interpolation in self.str2mode, f"Not supported interpolation mode: {interpolation}"

    def get_interpolation(self, size):
        if self.interpolation is None:
            if size < self.max_size:   # upsampling
                interpolation = cv2.INTER_CUBIC
            else:                      # downsampling
                interpolation = cv2.INTER_AREA
        else:
            interpolation = self.str2mode[self.interpolation]

        return interpolation

    def __call__(self, im):
        h, w = im.shape[:2]
        if self.pass_resize or min(h, w) == self.max_size:
            out = im
        else:
            if h < w:
                dsize = (int(self.max_size * w / h), self.max_size)
                out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(h))
            else:
                dsize = (self.max_size, int(self.max_size * h / w))
                out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(w))
            if out.dtype == np.uint8:
                out = np.clip(out, 0, 255)
            else:
                out = np.clip(out, 0, 1.0)

        return out

# ----------------------augmentation----------------------------
class SpatialAug:
    def __init__(self, pass_aug, only_hflip=False, only_vflip=False, only_hvflip=False):
        self.only_hflip = only_hflip
        self.only_vflip = only_vflip
        self.only_hvflip = only_hvflip
        self.pass_aug = pass_aug

    def __call__(self, im, flag=None):
        if self.pass_aug:
            return im

        if flag is None:
            if self.only_hflip:
                flag = random.choice([0, 5])
            elif self.only_vflip:
                flag = random.choice([0, 1])
            elif self.only_hvflip:
                flag = random.choice([0, 1, 5])
            else:
                flag = random.randint(0, 7)

        if isinstance(im, list) or isinstance(im, tuple):
            out = []
            for current_im in im:
                out.append(data_aug_np(current_im, flag))
        else:
            out = data_aug_np(im, flag)
        return out


================================================
FILE: comfyui_invsr_trimmed/utils/util_net.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 20:29:36

def reload_model(model, ckpt):
    module_flag = list(ckpt.keys())[0].startswith('module.')
    compile_flag = '_orig_mod' in list(ckpt.keys())[0]

    for source_key, source_value in model.state_dict().items():
        target_key = source_key
        if compile_flag and (not '_orig_mod.' in source_key):
            target_key = '_orig_mod.' + target_key
        if module_flag and (not source_key.startswith('module')):
            target_key = 'module.' + target_key

        assert target_key in ckpt
        source_value.copy_(ckpt[target_key])


================================================
FILE: comfyui_invsr_trimmed/utils/util_opts.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 15:07:43

import argparse

def update_args(args_json, args_parser):
    for arg in vars(args_parser):
        args_json[arg] = getattr(args_parser, arg)

def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


================================================
FILE: comfyui_invsr_trimmed/utils/util_sisr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-12-07 21:37:58

import cv2
import numpy as np

def modcrop(im, sf):
    h, w = im.shape[:2]
    h -= (h % sf)
    w -= (w % sf)
    return im[:h, :w,]

#-----------------------------------------Transform--------------------------------------------
class Bicubic:
    def __init__(self, scale=None, out_shape=None, matlab_mode=True):
        self.scale = scale
        self.out_shape = out_shape

    def __call__(self, im):
        out = cv2.resize(
                im,
                dsize=self.out_shape,
                fx=self.scale,
                fy=self.scale,
                interpolation=cv2.INTER_CUBIC,
                )
        return out


================================================
FILE: configs/degradation_testing_realesrgan.yaml
================================================
degradation:
  sf: 4
  # the first degradation process
  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
  resize_range: [0.5, 1.5]
  gaussian_noise_prob: 0.5
  noise_range: [1, 15]
  poisson_scale_range: [0.05, 0.3]
  gray_noise_prob: 0.4
  jpeg_range: [70, 95]

  # the second degradation process
  second_order_prob: 0.0
  second_blur_prob: 0.2
  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
  resize_range2: [0.8, 1.2]
  gaussian_noise_prob2: 0.5
  noise_range2: [1, 10]
  poisson_scale_range2: [0.05, 0.2]
  gray_noise_prob2: 0.4
  jpeg_range2: [80, 100]

  gt_size: 512

opts:
  data_source: ~
  im_exts: ['png', 'JPEG']
  io_backend:
    type: disk
  blur_kernel_size: 13
  kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
  kernel_prob: [0.60, 0.40, 0.0, 0.0, 0.0, 0.0]
  sinc_prob: 0.1
  blur_sigma: [0.2, 0.8]
  betag_range: [1.0, 1.5]
  betap_range: [1, 1.2]

  blur_kernel_size2: 7
  kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
  kernel_prob2: [0.60, 0.4, 0.0, 0.0, 0.0, 0.0]
  sinc_prob2: 0.0
  blur_sigma2: [0.2, 0.5]
  betag_range2: [0.5, 0.8]
  betap_range2: [1, 1.2]

  final_sinc_prob: 0.2

  gt_size: ${degradation.gt_size}
  crop_pad_size: ${degradation.gt_size}
  use_hflip: False
  use_rot: False



================================================
FILE: configs/sample-sd-turbo.yaml
================================================
seed: 12345


# Super-resolution settings
basesr:
  sf: 4
  chopping:     # for latent diffusion
    pch_size: 128
    weight_type: Gaussian

# VAE settings
tiled_vae: True
latent_tiled_size: 128
sample_tiled_size: 1024
gradient_checkpointing_vae: True
sliced_vae: False

# classifer-free guidance
cfg_scale: 1.0

# sampling settings 
start_timesteps: 200

# color fixing
color_fix: ~

# Stable Diffusion 
base_model: sd-turbo
sd_pipe:
  target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
  enable_grad_checkpoint: True
  params:
    pretrained_model_name_or_path: stabilityai/sd-turbo
    cache_dir: /mnt/sfs-common/zsyue/modelbase/stable-diffusion/sd-turbo
    use_safetensors: True
    torch_dtype: torch.float16

model_start:
  target: .noise_predictor.NoisePredictor
  ckpt_path: ~           # For initializing
  params:
    in_channels: 3
    down_block_types:
      - AttnDownBlock2D
      - AttnDownBlock2D
    up_block_types:
      - AttnUpBlock2D
      - AttnUpBlock2D
    block_out_channels:
      - 256    # 192, 256
      - 512    # 384, 512
    layers_per_block: 
      - 3
      - 3
    act_fn: silu
    latent_channels: 4
    norm_num_groups: 32
    sample_size: 128
    mid_block_add_attention: True
    resnet_time_scale_shift: default
    temb_channels: 512
    attention_head_dim: 64 
    freq_shift: 0
    flip_sin_to_cos: True
    double_z: True

model_middle:
  target: .noise_predictor.NoisePredictor
  params:
    in_channels: 3
    down_block_types:
      - AttnDownBlock2D
      - AttnDownBlock2D
    up_block_types:
      - AttnUpBlock2D
      - AttnUpBlock2D
    block_out_channels:
      - 256    # 192, 256
      - 512    # 384, 512
    layers_per_block: 
      - 3
      - 3
    act_fn: silu
    latent_channels: 4
    norm_num_groups: 32
    sample_size: 128
    mid_block_add_attention: True
    resnet_time_scale_shift: default
    temb_channels: 512
    attention_head_dim: 64 
    freq_shift: 0
    flip_sin_to_cos: True
    double_z: True


================================================
FILE: configs/sd-turbo-sr-ldis.yaml
================================================
trainer:
  target: trainer.TrainerSDTurboSR

sd_pipe:
  target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
  num_train_steps: 1000
  enable_grad_checkpoint: True
  compile: False
  vae_split: 8
  params:
    pretrained_model_name_or_path: stabilityai/sd-turbo
    cache_dir: weights
    use_safetensors: True
    torch_dtype: torch.float16

llpips:
  target: latent_lpips.lpips.LPIPS
  ckpt_path: weights/vgg16_sdturbo_lpips.pth
  compile: False
  params:
    pretrained: False
    net: vgg16
    lpips: True
    spatial: False
    pnet_rand: False
    pnet_tune: True
    use_dropout: True
    eval_mode: True
    latent: True
    in_chans: 4
    verbose: True

model:
  target: .noise_predictor.NoisePredictor
  ckpt_start_path: ~     # only used for training the intermidiate model
  ckpt_path: ~           # For initializing
  compile: False
  params:
    in_channels: 3
    down_block_types:
      - AttnDownBlock2D
      - AttnDownBlock2D
    up_block_types:
      - AttnUpBlock2D
      - AttnUpBlock2D
    block_out_channels:
      - 256      # 192, 256
      - 512      # 384, 512
    layers_per_block: 
      - 3
      - 3
    act_fn: silu
    latent_channels: 4
    norm_num_groups: 32
    sample_size: 128
    mid_block_add_attention: True
    resnet_time_scale_shift: default
    temb_channels: 512
    attention_head_dim: 64 
    freq_shift: 0
    flip_sin_to_cos: True
    double_z: True

discriminator:
  target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator
  enable_grad_checkpoint: True
  compile: False
  params:
    sample_size: 64
    in_channels: 4
    center_input_sample: False
    flip_sin_to_cos: True
    freq_shift: 0
    down_block_types:
      - DownBlock2D
      - CrossAttnDownBlock2D
      - CrossAttnDownBlock2D
    mid_block_type: UNetMidBlock2DCrossAttn
    up_block_types:
      - CrossAttnUpBlock2D
      - CrossAttnUpBlock2D
      - UpBlock2D
    only_cross_attention: False
    block_out_channels:
      - 128
      - 256
      - 512
    layers_per_block:
      - 1
      - 2
      - 2
    downsample_padding: 1
    mid_block_scale_factor: 1
    dropout: 0.0
    act_fn: silu
    norm_num_groups: 32
    norm_eps: 1e-5
    cross_attention_dim: 1024
    transformer_layers_per_block: 1
    reverse_transformer_layers_per_block: ~
    encoder_hid_dim: ~
    encoder_hid_dim_type: ~
    attention_head_dim:
      - 8 
      - 16 
      - 16 
    num_attention_heads: ~
    dual_cross_attention: False
    use_linear_projection: False
    class_embed_type: ~
    addition_embed_type: text 
    addition_time_embed_dim: 256
    num_class_embeds: ~
    upcast_attention: ~
    resnet_time_scale_shift: default
    resnet_skip_time_act: False
    resnet_out_scale_factor: 1.0
    time_embedding_type: positional
    time_embedding_dim: ~
    time_embedding_act_fn: ~
    timestep_post_act: ~
    time_cond_proj_dim: ~
    conv_in_kernel: 3
    conv_out_kernel: 3
    projection_class_embeddings_input_dim: 2560
    attention_type: default
    class_embeddings_concat: False
    mid_block_only_cross_attention: ~
    cross_attention_norm: ~
    addition_embed_type_num_heads: 64

degradation:
  sf: 4
  # the first degradation process
  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep
  resize_range: [0.15, 1.5]
  gaussian_noise_prob: 0.5
  noise_range: [1, 30]
  poisson_scale_range: [0.05, 3.0]
  gray_noise_prob: 0.4
  jpeg_range: [30, 95]

  # the second degradation process
  second_order_prob: 0.5
  second_blur_prob: 0.8
  resize_prob2: [0.3, 0.4, 0.3]  # up, down, keep
  resize_range2: [0.3, 1.2]
  gaussian_noise_prob2: 0.5
  noise_range2: [1, 25]
  poisson_scale_range2: [0.05, 2.5]
  gray_noise_prob2: 0.4
  jpeg_range2: [30, 95]

  gt_size: 512 
  resize_back: False
  use_sharp: False

data:
  train:
    type: realesrgan
    params:
      data_source: 
        source1:
          root_path: /mnt/sfs-common/zsyue/database/FFHQ
          image_path: images1024
          moment_path: ~
          text_path: ~
          im_ext: png
          length: 20000
        source2:
          root_path: /mnt/sfs-common/zsyue/database/LSDIR/train
          image_path: images 
          moment_path: ~
          text_path: ~
          im_ext: png
      max_token_length: 77   # 77
      io_backend:
        type: disk
      blur_kernel_size: 21
      kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob: 0.1
      blur_sigma: [0.2, 3.0]
      betag_range: [0.5, 4.0]
      betap_range: [1, 2.0]

      blur_kernel_size2: 15
      kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
      kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
      sinc_prob2: 0.1
      blur_sigma2: [0.2, 1.5]
      betag_range2: [0.5, 4.0]
      betap_range2: [1, 2.0]

      final_sinc_prob: 0.8

      gt_size: ${degradation.gt_size}
      use_hflip: True
      use_rot: False
      random_crop: True
  val:
    type: base
    params:
      dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq
      transform_type: default
      transform_kwargs:
          mean: 0.0
          std: 1.0
      extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt
      extra_transform_type: default
      extra_transform_kwargs:
          mean: 0.0
          std: 1.0
      im_exts: png
      length: 16
      recursive: False

train:
  # predict started inverser
  start_mode: True
  # learning rate
  lr: 5e-5                      # learning rate 
  lr_min: 5e-5                  # learning rate 
  lr_schedule: ~
  warmup_iterations: 2000
  # discriminator
  lr_dis: 5e-5                  # learning rate for dicriminator
  weight_decay_dis: 1e-3        # weight decay for dicriminator
  dis_init_iterations: 10000    # iterations used for updating the discriminator
  dis_update_freq: 1            
  # dataloader
  batch: 64               
  microbatch: 16 
  num_workers: 4
  prefetch_factor: 2            
  use_text: True
  # optimization settings
  weight_decay: 0               
  ema_rate: 0.999
  iterations: 200000            # total iterations
  # logging
  save_freq: 5000
  log_freq: [200, 5000]         # [training loss, training images, val images]
  local_logging: True           # manually save images
  tf_logging: False             # tensorboard logging
  # loss 
  loss_type: L2
  loss_coef:
    ldif: 1.0
  timesteps: [200, 100]
  num_inference_steps: 5
  # mixed precision
  use_amp: True                
  use_fsdp: False                
  # random seed 
  seed: 123456                 
  global_seeding: False
  noise_detach: False

validate:
  batch: 2
  use_ema: True            
  log_freq: 4      # logging frequence
  val_y_channel: True


================================================
FILE: node.py
================================================
from .comfyui_invsr_trimmed import get_configs, InvSamplerSR, BaseSampler, Namespace
import torch
from comfy.utils import ProgressBar
from folder_paths import get_full_path, get_folder_paths, models_dir
import os
import torch.nn.functional as F

def split_tensor_into_batches(tensor, batch_size):
    """
    Split a tensor into smaller batches of specified size
    
    Args:
        tensor (torch.Tensor): Input tensor of shape (N, C, H, W)
        batch_size (int): Desired batch size for splitting
        
    Returns:
        list: List of tensors, each with batch_size (except possibly the last one)
    """
    # Get original batch size
    original_batch_size = tensor.size(0)
    
    # Calculate number of full batches and remaining samples
    num_full_batches = original_batch_size // batch_size
    remaining_samples = original_batch_size % batch_size
    
    # Split tensor into chunks
    batches = []
    
    # Handle full batches
    for i in range(num_full_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        batch = tensor[start_idx:end_idx]
        batches.append(batch)
    
    # Handle remaining samples if any
    if remaining_samples > 0:
        last_batch = tensor[-remaining_samples:]
        batches.append(last_batch)
    
    return batches


class LoadInvSRModels:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "sd_model": (['stabilityai/sd-turbo'],),
                "invsr_model": (['noise_predictor_sd_turbo_v5.pth', 'noise_predictor_sd_turbo_v5_diftune.pth'],),
                "dtype": (['fp16', 'fp32', 'bf16'], {"default": "fp16"}),
                "tiled_vae": ("BOOLEAN", {"default": True}),
            },
        }

    RETURN_TYPES = ("INVSR_PIPE",)
    RETURN_NAMES = ("invsr_pipe",)
    FUNCTION = "loadmodel"
    CATEGORY = "INVSR"

    def loadmodel(self, sd_model, invsr_model, dtype, tiled_vae):
        match dtype:
            case "fp16":
                dtype = "torch.float16"
            case "fp32":
                dtype = "torch.float32"
            case "bf16":
                dtype = "torch.bfloat16"

        cfg_path = os.path.join(
            os.path.dirname(__file__), "configs", "sample-sd-turbo.yaml"
        )
        sd_path = get_folder_paths("diffusers")[0]

        try:
            ckpt_dir = get_folder_paths("invsr")[0]
        except:
            ckpt_dir = os.path.join(models_dir, "invsr")

        args = Namespace(
            bs=1,
            chopping_bs=8,
            timesteps=None,
            num_steps=1,
            cfg_path=
Download .txt
gitextract_393ad4r7/

├── .github/
│   ├── FUNDING.yml
│   └── workflows/
│       └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── comfyui_invsr_trimmed/
│   ├── __init__.py
│   ├── inference_invsr.py
│   ├── latent_lpips/
│   │   ├── __init__.py
│   │   ├── lpips.py
│   │   └── pretrained_networks.py
│   ├── noise_predictor.py
│   ├── pipeline_stable_diffusion_inversion_sr.py
│   ├── sampler_invsr.py
│   ├── time_aware_encoder.py
│   └── utils/
│       ├── __init__.py
│       ├── resize.py
│       ├── util_color_fix.py
│       ├── util_common.py
│       ├── util_ema.py
│       ├── util_image.py
│       ├── util_net.py
│       ├── util_opts.py
│       └── util_sisr.py
├── configs/
│   ├── degradation_testing_realesrgan.yaml
│   ├── sample-sd-turbo.yaml
│   └── sd-turbo-sr-ldis.yaml
├── node.py
├── pyproject.toml
├── requirements.txt
└── workflows/
    └── invsr.json
Download .txt
SYMBOL INDEX (199 symbols across 16 files)

FILE: comfyui_invsr_trimmed/inference_invsr.py
  class Namespace (line 15) | class Namespace:
    method __init__ (line 16) | def __init__(self, **kwargs):
    method __repr__ (line 20) | def __repr__(self):
  function get_configs (line 24) | def get_configs(args, log=False):

FILE: comfyui_invsr_trimmed/latent_lpips/lpips.py
  function normalize_tensor (line 11) | def normalize_tensor(in_feat,eps=1e-10):
  function spatial_average (line 15) | def spatial_average(in_tens, keepdim=True):
  function upsample (line 18) | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same fo...
  class LPIPS (line 23) | class LPIPS(nn.Module):
    method __init__ (line 24) | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=T...
    method forward (line 126) | def forward(self, in0, in1, retPerLayer=False, normalize=False):
  class ScalingLayer (line 160) | class ScalingLayer(nn.Module):
    method __init__ (line 161) | def __init__(self):
    method forward (line 166) | def forward(self, inp):
  class NetLinLayer (line 169) | class NetLinLayer(nn.Module):
    method __init__ (line 171) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
    method forward (line 178) | def forward(self, x):
  class Dist2LogitLayer (line 181) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 183) | def __init__(self, chn_mid=32, use_sigmoid=True):
    method forward (line 195) | def forward(self,d0,d1,eps=0.1):
  class BCERankingLoss (line 198) | class BCERankingLoss(nn.Module):
    method __init__ (line 199) | def __init__(self, chn_mid=32):
    method forward (line 206) | def forward(self, d0, d1, judge):
  function print_network (line 211) | def print_network(net):

FILE: comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py
  class squeezenet (line 5) | class squeezenet(torch.nn.Module):
    method __init__ (line 6) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 35) | def forward(self, X):
  class alexnet (line 56) | class alexnet(torch.nn.Module):
    method __init__ (line 57) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 80) | def forward(self, X):
  class vgg16 (line 96) | class vgg16(torch.nn.Module):
    method __init__ (line 97) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 120) | def forward(self, X):
  class vgg16_latent (line 136) | class vgg16_latent(torch.nn.Module):
    method __init__ (line 137) | def __init__(self, requires_grad=False, pretrained=True, in_chans=3):
    method forward (line 168) | def forward(self, X):
  class resnet (line 185) | class resnet(torch.nn.Module):
    method __init__ (line 186) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 209) | def forward(self, X):

FILE: comfyui_invsr_trimmed/noise_predictor.py
  class NoisePredictor (line 23) | class NoisePredictor(ModelMixin, ConfigMixin, FromOriginalModelMixin):
    method __init__ (line 56) | def __init__(
    method _set_gradient_checkpointing (line 111) | def _set_gradient_checkpointing(self, module, value=False):
    method enable_tiling (line 115) | def enable_tiling(self, use_tiling: bool = True):
    method disable_tiling (line 123) | def disable_tiling(self):
    method enable_slicing (line 130) | def enable_slicing(self):
    method disable_slicing (line 137) | def disable_slicing(self):
    method attn_processors (line 146) | def attn_processors(self) -> Dict[str, AttentionProcessor]:
    method set_attn_processor (line 174) | def set_attn_processor(
    method set_default_attn_processor (line 211) | def set_default_attn_processor(self):
    method encode (line 233) | def encode(
    method tiled_encode (line 273) | def tiled_encode(
    method forward (line 338) | def forward(

FILE: comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py
  function retrieve_latents (line 79) | def retrieve_latents(
  function preprocess (line 92) | def preprocess(image):
  function retrieve_timesteps (line 116) | def retrieve_timesteps(
  class StableDiffusionInvEnhancePipeline (line 170) | class StableDiffusionInvEnhancePipeline(
    method __init__ (line 216) | def __init__(
    method _encode_prompt (line 309) | def _encode_prompt(
    method encode_prompt (line 342) | def encode_prompt(
    method encode_image (line 525) | def encode_image(self, image, device, num_images_per_prompt, output_hi...
    method prepare_ip_adapter_image_embeds (line 550) | def prepare_ip_adapter_image_embeds(
    method run_safety_checker (line 596) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 611) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 623) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 640) | def check_inputs(
    method get_timesteps (line 708) | def get_timesteps(self, num_inference_steps, strength, device):
    method prepare_latents (line 719) | def prepare_latents(
    method get_guidance_scale_embedding (line 789) | def get_guidance_scale_embedding(
    method guidance_scale (line 820) | def guidance_scale(self):
    method clip_skip (line 824) | def clip_skip(self):
    method do_classifier_free_guidance (line 831) | def do_classifier_free_guidance(self):
    method cross_attention_kwargs (line 835) | def cross_attention_kwargs(self):
    method num_timesteps (line 839) | def num_timesteps(self):
    method interrupt (line 843) | def interrupt(self):
    method __call__ (line 848) | def __call__(

FILE: comfyui_invsr_trimmed/sampler_invsr.py
  class BaseSampler (line 30) | class BaseSampler:
    method __init__ (line 31) | def __init__(self, configs):
    method setup_seed (line 44) | def setup_seed(self, seed=None):
    method write_log (line 51) | def write_log(self, log_str):
    method build_model (line 54) | def build_model(self):
  class InvSamplerSR (line 109) | class InvSamplerSR(BaseSampler):
    method __init__ (line 110) | def __init__(self, base_sampler):
    method sample_func (line 115) | def sample_func(self, im_cond):
    method inference (line 226) | def inference(self, image_bchw):
  function get_torch_dtype (line 229) | def get_torch_dtype(torch_dtype: str):

FILE: comfyui_invsr_trimmed/time_aware_encoder.py
  class TimeAwareEncoder (line 15) | class TimeAwareEncoder(nn.Module):
    method __init__ (line 40) | def __init__(
    method forward (line 125) | def forward(

FILE: comfyui_invsr_trimmed/utils/resize.py
  function nearest_contribution (line 32) | def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
  function linear_contribution (line 38) | def linear_contribution(x: torch.Tensor) -> torch.Tensor:
  function cubic_contribution (line 45) | def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
  function gaussian_contribution (line 63) | def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch....
  function discrete_kernel (line 71) | def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True...
  function reflect_padding (line 101) | def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: i...
  function padding (line 131) | def padding(x: torch.Tensor,
  function get_padding (line 146) | def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> ty...
  function get_weight (line 167) | def get_weight(dist: torch.Tensor,
  function reshape_tensor (line 189) | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch...
  function reshape_input (line 206) | def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I,...
  function reshape_output (line 222) | def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
  function cast_input (line 237) | def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
  function cast_output (line 247) | def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
  function resize_1d (line 260) | def resize_1d(x: torch.Tensor,
  function downsampling_2d (line 334) | def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, paddin...
  function imresize (line 354) | def imresize(x: torch.Tensor,

FILE: comfyui_invsr_trimmed/utils/util_color_fix.py
  function calc_mean_std (line 16) | def calc_mean_std(feat: Tensor, eps=1e-5):
  function adaptive_instance_normalization (line 31) | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tens...
  function wavelet_blur (line 45) | def wavelet_blur(image: Tensor, radius: int):
  function wavelet_decomposition (line 66) | def wavelet_decomposition(image: Tensor, levels=5):
  function wavelet_reconstruction (line 80) | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
  function ycbcr_color_replace (line 93) | def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):

FILE: comfyui_invsr_trimmed/utils/util_common.py
  function mkdir (line 11) | def mkdir(dir_path, delete=False, parents=True):
  function get_obj_from_str (line 21) | def get_obj_from_str(string, reload=False):
  function instantiate_from_config (line 32) | def instantiate_from_config(config):
  function str2bool (line 37) | def str2bool(v):
  function get_filenames (line 47) | def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):
  function readline_txt (line 65) | def readline_txt(txt_file):
  function scan_files_from_folder (line 74) | def scan_files_from_folder(dir_paths, exts, recursive=True):
  function write_path_to_txt (line 96) | def write_path_to_txt(

FILE: comfyui_invsr_trimmed/utils/util_ema.py
  class LitEma (line 5) | class LitEma(nn.Module):
    method __init__ (line 6) | def __init__(self, model, decay=0.9999, use_num_upates=True):
    method reset_num_updates (line 25) | def reset_num_updates(self):
    method forward (line 29) | def forward(self, model):
    method copy_to (line 50) | def copy_to(self, model):
    method store (line 64) | def store(self, parameters):
    method restore (line 73) | def restore(self, parameters):
    method resume (line 87) | def resume(self, ckpt, num_updates):

FILE: comfyui_invsr_trimmed/utils/util_image.py
  function ssim (line 14) | def ssim(img1, img2):
  function calculate_ssim (line 36) | def calculate_ssim(im1, im2, border=0, ycbcr=False):
  function calculate_psnr (line 65) | def calculate_psnr(im1, im2, border=0, ycbcr=False):
  function normalize_np (line 88) | def normalize_np(im, mean=0.5, std=0.5, reverse=False):
  function normalize_th (line 110) | def normalize_th(im, mean=0.5, std=0.5, reverse=False):
  function rgb2ycbcr (line 133) | def rgb2ycbcr(im, only_y=True):
  function rgb2ycbcrTorch (line 159) | def rgb2ycbcrTorch(im, only_y=True):
  function ycbcr2rgbTorch (line 186) | def ycbcr2rgbTorch(im):
  function bgr2rgb (line 210) | def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  function rgb2bgr (line 212) | def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
  function tensor2img (line 214) | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
  function imresize_np (line 274) | def imresize_np(img, scale, antialiasing=True):
  function calculate_weights_indices (line 346) | def calculate_weights_indices(in_length, out_length, scale, kernel, kern...
  function cubic (line 401) | def cubic(x):
  function imread (line 409) | def imread(path, chn='rgb', dtype='float32', force_gray2rgb=True, force_...
  function data_aug_np (line 455) | def data_aug_np(image, mode):
  function inverse_data_aug_np (line 502) | def inverse_data_aug_np(image, mode):
  function imshow (line 533) | def imshow(x, title=None, cbar=False):
  function imblend_with_mask (line 542) | def imblend_with_mask(im, mask, alpha=0.25):
  function imgrad (line 565) | def imgrad(im, pading_mode='mirror'):
  function convtorch (line 595) | def convtorch(im, weight, mode='reflect'):
  function random_crop (line 611) | def random_crop(im, pch_size):
  class ToTensor (line 640) | class ToTensor:
    method __init__ (line 641) | def __init__(self, max_value=1.0):
    method __call__ (line 644) | def __call__(self, im):
  class RandomCrop (line 656) | class RandomCrop:
    method __init__ (line 657) | def __init__(self, pch_size, pass_crop=False):
    method __call__ (line 661) | def __call__(self, im):
  class ImageSpliterNp (line 672) | class ImageSpliterNp:
    method __init__ (line 673) | def __init__(self, im, pch_size, stride, sf=1):
    method extract_starts (line 698) | def extract_starts(self, length):
    method __len__ (line 704) | def __len__(self):
    method __iter__ (line 707) | def __iter__(self):
    method __next__ (line 710) | def __next__(self):
    method update (line 730) | def update(self, pch_res, index_infos):
    method gather (line 745) | def gather(self):
  class ImageSpliterTh (line 749) | class ImageSpliterTh:
    method __init__ (line 750) | def __init__(self, im, pch_size, stride, sf=1, extra_bs=1, weight_type...
    method extract_starts (line 784) | def extract_starts(self, length):
    method __len__ (line 795) | def __len__(self):
    method __iter__ (line 798) | def __iter__(self):
    method __next__ (line 801) | def __next__(self):
    method update (line 826) | def update(self, pch_res, index_infos):
    method generate_kernel_1d (line 842) | def generate_kernel_1d(ksize):
    method get_weight (line 852) | def get_weight(self, height, width):
    method gather (line 865) | def gather(self):
  class Clamper (line 870) | class Clamper:
    method __init__ (line 871) | def __init__(self, min_max=(-1, 1)):
    method __call__ (line 874) | def __call__(self, im):
  class Bicubic (line 883) | class Bicubic:
    method __init__ (line 884) | def __init__(self, scale=None, out_shape=None, activate_matlab=True, r...
    method __call__ (line 890) | def __call__(self, im):
  class SmallestMaxSize (line 913) | class SmallestMaxSize:
    method __init__ (line 914) | def __init__(self, max_size, pass_resize=False, interpolation=None):
    method get_interpolation (line 926) | def get_interpolation(self, size):
    method __call__ (line 937) | def __call__(self, im):
  class SpatialAug (line 956) | class SpatialAug:
    method __init__ (line 957) | def __init__(self, pass_aug, only_hflip=False, only_vflip=False, only_...
    method __call__ (line 963) | def __call__(self, im, flag=None):

FILE: comfyui_invsr_trimmed/utils/util_net.py
  function reload_model (line 5) | def reload_model(model, ckpt):

FILE: comfyui_invsr_trimmed/utils/util_opts.py
  function update_args (line 7) | def update_args(args_json, args_parser):
  function str2bool (line 11) | def str2bool(v):

FILE: comfyui_invsr_trimmed/utils/util_sisr.py
  function modcrop (line 8) | def modcrop(im, sf):
  class Bicubic (line 15) | class Bicubic:
    method __init__ (line 16) | def __init__(self, scale=None, out_shape=None, matlab_mode=True):
    method __call__ (line 20) | def __call__(self, im):

FILE: node.py
  function split_tensor_into_batches (line 8) | def split_tensor_into_batches(tensor, batch_size):
  class LoadInvSRModels (line 44) | class LoadInvSRModels:
    method INPUT_TYPES (line 46) | def INPUT_TYPES(s):
    method loadmodel (line 61) | def loadmodel(self, sd_model, invsr_model, dtype, tiled_vae):
  class InvSRSampler (line 99) | class InvSRSampler:
    method INPUT_TYPES (line 101) | def INPUT_TYPES(s):
    method process (line 122) | def process(self, invsr_pipe, images, num_steps, cfg, batch_size, chop...
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (212K chars).
[
  {
    "path": ".github/FUNDING.yml",
    "chars": 100,
    "preview": "github: yuvraj108c\ncustom: [\"https://paypal.me/yuvraj108c\", \"https://buymeacoffee.com/yuvraj108cz\"]\n"
  },
  {
    "path": ".github/workflows/publish.yml",
    "chars": 705,
    "preview": "name: Publish to Comfy registry\non:\n  workflow_dispatch:\n  push:\n    branches:\n      - main\n      - master\n    paths:\n  "
  },
  {
    "path": ".gitignore",
    "chars": 80,
    "preview": ".DS_Store\n*pyc\n.vscode\n__pycache__\n# *.egg-info\n*.bak\ncheckpoints\nresults\nbackup"
  },
  {
    "path": "LICENSE",
    "chars": 1718,
    "preview": "S-Lab License 1.0\n\nCopyright 2024 S-Lab\n\nRedistribution and use for non-commercial purpose in source and \nbinary forms, "
  },
  {
    "path": "README.md",
    "chars": 4450,
    "preview": "<div align=\"center\">\n\n# ComfyUI InvSR\n[![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.09013-b31b1b.svg)](https"
  },
  {
    "path": "__init__.py",
    "chars": 334,
    "preview": "from .node import LoadInvSRModels, InvSRSampler\n \nNODE_CLASS_MAPPINGS = { \n    \"LoadInvSRModels\" : LoadInvSRModels,\n    "
  },
  {
    "path": "comfyui_invsr_trimmed/__init__.py",
    "chars": 337,
    "preview": "from .inference_invsr import get_configs, Namespace\nfrom .sampler_invsr import InvSamplerSR, BaseSampler\nfrom .noise_pre"
  },
  {
    "path": "comfyui_invsr_trimmed/inference_invsr.py",
    "chars": 2905,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2023-03-11 17:17:41\n\nimport numpy as np\nfrom pathl"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/__init__.py",
    "chars": 110,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/lpips.py",
    "chars": 9234,
    "preview": "\nfrom __future__ import absolute_import\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.auto"
  },
  {
    "path": "comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py",
    "chars": 8701,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models as tv\n\nclass squeezenet(torch.nn.Module):"
  },
  {
    "path": "comfyui_invsr_trimmed/noise_predictor.py",
    "chars": 15116,
    "preview": "from typing import Dict, Optional, Tuple, Union\nimport torch\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom"
  },
  {
    "path": "comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py",
    "chars": 59886,
    "preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  },
  {
    "path": "comfyui_invsr_trimmed/sampler_invsr.py",
    "chars": 10145,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-07-13 16:59:27\n\nimport os, sys, math, random\n"
  },
  {
    "path": "comfyui_invsr_trimmed/time_aware_encoder.py",
    "chars": 7544,
    "preview": "from dataclasses import dataclass\nfrom typing import Optional, Tuple, Union\nimport numpy as np\nimport torch\nimport torch"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/__init__.py",
    "chars": 92,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-01-18 11:40:23\n\n\n"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/resize.py",
    "chars": 12829,
    "preview": "\"\"\"\nA standalone PyTorch implementation for fast and efficient bicubic resampling.\nThe resulting values are the same to "
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_color_fix.py",
    "chars": 3965,
    "preview": "'''\n# --------------------------------------------------------------------------------\n#   Color fixed script from Li Yi"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_common.py",
    "chars": 4291,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2022-02-06 10:34:59\n\nimport os\nimport random\nimpor"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_ema.py",
    "chars": 4046,
    "preview": "import torch\nfrom torch import nn\n\n\nclass LitEma(nn.Module):\n    def __init__(self, model, decay=0.9999, use_num_upates="
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_image.py",
    "chars": 34613,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 16:54:19\n\nimport sys\nimport cv2\nimport "
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_net.py",
    "chars": 642,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 20:29:36\n\ndef reload_model(model, ckpt)"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_opts.py",
    "chars": 623,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 15:07:43\n\nimport argparse\n\ndef update_a"
  },
  {
    "path": "comfyui_invsr_trimmed/utils/util_sisr.py",
    "chars": 729,
    "preview": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-12-07 21:37:58\n\nimport cv2\nimport numpy as np"
  },
  {
    "path": "configs/degradation_testing_realesrgan.yaml",
    "chars": 1334,
    "preview": "degradation:\n  sf: 4\n  # the first degradation process\n  resize_prob: [0.2, 0.7, 0.1]  # up, down, keep\n  resize_range: "
  },
  {
    "path": "configs/sample-sd-turbo.yaml",
    "chars": 2021,
    "preview": "seed: 12345\n\n\n# Super-resolution settings\nbasesr:\n  sf: 4\n  chopping:     # for latent diffusion\n    pch_size: 128\n    w"
  },
  {
    "path": "configs/sd-turbo-sr-ldis.yaml",
    "chars": 6883,
    "preview": "trainer:\n  target: trainer.TrainerSDTurboSR\n\nsd_pipe:\n  target: diffusers.pipelines.stable_diffusion.pipeline_stable_dif"
  },
  {
    "path": "node.py",
    "chars": 6314,
    "preview": "from .comfyui_invsr_trimmed import get_configs, InvSamplerSR, BaseSampler, Namespace\nimport torch\nfrom comfy.utils impor"
  },
  {
    "path": "pyproject.toml",
    "chars": 602,
    "preview": "[project]\nname = \"invsr\"\ndescription = \"This project is an unofficial ComfyUI implementation of [a/InvSR](https://github"
  },
  {
    "path": "requirements.txt",
    "chars": 88,
    "preview": "opencv-contrib-python-headless\nomegaconf\ndiffusers\nnumpy<2\nhuggingface-hub\ntransformers\n"
  },
  {
    "path": "workflows/invsr.json",
    "chars": 3308,
    "preview": "{\n  \"last_node_id\": 27,\n  \"last_link_id\": 40,\n  \"nodes\": [\n    {\n      \"id\": 25,\n      \"type\": \"LoadInvSRModels\",\n      "
  }
]

About this extraction

This page contains the full source code of the yuvraj108c/ComfyUI_InvSR GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (199.0 KB), approximately 51.9k tokens, and a symbol index with 199 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!