Repository: yuvraj108c/ComfyUI_InvSR Branch: main Commit: 20a0e003e676 Files: 31 Total size: 199.0 KB Directory structure: gitextract_393ad4r7/ ├── .github/ │ ├── FUNDING.yml │ └── workflows/ │ └── publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── comfyui_invsr_trimmed/ │ ├── __init__.py │ ├── inference_invsr.py │ ├── latent_lpips/ │ │ ├── __init__.py │ │ ├── lpips.py │ │ └── pretrained_networks.py │ ├── noise_predictor.py │ ├── pipeline_stable_diffusion_inversion_sr.py │ ├── sampler_invsr.py │ ├── time_aware_encoder.py │ └── utils/ │ ├── __init__.py │ ├── resize.py │ ├── util_color_fix.py │ ├── util_common.py │ ├── util_ema.py │ ├── util_image.py │ ├── util_net.py │ ├── util_opts.py │ └── util_sisr.py ├── configs/ │ ├── degradation_testing_realesrgan.yaml │ ├── sample-sd-turbo.yaml │ └── sd-turbo-sr-ldis.yaml ├── node.py ├── pyproject.toml ├── requirements.txt └── workflows/ └── invsr.json ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ github: yuvraj108c custom: ["https://paypal.me/yuvraj108c", "https://buymeacoffee.com/yuvraj108cz"] ================================================ FILE: .github/workflows/publish.yml ================================================ name: Publish to Comfy registry on: workflow_dispatch: push: branches: - main - master paths: - "pyproject.toml" permissions: issues: write jobs: publish-node: name: Publish Custom Node to registry runs-on: ubuntu-latest if: ${{ github.repository_owner == 'yuvraj108c' }} steps: - name: Check out code uses: actions/checkout@v4 with: submodules: true - name: Publish Custom Node uses: Comfy-Org/publish-node-action@v1 with: ## Add your own personal access token to your Github Repository secrets and reference it here. personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} ================================================ FILE: .gitignore ================================================ .DS_Store *pyc .vscode __pycache__ # *.egg-info *.bak checkpoints results backup ================================================ FILE: LICENSE ================================================ S-Lab License 1.0 Copyright 2024 S-Lab Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. ================================================ FILE: README.md ================================================
# ComfyUI InvSR [![arXiv](https://img.shields.io/badge/arXiv%20paper-2412.09013-b31b1b.svg)](https://arxiv.org/abs/2412.09013) This project is a ComfyUI wrapper for [InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion) **Last tested**: 2 January 2026 (ComfyUI v0.7.0@f2fda02 | Torch 2.9.1 | Python 3.10.12 | RTX4090 | CUDA 13.0 | Debian 12)
## ⭐ Support If you like my projects and wish to see updates and new features, please consider supporting me. It helps a lot! [![ComfyUI-Depth-Anything-Tensorrt](https://img.shields.io/badge/ComfyUI--Depth--Anything--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt) [![ComfyUI-Upscaler-Tensorrt](https://img.shields.io/badge/ComfyUI--Upscaler--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt) [![ComfyUI-Dwpose-Tensorrt](https://img.shields.io/badge/ComfyUI--Dwpose--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Dwpose-Tensorrt) [![ComfyUI-Rife-Tensorrt](https://img.shields.io/badge/ComfyUI--Rife--Tensorrt-blue?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Rife-Tensorrt) [![ComfyUI-Whisper](https://img.shields.io/badge/ComfyUI--Whisper-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Whisper) [![ComfyUI_InvSR](https://img.shields.io/badge/ComfyUI__InvSR-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI_InvSR) [![ComfyUI-Thera](https://img.shields.io/badge/ComfyUI--Thera-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Thera) [![ComfyUI-Video-Depth-Anything](https://img.shields.io/badge/ComfyUI--Video--Depth--Anything-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-Video-Depth-Anything) [![ComfyUI-PiperTTS](https://img.shields.io/badge/ComfyUI--PiperTTS-gray?style=flat-square)](https://github.com/yuvraj108c/ComfyUI-PiperTTS) [![buy-me-coffees](https://i.imgur.com/3MDbAtw.png)](https://www.buymeacoffee.com/yuvraj108cZ) [![paypal-donation](https://i.imgur.com/w5jjubk.png)](https://paypal.me/yuvraj108c) --- ## Installation Navigate to the ComfyUI `/custom_nodes` directory ```bash git clone https://github.com/yuvraj108c/ComfyUI_InvSR cd ComfyUI_InvSR pip install -r requirements.txt ``` ## Usage - Load [example workflow](workflows/invsr.json) - Diffusers model (stabilityai/sd-turbo) will download automatically to `ComfyUI/models/diffusers` - InvSR model (noise_predictor_sd_turbo_v5.pth) will download automatically to `ComfyUI/models/invsr` - To deal with large images, e.g, 1k---->4k, set `chopping_size` 256 - If your GPU memory is limited, please set `chopping_batch_size` to 1 ## Parameters - `num_steps`: number of inference steps - `cfg`: classifier-free guidance scale - `batch_size`: Controls how many complete images are processed simultaneously - `chopping_batch_size`: Controls how many patches from the same image are processed simultaneously - `chopping_size`: Controls the size of patches when splitting large images - `color_fix`: Method to fix color shift in processed images ## Updates **28 April 2025** - Update diffusers versions in requirements.txt to fix https://github.com/yuvraj108c/ComfyUI_InvSR/issues/26, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/21, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/15 - Add support for `noise_predictor_sd_turbo_v5_diftune.pth` **03 February 2025** - Add cfg parameter - Make image divisible by 16 - Use `mm` to set torch device **31 January 2025** - Merged https://github.com/yuvraj108c/ComfyUI_InvSR/pull/5 by [wfjsw](https://github.com/wfjsw) - Compatibility with `diffusers>=0.28` - Massive code refactoring & cleanup ## Citation ```bibtex @article{yue2024InvSR, title={Arbitrary-steps Image Super-resolution via Diffusion Inversion}, author={Yue, Zongsheng and Kang, Liao and Loy, Chen Change}, journal = {arXiv preprint arXiv:2412.09013}, year={2024}, } ``` ## License This project is licensed under [NTU S-Lab License 1.0](LICENSE) ## Acknowledgments Thanks to [simplepod.ai](https://simplepod.ai/) for providing GPU servers ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=yuvraj108c/ComfyUI_InvSR&type=Date)](https://star-history.com/#yuvraj108c/ComfyUI_InvSR&Date) ================================================ FILE: __init__.py ================================================ from .node import LoadInvSRModels, InvSRSampler NODE_CLASS_MAPPINGS = { "LoadInvSRModels" : LoadInvSRModels, "InvSRSampler" : InvSRSampler } NODE_DISPLAY_NAME_MAPPINGS = { "LoadInvSRModels" : "Load InvSR Models", "InvSRSampler" : "InvSRSampler" } __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] ================================================ FILE: comfyui_invsr_trimmed/__init__.py ================================================ from .inference_invsr import get_configs, Namespace from .sampler_invsr import InvSamplerSR, BaseSampler from .noise_predictor import NoisePredictor from .time_aware_encoder import TimeAwareEncoder __all__ = [ "get_configs", "Namespace", "InvSamplerSR", "BaseSampler", "NoisePredictor", "TimeAwareEncoder" ] ================================================ FILE: comfyui_invsr_trimmed/inference_invsr.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2023-03-11 17:17:41 import numpy as np from pathlib import Path from omegaconf import OmegaConf from .sampler_invsr import InvSamplerSR, BaseSampler from .utils import util_common from .utils.util_opts import str2bool from huggingface_hub import hf_hub_download from shutil import copy2 class Namespace: def __init__(self, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) def __repr__(self): items = [f"{key}={repr(value)}" for key, value in vars(self).items()] return f"Namespace({', '.join(items)})" def get_configs(args, log=False): configs = OmegaConf.load(args.cfg_path) if args.timesteps is not None: assert len(args.timesteps) == args.num_steps configs.timesteps = sorted(args.timesteps, reverse=True) else: if args.num_steps == 1: configs.timesteps = [200,] elif args.num_steps == 2: configs.timesteps = [200, 100] elif args.num_steps == 3: configs.timesteps = [200, 100, 50] elif args.num_steps == 4: configs.timesteps = [200, 150, 100, 50] elif args.num_steps == 5: configs.timesteps = [250, 200, 150, 100, 50] else: assert args.num_steps <= 250 configs.timesteps = np.linspace( start=args.started_step, stop=0, num=args.num_steps, endpoint=False, dtype=np.int64() ).tolist() if log: print(f'[InvSR] - Setting timesteps for inference: {configs.timesteps}') # path to save Stable Diffusion sd_path = args.sd_path if args.sd_path else "./weights" util_common.mkdir(sd_path, delete=False, parents=True) configs.sd_pipe.params.cache_dir = sd_path # path to save noise predictor started_ckpt_name = args.invsr_model if getattr(args, "started_ckpt_dir", None) is not None: started_ckpt_dir = args.started_ckpt_dir else: started_ckpt_dir = "./weights" if getattr(args, "started_ckpt_path", None) is not None: started_ckpt_path = args.started_ckpt_path else: started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name util_common.mkdir(started_ckpt_dir, delete=False, parents=True) if not Path(started_ckpt_path).exists(): temp_path = hf_hub_download( repo_id="OAOA/InvSR", filename=started_ckpt_name, ) copy2(temp_path, started_ckpt_path) configs.model_start.ckpt_path = str(started_ckpt_path) configs.bs = args.bs configs.tiled_vae = args.tiled_vae configs.color_fix = args.color_fix configs.basesr.chopping.pch_size = args.chopping_size if args.bs > 1: configs.basesr.chopping.extra_bs = 1 else: configs.basesr.chopping.extra_bs = args.chopping_bs return configs ================================================ FILE: comfyui_invsr_trimmed/latent_lpips/__init__.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function ================================================ FILE: comfyui_invsr_trimmed/latent_lpips/lpips.py ================================================ from __future__ import absolute_import import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import numpy as np from . import pretrained_networks as pn def normalize_tensor(in_feat,eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) return in_feat/(norm_factor+eps) def spatial_average(in_tens, keepdim=True): return in_tens.mean([2,3],keepdim=keepdim) def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W in_H, in_W = in_tens.shape[2], in_tens.shape[3] return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class LPIPS(nn.Module): def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, latent=False, in_chans=3, verbose=True): """ Initializes a perceptual loss torch.nn.Module Parameters (default listed first) --------------------------------- lpips : bool [True] use linear layers on top of base/trunk network [False] means no linear layers; each layer is averaged together pretrained : bool This flag controls the linear layers, which are only in effect when lpips=True above [True] means linear layers are calibrated with human perceptual judgments [False] means linear layers are randomly initialized pnet_rand : bool [False] means trunk loaded with ImageNet classification weights [True] means randomly initialized trunk net : str ['alex','vgg','squeeze'] are the base/trunk networks available version : str ['v0.1'] is the default and latest ['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1) model_path : 'str' [None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1 The following parameters should only be changed if training the network eval_mode : bool [True] is for test mode (default) [False] is for training mode pnet_tune [False] keep base/trunk frozen [True] tune the base/trunk network use_dropout : bool [True] to use dropout when training linear layers [False] for no dropout when training linear layers """ super(LPIPS, self).__init__() if(verbose): print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) self.pnet_type = net self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.latent = latent self.lpips = lpips # false means baseline of just averaging all layers self.version = version self.scaling_layer = ScalingLayer() if(self.pnet_type in ['vgg','vgg16']): if not latent: net_type = pn.vgg16 else: net_type = pn.vgg16_latent self.chns = [64,128,256,512,512] elif(self.pnet_type=='alex'): net_type = pn.alexnet self.chns = [64,192,384,256,256] elif(self.pnet_type=='squeeze'): net_type = pn.squeezenet self.chns = [64,128,256,384,384,512,512] self.L = len(self.chns) if latent: self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune, in_chans=in_chans) else: self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if(lpips): self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] if(self.pnet_type=='squeeze'): # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins+=[self.lin5,self.lin6] self.lins = nn.ModuleList(self.lins) if(pretrained): if(model_path is None): import inspect import os model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) if(verbose): print('Loading model from: %s'%model_path) missing_keys, unexpected_keys = self.load_state_dict( torch.load(model_path, map_location='cpu'), strict=False, ) print(f'Number of missing keys when loading chckepoint: {len(missing_keys)}') print(f'Number of unexpected keys when loading chckepoint: {len(unexpected_keys)}') if(eval_mode): self.eval() def forward(self, in0, in1, retPerLayer=False, normalize=False): if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] in0 = 2 * in0 - 1 in1 = 2 * in1 - 1 # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if (not self.latent and self.version=='0.1') else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = 0 for l in range(self.L): val += res[l] if(retPerLayer): return (val, res) else: return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) def forward(self, inp): return (inp - self.shift) / self.scale class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(),] if(use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class Dist2LogitLayer(nn.Module): ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] if(use_sigmoid): layers += [nn.Sigmoid(),] self.model = nn.Sequential(*layers) def forward(self,d0,d1,eps=0.1): return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid, use_sigmoid=False) # self.parameters = list(self.net.parameters()) # self.loss = torch.nn.BCELoss() self.loss = torch.nn.BCEWithLogitsLoss() def forward(self, d0, d1, judge): per = (judge+1.)/2. self.logit = self.net.forward(d0,d1) return self.loss(self.logit, per) def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network',net) print('Total number of parameters: %d' % num_params) ================================================ FILE: comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py ================================================ from collections import namedtuple import torch from torchvision import models as tv class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(squeezenet, self).__init__() pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() self.slice7 = torch.nn.Sequential() self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) for x in range(2,5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), pretrained_features[x]) for x in range(10, 11): self.slice5.add_module(str(x), pretrained_features[x]) for x in range(11, 12): self.slice6.add_module(str(x), pretrained_features[x]) for x in range(12, 13): self.slice7.add_module(str(x), pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h h = self.slice6(h) h_relu6 = h h = self.slice7(h) h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(alexnet, self).__init__() alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class vgg16_latent(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, in_chans=3): super(vgg16_latent, self).__init__() vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 # max pooling layers: vgg_pretrained_features[5, 9, 16, 23] for x in range(4): assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d) self.slice1.add_module(str(x), vgg_pretrained_features[x]) if (not in_chans == 3): # assert in_chans == 4 weight = self.slice1[0].weight.data[:, 0,].unsqueeze(1).repeat(1, in_chans, 1, 1) self.slice1[0].weight.data = weight for x in range(5, 9): # skip max pooling at index 5 assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d) self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(10, 16): # skip max pooling at index 9 assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d) self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(17, 23): # skip max pooling at index 16 self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): super(resnet, self).__init__() if(num==18): self.net = tv.resnet18(pretrained=pretrained) elif(num==34): self.net = tv.resnet34(pretrained=pretrained) elif(num==50): self.net = tv.resnet50(pretrained=pretrained) elif(num==101): self.net = tv.resnet101(pretrained=pretrained) elif(num==152): self.net = tv.resnet152(pretrained=pretrained) self.N_slices = 5 self.conv1 = self.net.conv1 self.bn1 = self.net.bn1 self.relu = self.net.relu self.maxpool = self.net.maxpool self.layer1 = self.net.layer1 self.layer2 = self.net.layer2 self.layer3 = self.net.layer3 self.layer4 = self.net.layer4 def forward(self, X): h = self.conv1(X) h = self.bn1(h) h = self.relu(h) h_relu1 = h h = self.maxpool(h) h = self.layer1(h) h_conv2 = h h = self.layer2(h) h_conv3 = h h = self.layer3(h) h_conv4 = h h = self.layer4(h) h_conv5 = h outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out ================================================ FILE: comfyui_invsr_trimmed/noise_predictor.py ================================================ from typing import Dict, Optional, Tuple, Union import torch from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders.single_file_model import FromOriginalModelMixin from diffusers.models.autoencoders.vae import ( Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder, ) from diffusers.models.attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnAddedKVProcessor, AttnProcessor, ) from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.utils.accelerate_utils import apply_forward_hook from .time_aware_encoder import TimeAwareEncoder class NoisePredictor(ModelMixin, ConfigMixin, FromOriginalModelMixin): r""" A noise predicted model from the encoder of AutoencoderKL. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): Tuple of downsample block types. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. sample_size (`int`, *optional*, defaults to `32`): Sample input size. mid_block_add_attention (`bool`, *optional*, default to `True`): If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the mid_block will only have resnet blocks temb_channels (`int`, *optional*, default to 256): Number of channels for time embedding freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip sin to cos for Fourier time embedding. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. """ _supports_gradient_checkpointing = True _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"] @register_to_config def __init__( self, in_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, mid_block_add_attention: bool = True, attention_head_dim: int = 1, resnet_time_scale_shift: str = "default", temb_channels: int = 256, freq_shift: int = 0, flip_sin_to_cos: bool = True, double_z: bool = True, ): super().__init__() # pass init params to Encoder self.encoder = TimeAwareEncoder( in_channels=in_channels, out_channels=latent_channels, down_block_types=down_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=double_z, mid_block_add_attention=mid_block_add_attention, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, freq_shift=freq_shift, flip_sin_to_cos=flip_sin_to_cos, attention_head_dim=attention_head_dim, ) self.use_slicing = False self.use_tiling = False self.double_z = double_z # only relevant if vae tiling is enabled self.tile_sample_min_size = self.config.sample_size sample_size = ( self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size ) self.tile_latent_min_size = int( sample_size / (2 ** (len(self.config.block_out_channels) - 1)) ) self.tile_overlap_factor = 0.25 def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (Encoder, Decoder)): module.gradient_checkpointing = value def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self): r""" Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing decoding in one step. """ self.enable_tiling(False) def enable_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.use_slicing = True def disable_slicing(self): r""" Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors( name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor], ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor( self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] ): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ if all( proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values() ): processor = AttnAddedKVProcessor() elif all( proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values() ): processor = AttnProcessor() else: raise ValueError( f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" ) self.set_attn_processor(processor) @apply_forward_hook def encode( self, x: torch.Tensor, timestep: Union[int, torch.Tensor], return_dict: bool = True, ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_tiling and ( x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size ): return self.tiled_encode(x, timestep, return_dict=return_dict) if self.use_slicing and x.shape[0] > 1: encoded_slices = [self.encoder(x_slice, timestep) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self.encoder(x, timestep) if not self.double_z: return h posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def tiled_encode( self, x: torch.Tensor, timestep: Union[int, torch.Tensor], return_dict: bool = True, ) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) row_limit = self.tile_latent_min_size - blend_extent # Split the image into 512x512 tiles and encode them separately. rows = [] for i in range(0, x.shape[2], overlap_size): row = [] for j in range(0, x.shape[3], overlap_size): tile = x[ :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size, ] tile = self.encoder(tile, timestep) if self.config.use_quant_conv: tile = self.quant_conv(tile) row.append(tile) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v(rows[i - 1][j], tile, blend_extent) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent) result_row.append(tile[:, :, :row_limit, :row_limit]) result_rows.append(torch.cat(result_row, dim=3)) moments = torch.cat(result_rows, dim=2) posterior = DiagonalGaussianDistribution(moments) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def forward( self, sample: torch.Tensor, timesteps: torch.Tensor, sample_posterior: bool = True, center_input_sample: bool = True, generator: Optional[torch.Generator] = None, ) -> Union[DecoderOutput, torch.Tensor]: r""" Args: sample (`torch.Tensor`): Input sample. sample_posterior (`bool`, *optional*, defaults to `False`): Whether to sample from the posterior. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`DecoderOutput`] instead of a plain tuple. """ if center_input_sample: sample = sample * 2 - 1.0 if not self.double_z: h = self.encode(sample, timesteps) return h else: posterior = self.encode(sample, timesteps).latent_dist if sample_posterior: return posterior.sample() else: return posterior ================================================ FILE: comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py ================================================ # Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union, Tuple import numpy as np import PIL.Image import torch from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import requests >>> import torch >>> from PIL import Image >>> from io import BytesIO >>> from diffusers import StableDiffusionImg2ImgPipeline >>> device = "cuda" >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> response = requests.get(url) >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") >>> init_image = init_image.resize((768, 512)) >>> prompt = "A fantasy landscape, trending on artstation" >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images >>> images[0].save("fantasy_landscape.png") ``` """ def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): w, h = image[0].size w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = 2.0 * image - 1.0 image = torch.from_numpy(image) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, timesteps: Optional[List[int]] = None, device: Optional[Union[str, torch.device]] = None, **kwargs, ): """ Prepare the sampling timesteps and noise sigmas. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ num_inference_steps = len(timesteps) timesteps = torch.tensor(timesteps, dtype=torch.float32, device=device) - 1 scheduler.timesteps = timesteps if not hasattr(scheduler, 'sigmas_cache'): scheduler.sigmas_cache = scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000 sigmas = scheduler.sigmas_cache[timesteps.long()] # minimal sigma if scheduler.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - scheduler.alphas_cumprod[0]) / scheduler.alphas_cumprod[0]) ** 0.5 elif scheduler.config.final_sigmas_type == "zero": sigma_last = 0 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {scheduler.config.final_sigmas_type}" ) sigma_last = torch.tensor([sigma_last,], dtype=torch.float32).to(device=sigmas.device) sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32) scheduler.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication scheduler._step_index = None scheduler._begin_index = None return scheduler.timesteps, num_inference_steps class StableDiffusionInvEnhancePipeline( DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, FromSingleFileMixin, ): r""" Pipeline for text-guided image-to-image generation using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). The pipeline also inherits the following loading methods: - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. text_encoder ([`~transformers.CLIPTextModel`]): Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). tokenizer ([`~transformers.CLIPTokenizer`]): A `CLIPTokenizer` to tokenize text. unet ([`UNet2DConditionModel`]): A `UNet2DConditionModel` to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details about a model's potential harms. feature_extractor ([`~transformers.CLIPImageProcessor`]): A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] _exclude_from_cpu_offload = ["safety_checker"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, image_encoder: CLIPVisionModelWithProjection = None, requires_safety_checker: bool = True, ): super().__init__() if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " "to update the config accordingly as leaving `steps_offset` might led to incorrect results" " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" " file" ) deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." " `clip_sample` should be set to False in the configuration file. Please make sure to update the" " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" ) deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(scheduler.config) new_config["clip_sample"] = False scheduler._internal_dict = FrozenDict(new_config) if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse("0.9.0.dev0") is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, **kwargs, ): deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) prompt_embeds_tuple = self.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=lora_scale, **kwargs, ) # concatenate for backwards comp prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt def encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, lora_scale: Optional[float] = None, clip_skip: Optional[int] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. lora_scale (`float`, *optional*): A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. """ # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if not USE_PEFT_BACKEND: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None if clip_skip is None: prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] else: prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True ) # Access the `hidden_states` first, that contains a tuple of # all the hidden states from the encoder layers. Then index into # the tuple to access the hidden states from the desired layer. prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] # We also need to apply the final LayerNorm here to not mess with the # representations. The `last_hidden_states` that we typically use for # obtaining the final prompt representations passes through the LayerNorm # layer. prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) if self.text_encoder is not None: prompt_embeds_dtype = self.text_encoder.dtype elif self.unet is not None: prompt_embeds_dtype = self.unet.dtype else: prompt_embeds_dtype = prompt_embeds.dtype prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: process multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) if self.text_encoder is not None: if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) if output_hidden_states: image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_enc_hidden_states = self.image_encoder( torch.zeros_like(image), output_hidden_states=True ).hidden_states[-2] uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( num_images_per_prompt, dim=0 ) return image_enc_hidden_states, uncond_image_enc_hidden_states else: image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance ): image_embeds = [] if do_classifier_free_guidance: negative_image_embeds = [] if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): raise ValueError( f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." ) for single_ip_adapter_image, image_proj_layer in zip( ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers ): output_hidden_state = not isinstance(image_proj_layer, ImageProjection) single_image_embeds, single_negative_image_embeds = self.encode_image( single_ip_adapter_image, device, 1, output_hidden_state ) image_embeds.append(single_image_embeds[None, :]) if do_classifier_free_guidance: negative_image_embeds.append(single_negative_image_embeds[None, :]) else: for single_image_embeds in ip_adapter_image_embeds: if do_classifier_free_guidance: single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) negative_image_embeds.append(single_negative_image_embeds) image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] for i, single_image_embeds in enumerate(image_embeds): single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) if do_classifier_free_guidance: single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) return ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ip_adapter_image=None, ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) if ip_adapter_image is not None and ip_adapter_image_embeds is not None: raise ValueError( "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." ) if ip_adapter_image_embeds is not None: if not isinstance(ip_adapter_image_embeds, list): raise ValueError( f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" ) elif ip_adapter_image_embeds[0].ndim not in [3, 4]: raise ValueError( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def prepare_latents( self, image, timestep, batch_size, num_images_per_prompt, dtype, device, noise=None, generator=None, ): if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt if image.shape[1] == 4: init_latents = image else: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) elif isinstance(generator, list): if image.shape[0] < batch_size and batch_size % image.shape[0] == 0: image = torch.cat([image] * (batch_size // image.shape[0]), dim=0) elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} " ) init_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size deprecation_message = ( f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" " your script to pass as many initial images as text prompts to suppress this warning." ) deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) additional_image_per_prompt = batch_size // init_latents.shape[0] init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." ) else: init_latents = torch.cat([init_latents], dim=0) shape = init_latents.shape if noise is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents return latents.type(dtype) # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding def get_guidance_scale_embedding( self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 ) -> torch.Tensor: """ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 Args: w (`torch.Tensor`): Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. embedding_dim (`int`, *optional*, defaults to 512): Dimension of the embeddings to generate. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): Data type of the generated embeddings. Returns: `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`. """ assert len(w.shape) == 1 w = w * 1000.0 half_dim = embedding_dim // 2 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) emb = w.to(dtype)[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1)) assert emb.shape == (w.shape[0], embedding_dim) return emb @property def guidance_scale(self): return self._guidance_scale @property def clip_skip(self): return self._clip_skip # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. @property def do_classifier_free_guidance(self): return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None @property def cross_attention_kwargs(self): return self._cross_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, target_size: Tuple[int] = (512, 512), strength: float = 0.25, num_inference_steps: Optional[int] = 20, timesteps: List[int] = None, sigmas: List[float] = None, guidance_scale: Optional[float] = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: int = None, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], **kwargs, ): r""" The call function to the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an low quality image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. target_size ('Tuple[int]'): Targeted image resolution (height, width) strength (`float`, *optional*, defaults to 0.25): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 20): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. sigmas (`List[float]`, *optional*): Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. start_noise_predictor ('nn.Module', *optional*): Noise predictor for the initial step intermediate_noise_predictor ('nn.Module', *optional*): Noise predictor for the intermediate steps. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide what to not include in image generation. If not defined, you need to pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, text embeddings are generated from the `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that the output of the pre-final layer will be used for computing the prompt embeddings. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of each denoising step during the inference. with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. Examples: Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) if callback is not None: deprecate( "callback", "1.0.0", "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if callback_steps is not None: deprecate( "callback_steps", "1.0.0", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", ) if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, ) self._guidance_scale = guidance_scale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._interrupt = False # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device # 3. Encode input prompt text_encoder_lora_scale = ( self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None ) prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=self.clip_skip, ) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt, self.do_classifier_free_guidance, ) # 4. Preprocess image self.image_processor.config.do_normalize = False image = self.image_processor.preprocess(image) # [0, 1], torch tensor, (b,c,h,w) self.image_processor.config.do_normalize = True image_up = torch.nn.functional.interpolate(image, size=target_size, mode='bicubic') # upsampling image_up = self.image_processor.normalize(image_up) # [-1, 1] # 5. set timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, timesteps, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables if getattr(self, 'start_noise_predictor', None) is not None: with torch.amp.autocast('cuda'): noise = self.start_noise_predictor( image, latent_timestep, sample_posterior=True, center_input_sample=True, ) else: noise = None latents = self.prepare_latents( image_up, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, noise, generator, ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter added_cond_kwargs = ( {"image_embeds": image_embeds} if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None if self.unet.config.time_cond_proj_dim is not None: guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) timestep_cond = self.get_guidance_scale_embedding( guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim ).to(device=device, dtype=latents.dtype) # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, timestep_cond=timestep_cond, cross_attention_kwargs=self.cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, )[0] # perform guidance if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 if getattr(self, 'intermediate_noise_predictor', None) is not None and i + 1 < len(timesteps): t_next = timesteps[i+1] with torch.amp.autocast('cuda'): noise = self.intermediate_noise_predictor(image, t_next, center_input_sample=True) extra_step_kwargs['noise'] = noise latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0 ] image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: comfyui_invsr_trimmed/sampler_invsr.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-07-13 16:59:27 import os, sys, math, random import numpy as np from pathlib import Path from .utils import util_net from .utils import util_image from .utils import util_common from .utils import util_color_fix import torch import torch.nn.functional as F import torch.distributed as dist import torch.multiprocessing as mean_psnr from .pipeline_stable_diffusion_inversion_sr import StableDiffusionInvEnhancePipeline from diffusers import AutoencoderKL import comfy.model_management as mm DEVICE = mm.get_torch_device() _positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\ 'meticulous detailing, hyper sharpness, perfect without deformations' _negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\ 'painting, drawing, sketch, oil painting' class BaseSampler: def __init__(self, configs): ''' Input: configs: config, see the yaml file in folder ./configs/ configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps} seed: int, random seed ''' self.configs = configs self.setup_seed() self.build_model() def setup_seed(self, seed=None): seed = self.configs.seed if seed is None else seed random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def write_log(self, log_str): print(log_str, flush=True) def build_model(self): # Build Stable diffusion params = dict(self.configs.sd_pipe.params) torch_dtype = params.pop('torch_dtype') params['torch_dtype'] = get_torch_dtype(torch_dtype) base_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params) if self.configs.get('scheduler', None) is not None: pipe_id = self.configs.scheduler.target.split('.')[-1] self.write_log(f'Loading scheduler of {pipe_id}...') base_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config( base_pipe.scheduler.config ) self.write_log('Loaded Done') if self.configs.get('vae_fp16', None) is not None: params_vae = dict(self.configs.vae_fp16.params) torch_dtype = params_vae.pop('torch_dtype') params_vae['torch_dtype'] = get_torch_dtype(torch_dtype) pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path self.write_log(f'Loading improved vae from {pipe_id}...') base_pipe.vae = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained( **params_vae, ) self.write_log('Loaded Done') if self.configs.base_model in ['sd-turbo', 'sd2base'] : sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe) else: raise ValueError(f"Unsupported base model: {self.configs.base_model}!") sd_pipe.to(DEVICE) if self.configs.sliced_vae: sd_pipe.vae.enable_slicing() if self.configs.tiled_vae: sd_pipe.vae.enable_tiling() sd_pipe.vae.tile_latent_min_size = self.configs.latent_tiled_size sd_pipe.vae.tile_sample_min_size = self.configs.sample_tiled_size if self.configs.gradient_checkpointing_vae: self.write_log(f"Activating gradient checkpoing for vae...") sd_pipe.vae.enable_gradient_checkpointing() model_configs = self.configs.model_start params = model_configs.get('params', dict) model_start = util_common.get_obj_from_str(model_configs.target)(**params) model_start.to(DEVICE) ckpt_path = model_configs.get('ckpt_path') assert ckpt_path is not None self.write_log(f"[InvSR] - Loading started model from {ckpt_path}...") state = torch.load(ckpt_path, map_location=DEVICE) if 'state_dict' in state: state = state['state_dict'] util_net.reload_model(model_start, state) # self.write_log(f"Loading Done") model_start.eval() setattr(sd_pipe, 'start_noise_predictor', model_start) self.sd_pipe = sd_pipe class InvSamplerSR(BaseSampler): def __init__(self, base_sampler): self.configs = base_sampler.configs self.sd_pipe = base_sampler.sd_pipe @torch.no_grad() def sample_func(self, im_cond): ''' Input: im_cond: b x c x h x w, torch tensor, [0,1], RGB Output: xt: h x w x c, numpy array, [0,1], RGB ''' ori_h_lq, ori_w_lq = im_cond.shape[-2:] ori_w_hq = ori_w_lq * self.configs.basesr.sf ori_h_hq = ori_h_lq * self.configs.basesr.sf vae_sf = (2 ** (len(self.sd_pipe.vae.config.block_out_channels) - 1)) if hasattr(self.sd_pipe, 'unet'): diffusion_sf = (2 ** (len(self.sd_pipe.unet.config.block_out_channels) - 1)) else: diffusion_sf = self.sd_pipe.transformer.patch_size mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf idle_pch_size = self.configs.basesr.chopping.pch_size if min(im_cond.shape[-2:]) >= idle_pch_size: pad_h_up = pad_w_left = 0 else: while min(im_cond.shape[-2:]) < idle_pch_size: pad_h_up = max(min((idle_pch_size - im_cond.shape[-2]) // 2, im_cond.shape[-2]-1), 0) pad_h_down = max(min(idle_pch_size - im_cond.shape[-2] - pad_h_up, im_cond.shape[-2]-1), 0) pad_w_left = max(min((idle_pch_size - im_cond.shape[-1]) // 2, im_cond.shape[-1]-1), 0) pad_w_right = max(min(idle_pch_size - im_cond.shape[-1] - pad_w_left, im_cond.shape[-1]-1), 0) im_cond = F.pad(im_cond, pad=(pad_w_left, pad_w_right, pad_h_up, pad_h_down), mode='reflect') if im_cond.shape[-2] == idle_pch_size and im_cond.shape[-1] == idle_pch_size: target_size = ( im_cond.shape[-2] * self.configs.basesr.sf, im_cond.shape[-1] * self.configs.basesr.sf ) res_sr = self.sd_pipe( image=im_cond.type(torch.float16), prompt=[_positive, ]*im_cond.shape[0], negative_prompt=[_negative, ]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None, target_size=target_size, timesteps=self.configs.timesteps, guidance_scale=self.configs.cfg_scale, output_type="pt", # torch tensor, b x c x h x w, [0, 1] ).images else: if not (im_cond.shape[-2] % mod_lq == 0 and im_cond.shape[-1] % mod_lq == 0): target_h_lq = math.ceil(im_cond.shape[-2] / mod_lq) * mod_lq target_w_lq = math.ceil(im_cond.shape[-1] / mod_lq) * mod_lq pad_h = target_h_lq - im_cond.shape[-2] pad_w = target_w_lq - im_cond.shape[-1] im_cond= F.pad(im_cond, pad=(0, pad_w, 0, pad_h), mode='reflect') im_spliter = util_image.ImageSpliterTh( im_cond, pch_size=idle_pch_size, stride= int(idle_pch_size * 0.50), sf=self.configs.basesr.sf, weight_type=self.configs.basesr.chopping.weight_type, extra_bs=self.configs.basesr.chopping.extra_bs, ) # pbar = ProgressBar(len(im_spliter) * im_cond.shape[0]) for im_lq_pch, index_infos in im_spliter: target_size = ( im_lq_pch.shape[-2] * self.configs.basesr.sf, im_lq_pch.shape[-1] * self.configs.basesr.sf, ) # start = torch.cuda.Event(enable_timing=True) # end = torch.cuda.Event(enable_timing=True) # start.record() res_sr_pch = self.sd_pipe( image=im_lq_pch.type(torch.float16), prompt=[_positive, ]*im_lq_pch.shape[0], negative_prompt=[_negative, ]*im_lq_pch.shape[0] if self.configs.cfg_scale > 1.0 else None, target_size=target_size, timesteps=self.configs.timesteps, guidance_scale=self.configs.cfg_scale, output_type="pt", # torch tensor, b x c x h x w, [0, 1] ).images # end.record() # torch.cuda.synchronize() # print(f"Time: {start.elapsed_time(end):.6f}") im_spliter.update(res_sr_pch, index_infos) # pbar.update(im_lq_pch.shape[0]) res_sr = im_spliter.gather() pad_h_up *= self.configs.basesr.sf pad_w_left *= self.configs.basesr.sf res_sr = res_sr[:, :, pad_h_up:ori_h_hq+pad_h_up, pad_w_left:ori_w_hq+pad_w_left] if self.configs.color_fix: im_cond_up = F.interpolate( im_cond, size=res_sr.shape[-2:], mode='bicubic', align_corners=False, antialias=True ) if self.configs.color_fix == 'ycbcr': res_sr = util_color_fix.ycbcr_color_replace(res_sr, im_cond_up) elif self.configs.color_fix == 'wavelet': res_sr = util_color_fix.wavelet_reconstruction(res_sr, im_cond_up) else: raise ValueError(f"Unsupported color fixing type: {self.configs.color_fix}") res_sr = res_sr.clamp(0.0, 1.0).cpu().float().numpy() return res_sr def inference(self, image_bchw): return self.sample_func(image_bchw.to(DEVICE)) def get_torch_dtype(torch_dtype: str): if torch_dtype == 'torch.float16': return torch.float16 elif torch_dtype == 'torch.bfloat16': return torch.bfloat16 elif torch_dtype == 'torch.float32': return torch.float32 else: raise ValueError(f'Unexpected torch dtype:{torch_dtype}') ================================================ FILE: comfyui_invsr_trimmed/time_aware_encoder.py ================================================ from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from diffusers.utils import is_torch_version from diffusers.models.unets.unet_2d_blocks import ( UNetMidBlock2D, get_down_block, ) from diffusers.models.embeddings import TimestepEmbedding, Timesteps class TimeAwareEncoder(nn.Module): r""" The `TimeAwareEncoder` layer of a variational autoencoder that encodes its input into a latent representation. Args: in_channels (`int`, *optional*, defaults to 3): The number of input channels. out_channels (`int`, *optional*, defaults to 3): The number of output channels. down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available options. block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): The number of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups for normalization. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. See `~diffusers.models.activations.get_activation` for available options. double_z (`bool`, *optional*, defaults to `True`): Whether to double the number of output channels for the last block. resnet_time_scale_shift (`str`, defaults to `"default"`) """ def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), block_out_channels: Tuple[int, ...] = (64,), layers_per_block: Union[int, Tuple[int, ...]] = 2, norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, mid_block_add_attention=True, resnet_time_scale_shift: str = "default", temb_channels: int = 256, freq_shift: int = 0, flip_sin_to_cos: bool = True, attention_head_dim: int = 1, ): super().__init__() if isinstance(layers_per_block, int): layers_per_block = (layers_per_block,) * len(down_block_types) self.layers_per_block = layers_per_block timestep_input_dim = max(128, block_out_channels[0]) self.time_proj = Timesteps(timestep_input_dim, flip_sin_to_cos, freq_shift) self.time_embedding = TimestepEmbedding(timestep_input_dim, temb_channels) self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) self.down_blocks = nn.ModuleList([]) # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=self.layers_per_block[i], in_channels=input_channel, out_channels=output_channel, add_downsample=not is_final_block, resnet_eps=1e-6, downsample_padding=0, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) self.down_blocks.append(down_block) # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, attention_head_dim=attention_head_dim, resnet_groups=norm_num_groups, add_attention=mid_block_add_attention, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) # out self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) self.conv_act = nn.SiLU() conv_out_channels = 2 * out_channels if double_z else out_channels self.conv_out = nn.Conv2d( block_out_channels[-1], conv_out_channels, 3, padding=1 ) self.gradient_checkpointing = False def forward( self, sample: torch.Tensor, timesteps: Union[torch.Tensor, int], ) -> torch.Tensor: r"""The forward method of the `Encoder` class.""" # time embedding if not torch.is_tensor(timesteps): timesteps = torch.tensor( [timesteps], dtype=torch.long, device=sample.device ) elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps * torch.ones( sample.shape[0], dtype=timesteps.dtype, device=timesteps.device ) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=list(self.time_embedding.parameters())[0].dtype) emb = self.time_embedding(t_emb) sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward # down if is_torch_version(">=", "1.11.0"): for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, emb, use_reentrant=False, ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, emb, use_reentrant=False, ) else: for down_block in self.down_blocks: sample = torch.utils.checkpoint.checkpoint( create_custom_forward(down_block), sample, emb ) # middle sample = torch.utils.checkpoint.checkpoint( create_custom_forward(self.mid_block), sample, emb ) else: # down for down_block in self.down_blocks: sample, _ = down_block(sample, emb) # middle sample = self.mid_block(sample, emb) # post-process sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample ================================================ FILE: comfyui_invsr_trimmed/utils/__init__.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-01-18 11:40:23 ================================================ FILE: comfyui_invsr_trimmed/utils/resize.py ================================================ """ A standalone PyTorch implementation for fast and efficient bicubic resampling. The resulting values are the same to MATLAB function imresize('bicubic'). ## Author: Sanghyun Son ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) ## Version: 1.2.0 ## Last update: July 9th, 2020 (KST) Dependency: torch Example:: >>> import torch >>> import core >>> x = torch.arange(16).float().view(1, 1, 4, 4) >>> y = core.imresize(x, sizes=(3, 3)) >>> print(y) tensor([[[[ 0.7506, 2.1004, 3.4503], [ 6.1505, 7.5000, 8.8499], [11.5497, 12.8996, 14.2494]]]]) """ import math import typing import torch from torch.nn import functional as F __all__ = ['imresize'] _I = typing.Optional[int] _D = typing.Optional[torch.dtype] def nearest_contribution(x: torch.Tensor) -> torch.Tensor: range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5)) cont = range_around_0.to(dtype=x.dtype) return cont def linear_contribution(x: torch.Tensor) -> torch.Tensor: ax = x.abs() range_01 = ax.le(1) cont = (1 - ax) * range_01.to(dtype=x.dtype) return cont def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor: ax = x.abs() ax2 = ax * ax ax3 = ax * ax2 range_01 = ax.le(1) range_12 = torch.logical_and(ax.gt(1), ax.le(2)) cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 cont_01 = cont_01 * range_01.to(dtype=x.dtype) cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) cont_12 = cont_12 * range_12.to(dtype=x.dtype) cont = cont_01 + cont_12 return cont def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: range_3sigma = (x.abs() <= 3 * sigma + 1) # Normalization will be done after cont = torch.exp(-x.pow(2) / (2 * sigma**2)) cont = cont * range_3sigma.to(dtype=x.dtype) return cont def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor: ''' For downsampling with integer scale only. ''' downsampling_factor = int(1 / scale) if kernel == 'cubic': kernel_size_orig = 4 else: raise ValueError('Pass!') if antialiasing: kernel_size = kernel_size_orig * downsampling_factor else: kernel_size = kernel_size_orig if downsampling_factor % 2 == 0: a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) else: kernel_size -= 1 a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) with torch.no_grad(): r = torch.linspace(-a, a, steps=kernel_size) k = cubic_contribution(r).view(-1, 1) k = torch.matmul(k, k.t()) k /= k.sum() return k def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor: ''' Apply reflect padding to the given Tensor. Note that it is slightly different from the PyTorch functional.pad, where boundary elements are used only once. Instead, we follow the MATLAB implementation which uses boundary elements twice. For example, [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, while our implementation yields [a, a, b, c, d, d]. ''' b, c, h, w = x.size() if dim == 2 or dim == -2: padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) for p in range(pad_pre): padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) for p in range(pad_post): padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) else: padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) for p in range(pad_pre): padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) for p in range(pad_post): padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) return padding_buffer def padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int, padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: if padding_type is None: return x elif padding_type == 'reflect': x_pad = reflect_padding(x, dim, pad_pre, pad_post) else: raise ValueError('{} padding is not supported!'.format(padding_type)) return x_pad def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]: base = base.long() r_min = base.min() r_max = base.max() + kernel_size - 1 if r_min <= 0: pad_pre = -r_min pad_pre = pad_pre.item() base += pad_pre else: pad_pre = 0 if r_max >= x_size: pad_post = r_max - x_size + 1 pad_post = pad_post.item() else: pad_post = 0 return pad_pre, pad_post, base def get_weight(dist: torch.Tensor, kernel_size: int, kernel: str = 'cubic', sigma: float = 2.0, antialiasing_factor: float = 1) -> torch.Tensor: buffer_pos = dist.new_zeros(kernel_size, len(dist)) for idx, buffer_sub in enumerate(buffer_pos): buffer_sub.copy_(dist - idx) # Expand (downsampling) / Shrink (upsampling) the receptive field. buffer_pos *= antialiasing_factor if kernel == 'cubic': weight = cubic_contribution(buffer_pos) elif kernel == 'gaussian': weight = gaussian_contribution(buffer_pos, sigma=sigma) else: raise ValueError('{} kernel is not supported!'.format(kernel)) weight /= weight.sum(dim=0, keepdim=True) return weight def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: # Resize height if dim == 2 or dim == -2: k = (kernel_size, 1) h_out = x.size(-2) - kernel_size + 1 w_out = x.size(-1) # Resize width else: k = (1, kernel_size) h_out = x.size(-2) w_out = x.size(-1) - kernel_size + 1 unfold = F.unfold(x, k) unfold = unfold.view(unfold.size(0), -1, h_out, w_out) return unfold def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]: if x.dim() == 4: b, c, h, w = x.size() elif x.dim() == 3: c, h, w = x.size() b = None elif x.dim() == 2: h, w = x.size() b = c = None else: raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) x = x.view(-1, 1, h, w) return x, b, c, h, w def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: rh = x.size(-2) rw = x.size(-1) # Back to the original dimension if b is not None: x = x.view(b, c, rh, rw) # 4-dim else: if c is not None: x = x.view(c, rh, rw) # 3-dim else: x = x.view(rh, rw) # 2-dim return x def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: if x.dtype != torch.float32 or x.dtype != torch.float64: dtype = x.dtype x = x.float() else: dtype = None return x, dtype def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: if dtype is not None: if not dtype.is_floating_point: x = x - x.detach() + x.round() # To prevent over/underflow when converting types if dtype is torch.uint8: x = x.clamp(0, 255) x = x.to(dtype=dtype) return x def resize_1d(x: torch.Tensor, dim: int, size: int, scale: float, kernel: str = 'cubic', sigma: float = 2.0, padding_type: str = 'reflect', antialiasing: bool = True) -> torch.Tensor: ''' Args: x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). dim (int): scale (float): size (int): Return: ''' # Identity case if scale == 1: return x # Default bicubic kernel with antialiasing (only when downsampling) if kernel == 'cubic': kernel_size = 4 else: kernel_size = math.floor(6 * sigma) if antialiasing and (scale < 1): antialiasing_factor = scale kernel_size = math.ceil(kernel_size / antialiasing_factor) else: antialiasing_factor = 1 # We allow margin to both sizes kernel_size += 2 # Weights only depend on the shape of input and output, # so we do not calculate gradients here. with torch.no_grad(): pos = torch.linspace( 0, size - 1, steps=size, dtype=x.dtype, device=x.device, ) pos = (pos + 0.5) / scale - 0.5 base = pos.floor() - (kernel_size // 2) + 1 dist = pos - base weight = get_weight( dist, kernel_size, kernel=kernel, sigma=sigma, antialiasing_factor=antialiasing_factor, ) pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) # To backpropagate through x x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) unfold = reshape_tensor(x_pad, dim, kernel_size) # Subsampling first if dim == 2 or dim == -2: sample = unfold[..., base, :] weight = weight.view(1, kernel_size, sample.size(2), 1) else: sample = unfold[..., base] weight = weight.view(1, kernel_size, 1, sample.size(3)) # Apply the kernel x = sample * weight x = x.sum(dim=1, keepdim=True) return x def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor: c = x.size(1) k_h = k.size(-2) k_w = k.size(-1) k = k.to(dtype=x.dtype, device=x.device) k = k.view(1, 1, k_h, k_w) k = k.repeat(c, c, 1, 1) e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) e = e.view(c, c, 1, 1) k = k * e pad_h = (k_h - scale) // 2 pad_w = (k_w - scale) // 2 x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) y = F.conv2d(x, k, padding=0, stride=scale) return y def imresize(x: torch.Tensor, scale: typing.Optional[float] = None, sizes: typing.Optional[typing.Tuple[int, int]] = None, kernel: typing.Union[str, torch.Tensor] = 'cubic', sigma: float = 2, rotation_degree: float = 0, padding_type: str = 'reflect', antialiasing: bool = True) -> torch.Tensor: """ Args: x (torch.Tensor): scale (float): sizes (tuple(int, int)): kernel (str, default='cubic'): sigma (float, default=2): rotation_degree (float, default=0): padding_type (str, default='reflect'): antialiasing (bool, default=True): Return: torch.Tensor: """ if scale is None and sizes is None: raise ValueError('One of scale or sizes must be specified!') if scale is not None and sizes is not None: raise ValueError('Please specify scale or sizes to avoid conflict!') x, b, c, h, w = reshape_input(x) if sizes is None and scale is not None: ''' # Check if we can apply the convolution algorithm scale_inv = 1 / scale if isinstance(kernel, str) and scale_inv.is_integer(): kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): raise ValueError( 'An integer downsampling factor ' 'should be used with a predefined kernel!' ) ''' # Determine output size sizes = (math.ceil(h * scale), math.ceil(w * scale)) scales = (scale, scale) if scale is None and sizes is not None: scales = (sizes[0] / h, sizes[1] / w) x, dtype = cast_input(x) if isinstance(kernel, str) and sizes is not None: # Core resizing module x = resize_1d( x, -2, size=sizes[0], scale=scales[0], kernel=kernel, sigma=sigma, padding_type=padding_type, antialiasing=antialiasing) x = resize_1d( x, -1, size=sizes[1], scale=scales[1], kernel=kernel, sigma=sigma, padding_type=padding_type, antialiasing=antialiasing) elif isinstance(kernel, torch.Tensor) and scale is not None: x = downsampling_2d(x, kernel, scale=int(1 / scale)) x = reshape_output(x, b, c) x = cast_output(x, dtype) return x ================================================ FILE: comfyui_invsr_trimmed/utils/util_color_fix.py ================================================ ''' # -------------------------------------------------------------------------------- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) # -------------------------------------------------------------------------------- ''' import torch from torch import Tensor from torch.nn import functional as F from torchvision.transforms import ToTensor, ToPILImage from .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: feat (Tensor): 4D tensor. eps (float): A small value added to the variance to avoid divide-by-zero. Default: 1e-5. """ size = feat.size() assert len(size) == 4, 'The input feature should be 4D tensor.' b, c = size[:2] feat_var = feat.reshape(b, c, -1).var(dim=2) + eps feat_std = feat_var.sqrt().reshape(b, c, 1, 1) feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) return feat_mean, feat_std def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): """Adaptive instance normalization. Adjust the reference features to have the similar color and illuminations as those in the degradate features. Args: content_feat (Tensor): The reference feature. style_feat (Tensor): The degradate features. """ size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size) def wavelet_blur(image: Tensor, radius: int): """ Apply wavelet blur to the input tensor. """ # input shape: (1, 3, H, W) # convolution kernel kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) # add channel dimensions to the kernel to make it a 4D tensor kernel = kernel[None, None] # repeat the kernel across all input channels kernel = kernel.repeat(3, 1, 1, 1) image = F.pad(image, (radius, radius, radius, radius), mode='replicate') # apply convolution output = F.conv2d(image, kernel, groups=3, dilation=radius) return output def wavelet_decomposition(image: Tensor, levels=5): """ Apply wavelet decomposition to the input tensor. This function only returns the low frequency & the high frequency. """ high_freq = torch.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq += (image - low_freq) image = low_freq return high_freq, low_freq def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): """ Apply wavelet decomposition, so that the content will have the same color as the style. """ # calculate the wavelet decomposition of the content feature content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # calculate the wavelet decomposition of the style feature style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # reconstruct the content feature with the style's high frequency return content_high_freq + style_low_freq def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor): """ Apply ycbcr decomposition, so that the content will have the same color as the style. """ content_y = rgb2ycbcrTorch(content_feat, only_y=True) style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False) target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1) target_rgb = ycbcr2rgbTorch(target_ycbcr) return target_rgb ================================================ FILE: comfyui_invsr_trimmed/utils/util_common.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2022-02-06 10:34:59 import os import random import requests import importlib from pathlib import Path def mkdir(dir_path, delete=False, parents=True): import shutil if not isinstance(dir_path, Path): dir_path = Path(dir_path) if delete: if dir_path.exists(): shutil.rmtree(str(dir_path)) if not dir_path.exists(): dir_path.mkdir(parents=parents) def get_obj_from_str(string, reload=False): current_package = __package__.rsplit(".", 1)[0] is_relative_import = string.startswith(".") package = current_package if is_relative_import else None module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module, package=package) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=package), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) def str2bool(v): if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("Boolean value expected.") def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True): ''' Get the file paths in the given folder. param exts: list, e.g., ['png',] return: list ''' if not isinstance(dir_path, Path): dir_path = Path(dir_path) file_paths = [] for current_ext in exts: if recursive: file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)]) else: file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)]) return file_paths def readline_txt(txt_file): txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file out = [] for txt_file_current in txt_file: with open(txt_file_current, 'r') as ff: out.extend([x[:-1] for x in ff.readlines()]) return out def scan_files_from_folder(dir_paths, exts, recursive=True): ''' Scaning images from given folder. Input: dir_pathas: str or list. exts: list ''' exts = [exts, ] if isinstance(exts, str) else exts dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths file_paths = [] for current_dir in dir_paths: current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir for current_ext in exts: if recursive: search_flag = f"**/*.{current_ext}" else: search_flag = f"*.{current_ext}" file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)])) return file_paths def write_path_to_txt( dir_folder, txt_path, search_key, num_files=None, write_only_name=False, write_only_stem=False, shuffle=False, ): ''' Scaning the files in the given folder and write them into a txt file Input: dir_folder: path of the target folder txt_path: path to save the txt file search_key: e.g., '*.png' write_only_name: bool, only record the file names (including extension), write_only_stem: bool, only record the file names (not including extension), ''' txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder if txt_path.exists(): txt_path.unlink() if write_only_name: path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)]) elif write_only_stem: path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)]) else: path_list = sorted([str(x) for x in dir_folder.glob(search_key)]) if shuffle: random.shuffle(path_list) if num_files is not None: path_list = path_list[:num_files] with open(txt_path, mode='w') as ff: for line in path_list: ff.write(line+'\n') ================================================ FILE: comfyui_invsr_trimmed/utils/util_ema.py ================================================ import torch from torch import nn class LitEma(nn.Module): def __init__(self, model, decay=0.9999, use_num_upates=True): super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError('Decay must be between 0 and 1') self.m_name2s_name = {} self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int)) for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers s_name = name.replace('.', '') self.m_name2s_name.update({name: s_name}) self.register_buffer(s_name, p.clone().detach().data) self.collected_params = [] def reset_num_updates(self): del self.num_updates self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) def forward(self, model): decay = self.decay if self.num_updates >= 0: self.num_updates += 1 decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) one_minus_decay = 1.0 - decay with torch.no_grad(): m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: sname = self.m_name2s_name[key] shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: assert not key in self.m_name2s_name def copy_to(self, model): """ Copying the ema state (i.e., buffers) to the targeted model Input: model: targeted model """ m_param = dict(model.named_parameters()) shadow_params = dict(self.named_buffers()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: assert not key in self.m_name2s_name def store(self, parameters): """ Save the parameters of the targeted model into the temporary pool for restoring later. Args: parameters: parameters of the targeted model. Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.collected_params = [param.clone() for param in parameters] def restore(self, parameters): """ Restore the parameters from the temporaty pool (stored with the `store` method). Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. """ for c_param, param in zip(self.collected_params, parameters): param.data.copy_(c_param.data) def resume(self, ckpt, num_updates): """ Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers Input: model: targerted model """ self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int)) shadow_params = dict(self.named_buffers()) for key, value in ckpt.items(): try: shadow_params[self.m_name2s_name[key]].data.copy_(value.data) except: if key.startswith('module') and key not in shadow_params: key = key[7:] shadow_params[self.m_name2s_name[key]].data.copy_(value.data) ================================================ FILE: comfyui_invsr_trimmed/utils/util_image.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2021-11-24 16:54:19 import sys import cv2 import math import torch import random import numpy as np from pathlib import Path # --------------------------Metrics---------------------------- def ssim(img1, img2): C1 = (0.01 * 255)**2 C2 = (0.03 * 255)**2 img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) kernel = cv2.getGaussianKernel(11, 1.5) window = np.outer(kernel, kernel.transpose()) mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] mu1_sq = mu1**2 mu2_sq = mu2**2 mu1_mu2 = mu1 * mu2 sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) return ssim_map.mean() def calculate_ssim(im1, im2, border=0, ycbcr=False): ''' SSIM the same outputs as MATLAB's im1, im2: h x w x , [0, 255], uint8 ''' if not im1.shape == im2.shape: raise ValueError('Input images must have the same dimensions.') if ycbcr: im1 = rgb2ycbcr(im1, True) im2 = rgb2ycbcr(im2, True) h, w = im1.shape[:2] im1 = im1[border:h-border, border:w-border] im2 = im2[border:h-border, border:w-border] if im1.ndim == 2: return ssim(im1, im2) elif im1.ndim == 3: if im1.shape[2] == 3: ssims = [] for i in range(3): ssims.append(ssim(im1[:,:,i], im2[:,:,i])) return np.array(ssims).mean() elif im1.shape[2] == 1: return ssim(np.squeeze(im1), np.squeeze(im2)) else: raise ValueError('Wrong input image dimensions.') def calculate_psnr(im1, im2, border=0, ycbcr=False): ''' PSNR metric. im1, im2: h x w x , [0, 255], uint8 ''' if not im1.shape == im2.shape: raise ValueError('Input images must have the same dimensions.') if ycbcr: im1 = rgb2ycbcr(im1, True) im2 = rgb2ycbcr(im2, True) h, w = im1.shape[:2] im1 = im1[border:h-border, border:w-border] im2 = im2[border:h-border, border:w-border] im1 = im1.astype(np.float64) im2 = im2.astype(np.float64) mse = np.mean((im1 - im2)**2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse)) def normalize_np(im, mean=0.5, std=0.5, reverse=False): ''' Input: im: h x w x c, numpy array Normalize: (im - mean) / std Reverse: im * std + mean ''' if not isinstance(mean, (list, tuple)): mean = [mean, ] * im.shape[2] mean = np.array(mean).reshape([1, 1, im.shape[2]]) if not isinstance(std, (list, tuple)): std = [std, ] * im.shape[2] std = np.array(std).reshape([1, 1, im.shape[2]]) if not reverse: out = (im.astype(np.float32) - mean) / std else: out = im.astype(np.float32) * std + mean return out def normalize_th(im, mean=0.5, std=0.5, reverse=False): ''' Input: im: b x c x h x w, torch tensor Normalize: (im - mean) / std Reverse: im * std + mean ''' if not isinstance(mean, (list, tuple)): mean = [mean, ] * im.shape[1] mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1]) if not isinstance(std, (list, tuple)): std = [std, ] * im.shape[1] std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1]) if not reverse: out = (im - mean) / std else: out = im * std + mean return out # ------------------------Image format-------------------------- def rgb2ycbcr(im, only_y=True): ''' same as matlab rgb2ycbcr Input: im: uint8 [0,255] or float [0,1] only_y: only return Y channel ''' # transform to float64 data type, range [0, 255] if im.dtype == np.uint8: im_temp = im.astype(np.float64) else: im_temp = (im * 255).astype(np.float64) # convert if only_y: rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0 else: rlt = np.matmul(im_temp, np.array([[65.481, -37.797, 112.0 ], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]])/255.0) + [16, 128, 128] if im.dtype == np.uint8: rlt = rlt.round() else: rlt /= 255. return rlt.astype(im.dtype) def rgb2ycbcrTorch(im, only_y=True): ''' same as matlab rgb2ycbcr Input: im: float [0,1], N x 3 x H x W only_y: only return Y channel ''' # transform to range [0,255.0] im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C # convert if only_y: rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966], device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0 else: scale = torch.tensor( [[65.481, -37.797, 112.0 ], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]], device=im.device, dtype=im.dtype ) / 255.0 bias = torch.tensor([16, 128, 128], device=im.device, dtype=im.dtype).view([-1, 1, 1, 3]) rlt = torch.matmul(im_temp, scale) + bias rlt /= 255.0 rlt.clamp_(0.0, 1.0) return rlt.permute([0, 3, 1, 2]) def ycbcr2rgbTorch(im): ''' same as matlab ycbcr2rgb Input: im: float [0,1], N x 3 x H x W only_y: only return Y channel ''' # transform to range [0,255.0] im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C # convert scale = torch.tensor( [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]], device=im.device, dtype=im.dtype ) * 255.0 bias = torch.tensor( [-222.921, 135.576, -276.836], device=im.device, dtype=im.dtype ).view([-1, 1, 1, 3]) rlt = torch.matmul(im_temp, scale) + bias rlt /= 255.0 rlt.clamp_(0.0, 1.0) return rlt.permute([0, 3, 1, 2]) def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB) def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR) def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): """Convert torch Tensors into image numpy arrays. After clamping to [min, max], values will be normalized to [0, 1]. Args: tensor (Tensor or list[Tensor]): Accept shapes: 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 2) 3D Tensor of shape (3/1 x H x W); 3) 2D Tensor of shape (H x W). Tensor channel should be in RGB order. rgb2bgr (bool): Whether to change rgb to bgr. out_type (numpy type): output types. If ``np.uint8``, transform outputs to uint8 type with range [0, 255]; otherwise, float type with range [0, 1]. Default: ``np.uint8``. min_max (tuple[int]): min and max values for clamp. Returns: (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of shape (H x W). The channel order is BGR. """ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') flag_tensor = torch.is_tensor(tensor) if flag_tensor: tensor = [tensor] result = [] for _tensor in tensor: _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) n_dim = _tensor.dim() if n_dim == 4: img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() img_np = img_np.transpose(1, 2, 0) if rgb2bgr: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif n_dim == 3: img_np = _tensor.numpy() img_np = img_np.transpose(1, 2, 0) if img_np.shape[2] == 1: # gray image img_np = np.squeeze(img_np, axis=2) else: if rgb2bgr: img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) elif n_dim == 2: img_np = _tensor.numpy() else: raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') if out_type == np.uint8: # Unlike MATLAB, numpy.unit8() WILL NOT round by default. img_np = (img_np * 255.0).round() img_np = img_np.astype(out_type) result.append(img_np) if len(result) == 1 and flag_tensor: result = result[0] return result # ------------------------Image resize----------------------------- def imresize_np(img, scale, antialiasing=True): # Now the scale should be the same for H and W # input: img: Numpy, HWC or HW [0,1] # output: HWC or HW [0,1] w/o round img = torch.from_numpy(img) need_squeeze = True if img.dim() == 2 else False if need_squeeze: img.unsqueeze_(2) in_H, in_W, in_C = img.size() out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) kernel_width = 4 kernel = 'cubic' # Return the desired dimension order for performing the resize. The # strategy is to perform the resize first along the dimension with the # smallest scale factor. # Now we do not support this. # get weights and indices weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( in_H, out_H, scale, kernel, kernel_width, antialiasing) weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( in_W, out_W, scale, kernel, kernel_width, antialiasing) # process H dimension # symmetric copying img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) sym_patch = img[:sym_len_Hs, :, :] inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(0, inv_idx) img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) sym_patch = img[-sym_len_He:, :, :] inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(0, inv_idx) img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) out_1 = torch.FloatTensor(out_H, in_W, in_C) kernel_width = weights_H.size(1) for i in range(out_H): idx = int(indices_H[i][0]) for j in range(out_C): out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) # process W dimension # symmetric copying out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) sym_patch = out_1[:, :sym_len_Ws, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) sym_patch = out_1[:, -sym_len_We:, :] inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() sym_patch_inv = sym_patch.index_select(1, inv_idx) out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) out_2 = torch.FloatTensor(out_H, out_W, in_C) kernel_width = weights_W.size(1) for i in range(out_W): idx = int(indices_W[i][0]) for j in range(out_C): out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) if need_squeeze: out_2.squeeze_() return out_2.numpy() def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): if (scale < 1) and (antialiasing): # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width kernel_width = kernel_width / scale # Output-space coordinates x = torch.linspace(1, out_length, out_length) # Input-space coordinates. Calculate the inverse mapping such that 0.5 # in output space maps to 0.5 in input space, and 0.5+scale in output # space maps to 1.5 in input space. u = x / scale + 0.5 * (1 - 1 / scale) # What is the left-most pixel that can be involved in the computation? left = torch.floor(u - kernel_width / 2) # What is the maximum number of pixels that can be involved in the # computation? Note: it's OK to use an extra pixel here; if the # corresponding weights are all zero, it will be eliminated at the end # of this function. P = math.ceil(kernel_width) + 2 # The indices of the input pixels involved in computing the k-th output # pixel are in row k of the indices matrix. indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( 1, P).expand(out_length, P) # The weights used to compute the k-th output pixel are in row k of the # weights matrix. distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices # apply cubic kernel if (scale < 1) and (antialiasing): weights = scale * cubic(distance_to_center * scale) else: weights = cubic(distance_to_center) # Normalize the weights matrix so that each row sums to 1. weights_sum = torch.sum(weights, 1).view(out_length, 1) weights = weights / weights_sum.expand(out_length, P) # If a column in weights is all zero, get rid of it. only consider the first and last column. weights_zero_tmp = torch.sum((weights == 0), 0) if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): indices = indices.narrow(1, 1, P - 2) weights = weights.narrow(1, 1, P - 2) if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): indices = indices.narrow(1, 0, P - 2) weights = weights.narrow(1, 0, P - 2) weights = weights.contiguous() indices = indices.contiguous() sym_len_s = -indices.min() + 1 sym_len_e = indices.max() - in_length indices = indices + sym_len_s - 1 return weights, indices, int(sym_len_s), int(sym_len_e) # matlab 'imresize' function, now only support 'bicubic' def cubic(x): absx = torch.abs(x) absx2 = absx**2 absx3 = absx**3 return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) # ------------------------Image I/O----------------------------- def imread(path, chn='rgb', dtype='float32', force_gray2rgb=True, force_rgba2rgb=False): ''' Read image. chn: 'rgb', 'bgr' or 'gray' out: im: h x w x c, numpy tensor ''' try: im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # BGR, uint8 except: print(str(path)) if im is None: print(str(path)) if chn.lower() == 'gray': assert im.ndim == 2, f"{str(path)} can't be successfuly read!" else: if im.ndim == 2: if force_gray2rgb: im = np.stack([im, im, im], axis=2) else: raise ValueError(f"{str(path)} has {im.ndim} channels!") elif im.ndim == 4: if force_rgba2rgb: im = im[:, :, :3] else: raise ValueError(f"{str(path)} has {im.ndim} channels!") else: if chn.lower() == 'rgb': im = bgr2rgb(im) elif chn.lower() == 'bgr': pass if dtype == 'float32': im = im.astype(np.float32) / 255. elif dtype == 'float64': im = im.astype(np.float64) / 255. elif dtype == 'uint8': pass else: sys.exit('Please input corrected dtype: float32, float64 or uint8!') return im # ------------------------Augmentation----------------------------- def data_aug_np(image, mode): ''' Performs data augmentation of the input image Input: image: a cv2 (OpenCV) image mode: int. Choice of transformation to apply to the image 0 - no transformation 1 - flip up and down 2 - rotate counterwise 90 degree 3 - rotate 90 degree and flip up and down 4 - rotate 180 degree 5 - rotate 180 degree and flip 6 - rotate 270 degree 7 - rotate 270 degree and flip ''' if mode == 0: # original out = image elif mode == 1: # flip up and down out = np.flipud(image) elif mode == 2: # rotate counterwise 90 degree out = np.rot90(image) elif mode == 3: # rotate 90 degree and flip up and down out = np.rot90(image) out = np.flipud(out) elif mode == 4: # rotate 180 degree out = np.rot90(image, k=2) elif mode == 5: # rotate 180 degree and flip out = np.rot90(image, k=2) out = np.flipud(out) elif mode == 6: # rotate 270 degree out = np.rot90(image, k=3) elif mode == 7: # rotate 270 degree and flip out = np.rot90(image, k=3) out = np.flipud(out) else: raise Exception('Invalid choice of image transformation') return out.copy() def inverse_data_aug_np(image, mode): ''' Performs inverse data augmentation of the input image ''' if mode == 0: # original out = image elif mode == 1: out = np.flipud(image) elif mode == 2: out = np.rot90(image, axes=(1,0)) elif mode == 3: out = np.flipud(image) out = np.rot90(out, axes=(1,0)) elif mode == 4: out = np.rot90(image, k=2, axes=(1,0)) elif mode == 5: out = np.flipud(image) out = np.rot90(out, k=2, axes=(1,0)) elif mode == 6: out = np.rot90(image, k=3, axes=(1,0)) elif mode == 7: # rotate 270 degree and flip out = np.flipud(image) out = np.rot90(out, k=3, axes=(1,0)) else: raise Exception('Invalid choice of image transformation') return out # ----------------------Visualization---------------------------- def imshow(x, title=None, cbar=False): import matplotlib.pyplot as plt plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') if title: plt.title(title) if cbar: plt.colorbar() plt.show() def imblend_with_mask(im, mask, alpha=0.25): """ Input: im, mask: h x w x c numpy array, uint8, [0, 255] alpha: scaler in [0.0, 1.0] """ edge_map = cv2.Canny(mask, 100, 200).astype(np.float32)[:, :, None] / 255. assert mask.dtype == np.uint8 mask = mask.astype(np.float32) / 255. if mask.ndim == 2: mask = mask[:, :, None] back_color = np.array([159, 121, 238], dtype=np.float32).reshape((1,1,3)) blend = im.astype(np.float32) * alpha + (1 - alpha) * back_color blend = np.clip(blend, 0, 255) out = im.astype(np.float32) * (1 - mask) + blend * mask # paste edge out = out * (1 - edge_map) + np.array([0,255,0], dtype=np.float32).reshape((1,1,3)) * edge_map return out.astype(np.uint8) # -----------------------Covolution------------------------------ def imgrad(im, pading_mode='mirror'): ''' Calculate image gradient. Input: im: h x w x c numpy array ''' from scipy.ndimage import correlate # lazy import wx = np.array([[0, 0, 0], [-1, 1, 0], [0, 0, 0]], dtype=np.float32) wy = np.array([[0, -1, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32) if im.ndim == 3: gradx = np.stack( [correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])], axis=2 ) grady = np.stack( [correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])], axis=2 ) grad = np.concatenate((gradx, grady), axis=2) else: gradx = correlate(im, wx, mode=pading_mode) grady = correlate(im, wy, mode=pading_mode) grad = np.stack((gradx, grady), axis=2) return {'gradx': gradx, 'grady': grady, 'grad':grad} def convtorch(im, weight, mode='reflect'): ''' Image convolution with pytorch Input: im: b x c_in x h x w torch tensor weight: c_out x c_in x k x k torch tensor Output: out: c x h x w torch tensor ''' radius = weight.shape[-1] chn = im.shape[1] im_pad = torch.nn.functional.pad(im, pad=(radius // 2, )*4, mode=mode) out = torch.nn.functional.conv2d(im_pad, weight, padding=0, groups=chn) return out # ----------------------Patch Cropping---------------------------- def random_crop(im, pch_size): ''' Randomly crop a patch from the give image. ''' h, w = im.shape[:2] # padding if necessary if h < pch_size or w < pch_size: pad_h = min(max(0, pch_size - h), h) pad_w = min(max(0, pch_size - w), w) im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) h, w = im.shape[:2] if h == pch_size: ind_h = 0 elif h > pch_size: ind_h = random.randint(0, h-pch_size) else: raise ValueError('Image height is smaller than the patch size') if w == pch_size: ind_w = 0 elif w > pch_size: ind_w = random.randint(0, w-pch_size) else: raise ValueError('Image width is smaller than the patch size') im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,] return im_pch class ToTensor: def __init__(self, max_value=1.0): self.max_value = max_value def __call__(self, im): assert isinstance(im, np.ndarray) if im.ndim == 2: im = im[:, :, np.newaxis] if im.dtype == np.uint8: assert self.max_value == 255. out = torch.from_numpy(im.astype(np.float32).transpose(2,0,1) / self.max_value) else: assert self.max_value == 1.0 out = torch.from_numpy(im.transpose(2,0,1)) return out class RandomCrop: def __init__(self, pch_size, pass_crop=False): self.pch_size = pch_size self.pass_crop = pass_crop def __call__(self, im): if self.pass_crop: return im if isinstance(im, list) or isinstance(im, tuple): out = [] for current_im in im: out.append(random_crop(current_im, self.pch_size)) else: out = random_crop(im, self.pch_size) return out class ImageSpliterNp: def __init__(self, im, pch_size, stride, sf=1): ''' Input: im: h x w x c, numpy array, [0, 1], low-resolution image in SR pch_size, stride: patch setting sf: scale factor in image super-resolution ''' assert stride <= pch_size self.stride = stride self.pch_size = pch_size self.sf = sf if im.ndim == 2: im = im[:, :, None] height, width, chn = im.shape self.height_starts_list = self.extract_starts(height) self.width_starts_list = self.extract_starts(width) self.length = self.__len__() self.num_pchs = 0 self.im_ori = im self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype) self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype) def extract_starts(self, length): starts = list(range(0, length, self.stride)) if starts[-1] + self.pch_size > length: starts[-1] = length - self.pch_size return starts def __len__(self): return len(self.height_starts_list) * len(self.width_starts_list) def __iter__(self): return self def __next__(self): if self.num_pchs < self.length: w_start_idx = self.num_pchs // len(self.height_starts_list) w_start = self.width_starts_list[w_start_idx] * self.sf w_end = w_start + self.pch_size * self.sf h_start_idx = self.num_pchs % len(self.height_starts_list) h_start = self.height_starts_list[h_start_idx] * self.sf h_end = h_start + self.pch_size * self.sf pch = self.im_ori[h_start:h_end, w_start:w_end,] self.w_start, self.w_end = w_start, w_end self.h_start, self.h_end = h_start, h_end self.num_pchs += 1 else: raise StopIteration(0) return pch, (h_start, h_end, w_start, w_end) def update(self, pch_res, index_infos): ''' Input: pch_res: pch_size x pch_size x 3, [0,1] index_infos: (h_start, h_end, w_start, w_end) ''' if index_infos is None: w_start, w_end = self.w_start, self.w_end h_start, h_end = self.h_start, self.h_end else: h_start, h_end, w_start, w_end = index_infos self.im_res[h_start:h_end, w_start:w_end] += pch_res self.pixel_count[h_start:h_end, w_start:w_end] += 1 def gather(self): assert np.all(self.pixel_count != 0) return self.im_res / self.pixel_count class ImageSpliterTh: def __init__(self, im, pch_size, stride, sf=1, extra_bs=1, weight_type='Gaussian'): ''' Input: im: n x c x h x w, torch tensor, float, low-resolution image in SR pch_size, stride: patch setting sf: scale factor in image super-resolution pch_bs: aggregate pchs to processing, only used when inputing single image ''' assert weight_type in ['Gaussian', 'ones'] self.weight_type = weight_type assert stride <= pch_size self.stride = stride self.pch_size = pch_size self.sf = sf self.extra_bs = extra_bs bs, chn, height, width= im.shape self.true_bs = bs self.height_starts_list = self.extract_starts(height) self.width_starts_list = self.extract_starts(width) self.starts_list = [] for ii in self.height_starts_list: for jj in self.width_starts_list: self.starts_list.append([ii, jj]) self.length = self.__len__() self.count_pchs = 0 self.im_ori = im self.dtype = torch.float64 self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device) self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device) def extract_starts(self, length): if length <= self.pch_size: starts = [0,] else: starts = list(range(0, length, self.stride)) for ii in range(len(starts)): if starts[ii] + self.pch_size > length: starts[ii] = length - self.pch_size starts = sorted(set(starts), key=starts.index) return starts def __len__(self): return len(self.height_starts_list) * len(self.width_starts_list) def __iter__(self): return self def __next__(self): if self.count_pchs < self.length: index_infos = [] current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs] for ii, (h_start, w_start) in enumerate(current_starts_list): w_end = w_start + self.pch_size h_end = h_start + self.pch_size current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end] if ii == 0: pch = current_pch else: pch = torch.cat([pch, current_pch], dim=0) h_start *= self.sf h_end *= self.sf w_start *= self.sf w_end *= self.sf index_infos.append([h_start, h_end, w_start, w_end]) self.count_pchs += len(current_starts_list) else: raise StopIteration() return pch, index_infos def update(self, pch_res, index_infos): ''' Input: pch_res: (n*extra_bs) x c x pch_size x pch_size, float index_infos: [(h_start, h_end, w_start, w_end),] ''' assert pch_res.shape[0] % self.true_bs == 0 pch_list = torch.split(pch_res, self.true_bs, dim=0) assert len(pch_list) == len(index_infos) for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos): current_pch = pch_list[ii].type(self.dtype) current_weight = self.get_weight(current_pch.shape[-2], current_pch.shape[-1]) self.im_res[:, :, h_start:h_end, w_start:w_end] += current_pch * current_weight self.pixel_count[:, :, h_start:h_end, w_start:w_end] += current_weight @staticmethod def generate_kernel_1d(ksize): sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8 # opencv default setting if ksize % 2 == 0: kernel = cv2.getGaussianKernel(ksize=ksize+1, sigma=sigma, ktype=cv2.CV_64F) kernel = kernel[1:, ] else: kernel = cv2.getGaussianKernel(ksize=ksize, sigma=sigma, ktype=cv2.CV_64F) return kernel def get_weight(self, height, width): if self.weight_type == 'ones': kernel = torch.ones(1, 1, height, width) elif self.weight_type == 'Gaussian': kernel_h = self.generate_kernel_1d(height).reshape(-1, 1) kernel_w = self.generate_kernel_1d(width).reshape(1, -1) kernel = np.matmul(kernel_h, kernel_w) kernel = torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0) # 1 x 1 x pch_size x pch_size else: raise ValueError(f"Unsupported weight type: {self.weight_type}") return kernel.to(dtype=self.dtype, device=self.im_ori.device) def gather(self): assert torch.all(self.pixel_count != 0) return self.im_res.div(self.pixel_count) # ----------------------Patch Cliping---------------------------- class Clamper: def __init__(self, min_max=(-1, 1)): self.min_bound, self.max_bound = min_max[0], min_max[1] def __call__(self, im): if isinstance(im, np.ndarray): return np.clip(im, a_min=self.min_bound, a_max=self.max_bound) elif isinstance(im, torch.Tensor): return torch.clamp(im, min=self.min_bound, max=self.max_bound) else: raise TypeError(f'ndarray or Tensor expected, got {type(im)}') # ----------------------Interpolation---------------------------- class Bicubic: def __init__(self, scale=None, out_shape=None, activate_matlab=True, resize_back=False): self.scale = scale self.activate_matlab = activate_matlab self.out_shape = out_shape self.resize_back = resize_back def __call__(self, im): if self.activate_matlab: out = imresize_np(im, scale=self.scale) if self.resize_back: out = imresize_np(out, scale=1/self.scale) else: out = cv2.resize( im, dsize=self.out_shape, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC, ) if self.resize_back: out = cv2.resize( out, dsize=self.out_shape, fx=1/self.scale, fy=1/self.scale, interpolation=cv2.INTER_CUBIC, ) return out class SmallestMaxSize: def __init__(self, max_size, pass_resize=False, interpolation=None): self.pass_resize = pass_resize self.max_size = max_size self.interpolation = interpolation self.str2mode = { 'nearest': cv2.INTER_NEAREST_EXACT, 'bilinear': cv2.INTER_LINEAR, 'bicubic': cv2.INTER_CUBIC } if self.interpolation is not None: assert interpolation in self.str2mode, f"Not supported interpolation mode: {interpolation}" def get_interpolation(self, size): if self.interpolation is None: if size < self.max_size: # upsampling interpolation = cv2.INTER_CUBIC else: # downsampling interpolation = cv2.INTER_AREA else: interpolation = self.str2mode[self.interpolation] return interpolation def __call__(self, im): h, w = im.shape[:2] if self.pass_resize or min(h, w) == self.max_size: out = im else: if h < w: dsize = (int(self.max_size * w / h), self.max_size) out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(h)) else: dsize = (self.max_size, int(self.max_size * h / w)) out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(w)) if out.dtype == np.uint8: out = np.clip(out, 0, 255) else: out = np.clip(out, 0, 1.0) return out # ----------------------augmentation---------------------------- class SpatialAug: def __init__(self, pass_aug, only_hflip=False, only_vflip=False, only_hvflip=False): self.only_hflip = only_hflip self.only_vflip = only_vflip self.only_hvflip = only_hvflip self.pass_aug = pass_aug def __call__(self, im, flag=None): if self.pass_aug: return im if flag is None: if self.only_hflip: flag = random.choice([0, 5]) elif self.only_vflip: flag = random.choice([0, 1]) elif self.only_hvflip: flag = random.choice([0, 1, 5]) else: flag = random.randint(0, 7) if isinstance(im, list) or isinstance(im, tuple): out = [] for current_im in im: out.append(data_aug_np(current_im, flag)) else: out = data_aug_np(im, flag) return out ================================================ FILE: comfyui_invsr_trimmed/utils/util_net.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2021-11-24 20:29:36 def reload_model(model, ckpt): module_flag = list(ckpt.keys())[0].startswith('module.') compile_flag = '_orig_mod' in list(ckpt.keys())[0] for source_key, source_value in model.state_dict().items(): target_key = source_key if compile_flag and (not '_orig_mod.' in source_key): target_key = '_orig_mod.' + target_key if module_flag and (not source_key.startswith('module')): target_key = 'module.' + target_key assert target_key in ckpt source_value.copy_(ckpt[target_key]) ================================================ FILE: comfyui_invsr_trimmed/utils/util_opts.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2021-11-24 15:07:43 import argparse def update_args(args_json, args_parser): for arg in vars(args_parser): args_json[arg] = getattr(args_parser, arg) def str2bool(v): """ https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse """ if isinstance(v, bool): return v if v.lower() in ("yes", "true", "t", "y", "1"): return True elif v.lower() in ("no", "false", "f", "n", "0"): return False else: raise argparse.ArgumentTypeError("boolean value expected") ================================================ FILE: comfyui_invsr_trimmed/utils/util_sisr.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2021-12-07 21:37:58 import cv2 import numpy as np def modcrop(im, sf): h, w = im.shape[:2] h -= (h % sf) w -= (w % sf) return im[:h, :w,] #-----------------------------------------Transform-------------------------------------------- class Bicubic: def __init__(self, scale=None, out_shape=None, matlab_mode=True): self.scale = scale self.out_shape = out_shape def __call__(self, im): out = cv2.resize( im, dsize=self.out_shape, fx=self.scale, fy=self.scale, interpolation=cv2.INTER_CUBIC, ) return out ================================================ FILE: configs/degradation_testing_realesrgan.yaml ================================================ degradation: sf: 4 # the first degradation process resize_prob: [0.2, 0.7, 0.1] # up, down, keep resize_range: [0.5, 1.5] gaussian_noise_prob: 0.5 noise_range: [1, 15] poisson_scale_range: [0.05, 0.3] gray_noise_prob: 0.4 jpeg_range: [70, 95] # the second degradation process second_order_prob: 0.0 second_blur_prob: 0.2 resize_prob2: [0.3, 0.4, 0.3] # up, down, keep resize_range2: [0.8, 1.2] gaussian_noise_prob2: 0.5 noise_range2: [1, 10] poisson_scale_range2: [0.05, 0.2] gray_noise_prob2: 0.4 jpeg_range2: [80, 100] gt_size: 512 opts: data_source: ~ im_exts: ['png', 'JPEG'] io_backend: type: disk blur_kernel_size: 13 kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob: [0.60, 0.40, 0.0, 0.0, 0.0, 0.0] sinc_prob: 0.1 blur_sigma: [0.2, 0.8] betag_range: [1.0, 1.5] betap_range: [1, 1.2] blur_kernel_size2: 7 kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob2: [0.60, 0.4, 0.0, 0.0, 0.0, 0.0] sinc_prob2: 0.0 blur_sigma2: [0.2, 0.5] betag_range2: [0.5, 0.8] betap_range2: [1, 1.2] final_sinc_prob: 0.2 gt_size: ${degradation.gt_size} crop_pad_size: ${degradation.gt_size} use_hflip: False use_rot: False ================================================ FILE: configs/sample-sd-turbo.yaml ================================================ seed: 12345 # Super-resolution settings basesr: sf: 4 chopping: # for latent diffusion pch_size: 128 weight_type: Gaussian # VAE settings tiled_vae: True latent_tiled_size: 128 sample_tiled_size: 1024 gradient_checkpointing_vae: True sliced_vae: False # classifer-free guidance cfg_scale: 1.0 # sampling settings start_timesteps: 200 # color fixing color_fix: ~ # Stable Diffusion base_model: sd-turbo sd_pipe: target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline enable_grad_checkpoint: True params: pretrained_model_name_or_path: stabilityai/sd-turbo cache_dir: /mnt/sfs-common/zsyue/modelbase/stable-diffusion/sd-turbo use_safetensors: True torch_dtype: torch.float16 model_start: target: .noise_predictor.NoisePredictor ckpt_path: ~ # For initializing params: in_channels: 3 down_block_types: - AttnDownBlock2D - AttnDownBlock2D up_block_types: - AttnUpBlock2D - AttnUpBlock2D block_out_channels: - 256 # 192, 256 - 512 # 384, 512 layers_per_block: - 3 - 3 act_fn: silu latent_channels: 4 norm_num_groups: 32 sample_size: 128 mid_block_add_attention: True resnet_time_scale_shift: default temb_channels: 512 attention_head_dim: 64 freq_shift: 0 flip_sin_to_cos: True double_z: True model_middle: target: .noise_predictor.NoisePredictor params: in_channels: 3 down_block_types: - AttnDownBlock2D - AttnDownBlock2D up_block_types: - AttnUpBlock2D - AttnUpBlock2D block_out_channels: - 256 # 192, 256 - 512 # 384, 512 layers_per_block: - 3 - 3 act_fn: silu latent_channels: 4 norm_num_groups: 32 sample_size: 128 mid_block_add_attention: True resnet_time_scale_shift: default temb_channels: 512 attention_head_dim: 64 freq_shift: 0 flip_sin_to_cos: True double_z: True ================================================ FILE: configs/sd-turbo-sr-ldis.yaml ================================================ trainer: target: trainer.TrainerSDTurboSR sd_pipe: target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline num_train_steps: 1000 enable_grad_checkpoint: True compile: False vae_split: 8 params: pretrained_model_name_or_path: stabilityai/sd-turbo cache_dir: weights use_safetensors: True torch_dtype: torch.float16 llpips: target: latent_lpips.lpips.LPIPS ckpt_path: weights/vgg16_sdturbo_lpips.pth compile: False params: pretrained: False net: vgg16 lpips: True spatial: False pnet_rand: False pnet_tune: True use_dropout: True eval_mode: True latent: True in_chans: 4 verbose: True model: target: .noise_predictor.NoisePredictor ckpt_start_path: ~ # only used for training the intermidiate model ckpt_path: ~ # For initializing compile: False params: in_channels: 3 down_block_types: - AttnDownBlock2D - AttnDownBlock2D up_block_types: - AttnUpBlock2D - AttnUpBlock2D block_out_channels: - 256 # 192, 256 - 512 # 384, 512 layers_per_block: - 3 - 3 act_fn: silu latent_channels: 4 norm_num_groups: 32 sample_size: 128 mid_block_add_attention: True resnet_time_scale_shift: default temb_channels: 512 attention_head_dim: 64 freq_shift: 0 flip_sin_to_cos: True double_z: True discriminator: target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator enable_grad_checkpoint: True compile: False params: sample_size: 64 in_channels: 4 center_input_sample: False flip_sin_to_cos: True freq_shift: 0 down_block_types: - DownBlock2D - CrossAttnDownBlock2D - CrossAttnDownBlock2D mid_block_type: UNetMidBlock2DCrossAttn up_block_types: - CrossAttnUpBlock2D - CrossAttnUpBlock2D - UpBlock2D only_cross_attention: False block_out_channels: - 128 - 256 - 512 layers_per_block: - 1 - 2 - 2 downsample_padding: 1 mid_block_scale_factor: 1 dropout: 0.0 act_fn: silu norm_num_groups: 32 norm_eps: 1e-5 cross_attention_dim: 1024 transformer_layers_per_block: 1 reverse_transformer_layers_per_block: ~ encoder_hid_dim: ~ encoder_hid_dim_type: ~ attention_head_dim: - 8 - 16 - 16 num_attention_heads: ~ dual_cross_attention: False use_linear_projection: False class_embed_type: ~ addition_embed_type: text addition_time_embed_dim: 256 num_class_embeds: ~ upcast_attention: ~ resnet_time_scale_shift: default resnet_skip_time_act: False resnet_out_scale_factor: 1.0 time_embedding_type: positional time_embedding_dim: ~ time_embedding_act_fn: ~ timestep_post_act: ~ time_cond_proj_dim: ~ conv_in_kernel: 3 conv_out_kernel: 3 projection_class_embeddings_input_dim: 2560 attention_type: default class_embeddings_concat: False mid_block_only_cross_attention: ~ cross_attention_norm: ~ addition_embed_type_num_heads: 64 degradation: sf: 4 # the first degradation process resize_prob: [0.2, 0.7, 0.1] # up, down, keep resize_range: [0.15, 1.5] gaussian_noise_prob: 0.5 noise_range: [1, 30] poisson_scale_range: [0.05, 3.0] gray_noise_prob: 0.4 jpeg_range: [30, 95] # the second degradation process second_order_prob: 0.5 second_blur_prob: 0.8 resize_prob2: [0.3, 0.4, 0.3] # up, down, keep resize_range2: [0.3, 1.2] gaussian_noise_prob2: 0.5 noise_range2: [1, 25] poisson_scale_range2: [0.05, 2.5] gray_noise_prob2: 0.4 jpeg_range2: [30, 95] gt_size: 512 resize_back: False use_sharp: False data: train: type: realesrgan params: data_source: source1: root_path: /mnt/sfs-common/zsyue/database/FFHQ image_path: images1024 moment_path: ~ text_path: ~ im_ext: png length: 20000 source2: root_path: /mnt/sfs-common/zsyue/database/LSDIR/train image_path: images moment_path: ~ text_path: ~ im_ext: png max_token_length: 77 # 77 io_backend: type: disk blur_kernel_size: 21 kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] sinc_prob: 0.1 blur_sigma: [0.2, 3.0] betag_range: [0.5, 4.0] betap_range: [1, 2.0] blur_kernel_size2: 15 kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] sinc_prob2: 0.1 blur_sigma2: [0.2, 1.5] betag_range2: [0.5, 4.0] betap_range2: [1, 2.0] final_sinc_prob: 0.8 gt_size: ${degradation.gt_size} use_hflip: True use_rot: False random_crop: True val: type: base params: dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq transform_type: default transform_kwargs: mean: 0.0 std: 1.0 extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt extra_transform_type: default extra_transform_kwargs: mean: 0.0 std: 1.0 im_exts: png length: 16 recursive: False train: # predict started inverser start_mode: True # learning rate lr: 5e-5 # learning rate lr_min: 5e-5 # learning rate lr_schedule: ~ warmup_iterations: 2000 # discriminator lr_dis: 5e-5 # learning rate for dicriminator weight_decay_dis: 1e-3 # weight decay for dicriminator dis_init_iterations: 10000 # iterations used for updating the discriminator dis_update_freq: 1 # dataloader batch: 64 microbatch: 16 num_workers: 4 prefetch_factor: 2 use_text: True # optimization settings weight_decay: 0 ema_rate: 0.999 iterations: 200000 # total iterations # logging save_freq: 5000 log_freq: [200, 5000] # [training loss, training images, val images] local_logging: True # manually save images tf_logging: False # tensorboard logging # loss loss_type: L2 loss_coef: ldif: 1.0 timesteps: [200, 100] num_inference_steps: 5 # mixed precision use_amp: True use_fsdp: False # random seed seed: 123456 global_seeding: False noise_detach: False validate: batch: 2 use_ema: True log_freq: 4 # logging frequence val_y_channel: True ================================================ FILE: node.py ================================================ from .comfyui_invsr_trimmed import get_configs, InvSamplerSR, BaseSampler, Namespace import torch from comfy.utils import ProgressBar from folder_paths import get_full_path, get_folder_paths, models_dir import os import torch.nn.functional as F def split_tensor_into_batches(tensor, batch_size): """ Split a tensor into smaller batches of specified size Args: tensor (torch.Tensor): Input tensor of shape (N, C, H, W) batch_size (int): Desired batch size for splitting Returns: list: List of tensors, each with batch_size (except possibly the last one) """ # Get original batch size original_batch_size = tensor.size(0) # Calculate number of full batches and remaining samples num_full_batches = original_batch_size // batch_size remaining_samples = original_batch_size % batch_size # Split tensor into chunks batches = [] # Handle full batches for i in range(num_full_batches): start_idx = i * batch_size end_idx = start_idx + batch_size batch = tensor[start_idx:end_idx] batches.append(batch) # Handle remaining samples if any if remaining_samples > 0: last_batch = tensor[-remaining_samples:] batches.append(last_batch) return batches class LoadInvSRModels: @classmethod def INPUT_TYPES(s): return { "required": { "sd_model": (['stabilityai/sd-turbo'],), "invsr_model": (['noise_predictor_sd_turbo_v5.pth', 'noise_predictor_sd_turbo_v5_diftune.pth'],), "dtype": (['fp16', 'fp32', 'bf16'], {"default": "fp16"}), "tiled_vae": ("BOOLEAN", {"default": True}), }, } RETURN_TYPES = ("INVSR_PIPE",) RETURN_NAMES = ("invsr_pipe",) FUNCTION = "loadmodel" CATEGORY = "INVSR" def loadmodel(self, sd_model, invsr_model, dtype, tiled_vae): match dtype: case "fp16": dtype = "torch.float16" case "fp32": dtype = "torch.float32" case "bf16": dtype = "torch.bfloat16" cfg_path = os.path.join( os.path.dirname(__file__), "configs", "sample-sd-turbo.yaml" ) sd_path = get_folder_paths("diffusers")[0] try: ckpt_dir = get_folder_paths("invsr")[0] except: ckpt_dir = os.path.join(models_dir, "invsr") args = Namespace( bs=1, chopping_bs=8, timesteps=None, num_steps=1, cfg_path=cfg_path, sd_path=sd_path, started_ckpt_dir=ckpt_dir, tiled_vae=tiled_vae, color_fix="", chopping_size=128, invsr_model=invsr_model ) configs = get_configs(args) configs["sd_pipe"]["params"]["torch_dtype"] = dtype base_sampler = BaseSampler(configs) return ((base_sampler, invsr_model),) class InvSRSampler: @classmethod def INPUT_TYPES(s): return { "required": { "invsr_pipe": ("INVSR_PIPE",), "images": ("IMAGE",), "num_steps": ("INT",{"default": 1, "min": 1, "max": 5}), "cfg": ("FLOAT",{"default": 1.0, "step":0.1}), # "scale_factor": ("INT",{"default": 4}), "batch_size": ("INT",{"default": 1}), "chopping_batch_size": ("INT",{"default": 8}), "chopping_size": ([128, 256, 512],{"default": 128}), "color_fix": (['none', 'wavelet', 'ycbcr'], {"default": "none"}), "seed": ("INT", {"default": 123, "min": 0, "max": 2**32 - 1, "step": 1}), }, } RETURN_TYPES = ("IMAGE",) RETURN_NAMES = ("image",) FUNCTION = "process" CATEGORY = "INVSR" def process(self, invsr_pipe, images, num_steps, cfg, batch_size, chopping_batch_size, chopping_size, color_fix, seed): base_sampler, invsr_model = invsr_pipe if color_fix == "none": color_fix = "" cfg_path = os.path.join( os.path.dirname(__file__), "configs", "sample-sd-turbo.yaml" ) sd_path = get_folder_paths("diffusers")[0] try: ckpt_dir = get_folder_paths("invsr")[0] except: ckpt_dir = os.path.join(models_dir, "invsr") args = Namespace( bs=batch_size, chopping_bs=chopping_batch_size, timesteps=None, num_steps=num_steps, cfg_path=cfg_path, sd_path=sd_path, started_ckpt_dir=ckpt_dir, tiled_vae=base_sampler.configs.tiled_vae, color_fix=color_fix, chopping_size=chopping_size, invsr_model=invsr_model ) configs = get_configs(args, log=True) configs["cfg_scale"] = cfg # configs["basesr"]["sf"] = scale_factor base_sampler.configs = configs base_sampler.setup_seed(seed) sampler = InvSamplerSR(base_sampler) images_bchw = images.permute(0,3,1,2) og_h, og_w = images_bchw.shape[2:] # Calculate new dimensions divisible by 16 new_height = ((og_h + 15) // 16) * 16 # Round up to nearest multiple of 16 new_width = ((og_w + 15) // 16) * 16 resized = False if og_h != new_height or og_w != new_width: resized = True print(f"[InvSR] - Image not divisible by 16. Resizing to {new_height} (h) x {new_width} (w)") images_bchw = F.interpolate(images_bchw, size=(new_height, new_width), mode='bicubic', align_corners=False) batches = split_tensor_into_batches(images_bchw, batch_size) results = [] pbar = ProgressBar(len(batches)) for batch in batches: result = sampler.inference(image_bchw=batch) results.append(torch.from_numpy(result)) pbar.update(1) result_t = torch.cat(results, dim=0) # Resize to original dimensions * 4 if resized: result_t = F.interpolate(result_t, size=(og_h * 4, og_w * 4), mode='bicubic', align_corners=False) return (result_t.permute(0,2,3,1),) ================================================ FILE: pyproject.toml ================================================ [project] name = "invsr" description = "This project is an unofficial ComfyUI implementation of [a/InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)" version = "1.0.1" license = {file = "LICENSE"} dependencies = ["pyiqa==0.1.12", "opencv-python", "albumentations==1.4.18", "bitsandbytes", "protobuf", "python-box", "omegaconf", "loguru"] [project.urls] Repository = "https://github.com/yuvraj108c/ComfyUI_InvSR" # Used by Comfy Registry https://comfyregistry.org [tool.comfy] PublisherId = "yuvraj108c" DisplayName = "ComfyUI_InvSR" Icon = "" ================================================ FILE: requirements.txt ================================================ opencv-contrib-python-headless omegaconf diffusers numpy<2 huggingface-hub transformers ================================================ FILE: workflows/invsr.json ================================================ { "last_node_id": 27, "last_link_id": 40, "nodes": [ { "id": 25, "type": "LoadInvSRModels", "pos": [ -877.6565551757812, 719.943603515625 ], "size": [ 413.84686279296875, 130 ], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [ { "name": "invsr_pipe", "type": "INVSR_PIPE", "links": [ 38 ], "slot_index": 0 } ], "properties": { "Node name for S&R": "LoadInvSRModels" }, "widgets_values": [ "stabilityai/sd-turbo", "noise_predictor_sd_turbo_v5.pth", "bf16", true ] }, { "id": 17, "type": "LoadImage", "pos": [ -867.9894409179688, 918.590087890625 ], "size": [ 367.18212890625, 439.6049499511719 ], "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [ { "name": "IMAGE", "type": "IMAGE", "links": [ 40 ], "slot_index": 0 }, { "name": "MASK", "type": "MASK", "links": null } ], "properties": { "Node name for S&R": "LoadImage" }, "widgets_values": [ "i1.png", "image" ] }, { "id": 27, "type": "InvSRSampler", "pos": [ -365.7756652832031, 859.0978393554688 ], "size": [ 315, 246 ], "flags": {}, "order": 2, "mode": 0, "inputs": [ { "name": "invsr_pipe", "type": "INVSR_PIPE", "link": 38 }, { "name": "images", "type": "IMAGE", "link": 40 } ], "outputs": [ { "name": "image", "type": "IMAGE", "links": [ 39 ], "slot_index": 0 } ], "properties": { "Node name for S&R": "InvSRSampler" }, "widgets_values": [ 1, 5, 1, 8, 128, "wavelet", 536149006, "randomize" ] }, { "id": 26, "type": "PreviewImage", "pos": [ 35.55903625488281, 643.07470703125 ], "size": [ 889.1287231445312, 869.5750732421875 ], "flags": {}, "order": 3, "mode": 0, "inputs": [ { "name": "images", "type": "IMAGE", "link": 39 } ], "outputs": [], "properties": { "Node name for S&R": "PreviewImage" } } ], "links": [ [ 38, 25, 0, 27, 0, "INVSR_PIPE" ], [ 39, 27, 0, 26, 0, "IMAGE" ], [ 40, 17, 0, 27, 1, "IMAGE" ] ], "groups": [], "config": {}, "extra": { "ds": { "scale": 0.9090909090909091, "offset": [ 1131.660545256908, -643.8089187150072 ] }, "VHS_latentpreview": false, "VHS_latentpreviewrate": 0 }, "version": 0.4 }