Repository: yuvraj108c/ComfyUI_InvSR
Branch: main
Commit: 20a0e003e676
Files: 31
Total size: 199.0 KB
Directory structure:
gitextract_393ad4r7/
├── .github/
│ ├── FUNDING.yml
│ └── workflows/
│ └── publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── __init__.py
├── comfyui_invsr_trimmed/
│ ├── __init__.py
│ ├── inference_invsr.py
│ ├── latent_lpips/
│ │ ├── __init__.py
│ │ ├── lpips.py
│ │ └── pretrained_networks.py
│ ├── noise_predictor.py
│ ├── pipeline_stable_diffusion_inversion_sr.py
│ ├── sampler_invsr.py
│ ├── time_aware_encoder.py
│ └── utils/
│ ├── __init__.py
│ ├── resize.py
│ ├── util_color_fix.py
│ ├── util_common.py
│ ├── util_ema.py
│ ├── util_image.py
│ ├── util_net.py
│ ├── util_opts.py
│ └── util_sisr.py
├── configs/
│ ├── degradation_testing_realesrgan.yaml
│ ├── sample-sd-turbo.yaml
│ └── sd-turbo-sr-ldis.yaml
├── node.py
├── pyproject.toml
├── requirements.txt
└── workflows/
└── invsr.json
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/FUNDING.yml
================================================
github: yuvraj108c
custom: ["https://paypal.me/yuvraj108c", "https://buymeacoffee.com/yuvraj108cz"]
================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish to Comfy registry
on:
workflow_dispatch:
push:
branches:
- main
- master
paths:
- "pyproject.toml"
permissions:
issues: write
jobs:
publish-node:
name: Publish Custom Node to registry
runs-on: ubuntu-latest
if: ${{ github.repository_owner == 'yuvraj108c' }}
steps:
- name: Check out code
uses: actions/checkout@v4
with:
submodules: true
- name: Publish Custom Node
uses: Comfy-Org/publish-node-action@v1
with:
## Add your own personal access token to your Github Repository secrets and reference it here.
personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
================================================
FILE: .gitignore
================================================
.DS_Store
*pyc
.vscode
__pycache__
# *.egg-info
*.bak
checkpoints
results
backup
================================================
FILE: LICENSE
================================================
S-Lab License 1.0
Copyright 2024 S-Lab
Redistribution and use for non-commercial purpose in source and
binary forms, with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in
the documentation and/or other materials provided with the
distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
In the event that redistribution and/or use for commercial purpose in
source or binary forms, with or without modification is required,
please contact the contributor(s) of the work.
================================================
FILE: README.md
================================================
# ComfyUI InvSR
[](https://arxiv.org/abs/2412.09013)
This project is a ComfyUI wrapper for [InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)
**Last tested**: 2 January 2026 (ComfyUI v0.7.0@f2fda02 | Torch 2.9.1 | Python 3.10.12 | RTX4090 | CUDA 13.0 | Debian 12)
## ⭐ Support
If you like my projects and wish to see updates and new features, please consider supporting me. It helps a lot!
[](https://github.com/yuvraj108c/ComfyUI-Depth-Anything-Tensorrt)
[](https://github.com/yuvraj108c/ComfyUI-Upscaler-Tensorrt)
[](https://github.com/yuvraj108c/ComfyUI-Dwpose-Tensorrt)
[](https://github.com/yuvraj108c/ComfyUI-Rife-Tensorrt)
[](https://github.com/yuvraj108c/ComfyUI-Whisper)
[](https://github.com/yuvraj108c/ComfyUI_InvSR)
[](https://github.com/yuvraj108c/ComfyUI-Thera)
[](https://github.com/yuvraj108c/ComfyUI-Video-Depth-Anything)
[](https://github.com/yuvraj108c/ComfyUI-PiperTTS)
[](https://www.buymeacoffee.com/yuvraj108cZ)
[](https://paypal.me/yuvraj108c)
---
## Installation
Navigate to the ComfyUI `/custom_nodes` directory
```bash
git clone https://github.com/yuvraj108c/ComfyUI_InvSR
cd ComfyUI_InvSR
pip install -r requirements.txt
```
## Usage
- Load [example workflow](workflows/invsr.json)
- Diffusers model (stabilityai/sd-turbo) will download automatically to `ComfyUI/models/diffusers`
- InvSR model (noise_predictor_sd_turbo_v5.pth) will download automatically to `ComfyUI/models/invsr`
- To deal with large images, e.g, 1k---->4k, set `chopping_size` 256
- If your GPU memory is limited, please set `chopping_batch_size` to 1
## Parameters
- `num_steps`: number of inference steps
- `cfg`: classifier-free guidance scale
- `batch_size`: Controls how many complete images are processed simultaneously
- `chopping_batch_size`: Controls how many patches from the same image are processed simultaneously
- `chopping_size`: Controls the size of patches when splitting large images
- `color_fix`: Method to fix color shift in processed images
## Updates
**28 April 2025**
- Update diffusers versions in requirements.txt to fix https://github.com/yuvraj108c/ComfyUI_InvSR/issues/26, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/21, https://github.com/yuvraj108c/ComfyUI_InvSR/issues/15
- Add support for `noise_predictor_sd_turbo_v5_diftune.pth`
**03 February 2025**
- Add cfg parameter
- Make image divisible by 16
- Use `mm` to set torch device
**31 January 2025**
- Merged https://github.com/yuvraj108c/ComfyUI_InvSR/pull/5 by [wfjsw](https://github.com/wfjsw)
- Compatibility with `diffusers>=0.28`
- Massive code refactoring & cleanup
## Citation
```bibtex
@article{yue2024InvSR,
title={Arbitrary-steps Image Super-resolution via Diffusion Inversion},
author={Yue, Zongsheng and Kang, Liao and Loy, Chen Change},
journal = {arXiv preprint arXiv:2412.09013},
year={2024},
}
```
## License
This project is licensed under [NTU S-Lab License 1.0](LICENSE)
## Acknowledgments
Thanks to [simplepod.ai](https://simplepod.ai/) for providing GPU servers
## Star History
[](https://star-history.com/#yuvraj108c/ComfyUI_InvSR&Date)
================================================
FILE: __init__.py
================================================
from .node import LoadInvSRModels, InvSRSampler
NODE_CLASS_MAPPINGS = {
"LoadInvSRModels" : LoadInvSRModels,
"InvSRSampler" : InvSRSampler
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LoadInvSRModels" : "Load InvSR Models",
"InvSRSampler" : "InvSRSampler"
}
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
================================================
FILE: comfyui_invsr_trimmed/__init__.py
================================================
from .inference_invsr import get_configs, Namespace
from .sampler_invsr import InvSamplerSR, BaseSampler
from .noise_predictor import NoisePredictor
from .time_aware_encoder import TimeAwareEncoder
__all__ = [
"get_configs",
"Namespace",
"InvSamplerSR",
"BaseSampler",
"NoisePredictor",
"TimeAwareEncoder"
]
================================================
FILE: comfyui_invsr_trimmed/inference_invsr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2023-03-11 17:17:41
import numpy as np
from pathlib import Path
from omegaconf import OmegaConf
from .sampler_invsr import InvSamplerSR, BaseSampler
from .utils import util_common
from .utils.util_opts import str2bool
from huggingface_hub import hf_hub_download
from shutil import copy2
class Namespace:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
items = [f"{key}={repr(value)}" for key, value in vars(self).items()]
return f"Namespace({', '.join(items)})"
def get_configs(args, log=False):
configs = OmegaConf.load(args.cfg_path)
if args.timesteps is not None:
assert len(args.timesteps) == args.num_steps
configs.timesteps = sorted(args.timesteps, reverse=True)
else:
if args.num_steps == 1:
configs.timesteps = [200,]
elif args.num_steps == 2:
configs.timesteps = [200, 100]
elif args.num_steps == 3:
configs.timesteps = [200, 100, 50]
elif args.num_steps == 4:
configs.timesteps = [200, 150, 100, 50]
elif args.num_steps == 5:
configs.timesteps = [250, 200, 150, 100, 50]
else:
assert args.num_steps <= 250
configs.timesteps = np.linspace(
start=args.started_step, stop=0, num=args.num_steps, endpoint=False, dtype=np.int64()
).tolist()
if log:
print(f'[InvSR] - Setting timesteps for inference: {configs.timesteps}')
# path to save Stable Diffusion
sd_path = args.sd_path if args.sd_path else "./weights"
util_common.mkdir(sd_path, delete=False, parents=True)
configs.sd_pipe.params.cache_dir = sd_path
# path to save noise predictor
started_ckpt_name = args.invsr_model
if getattr(args, "started_ckpt_dir", None) is not None:
started_ckpt_dir = args.started_ckpt_dir
else:
started_ckpt_dir = "./weights"
if getattr(args, "started_ckpt_path", None) is not None:
started_ckpt_path = args.started_ckpt_path
else:
started_ckpt_path = Path(started_ckpt_dir) / started_ckpt_name
util_common.mkdir(started_ckpt_dir, delete=False, parents=True)
if not Path(started_ckpt_path).exists():
temp_path = hf_hub_download(
repo_id="OAOA/InvSR",
filename=started_ckpt_name,
)
copy2(temp_path, started_ckpt_path)
configs.model_start.ckpt_path = str(started_ckpt_path)
configs.bs = args.bs
configs.tiled_vae = args.tiled_vae
configs.color_fix = args.color_fix
configs.basesr.chopping.pch_size = args.chopping_size
if args.bs > 1:
configs.basesr.chopping.extra_bs = 1
else:
configs.basesr.chopping.extra_bs = args.chopping_bs
return configs
================================================
FILE: comfyui_invsr_trimmed/latent_lpips/__init__.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
================================================
FILE: comfyui_invsr_trimmed/latent_lpips/lpips.py
================================================
from __future__ import absolute_import
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from . import pretrained_networks as pn
def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2,3],keepdim=keepdim)
def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
# Learned perceptual metric
class LPIPS(nn.Module):
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True,
latent=False, in_chans=3, verbose=True):
""" Initializes a perceptual loss torch.nn.Module
Parameters (default listed first)
---------------------------------
lpips : bool
[True] use linear layers on top of base/trunk network
[False] means no linear layers; each layer is averaged together
pretrained : bool
This flag controls the linear layers, which are only in effect when lpips=True above
[True] means linear layers are calibrated with human perceptual judgments
[False] means linear layers are randomly initialized
pnet_rand : bool
[False] means trunk loaded with ImageNet classification weights
[True] means randomly initialized trunk
net : str
['alex','vgg','squeeze'] are the base/trunk networks available
version : str
['v0.1'] is the default and latest
['v0.0'] contained a normalization bug; corresponds to old arxiv v1 (https://arxiv.org/abs/1801.03924v1)
model_path : 'str'
[None] is default and loads the pretrained weights from paper https://arxiv.org/abs/1801.03924v1
The following parameters should only be changed if training the network
eval_mode : bool
[True] is for test mode (default)
[False] is for training mode
pnet_tune
[False] keep base/trunk frozen
[True] tune the base/trunk network
use_dropout : bool
[True] to use dropout when training linear layers
[False] for no dropout when training linear layers
"""
super(LPIPS, self).__init__()
if(verbose):
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
self.pnet_type = net
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.latent = latent
self.lpips = lpips # false means baseline of just averaging all layers
self.version = version
self.scaling_layer = ScalingLayer()
if(self.pnet_type in ['vgg','vgg16']):
if not latent:
net_type = pn.vgg16
else:
net_type = pn.vgg16_latent
self.chns = [64,128,256,512,512]
elif(self.pnet_type=='alex'):
net_type = pn.alexnet
self.chns = [64,192,384,256,256]
elif(self.pnet_type=='squeeze'):
net_type = pn.squeezenet
self.chns = [64,128,256,384,384,512,512]
self.L = len(self.chns)
if latent:
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune, in_chans=in_chans)
else:
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if(lpips):
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins+=[self.lin5,self.lin6]
self.lins = nn.ModuleList(self.lins)
if(pretrained):
if(model_path is None):
import inspect
import os
model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
if(verbose):
print('Loading model from: %s'%model_path)
missing_keys, unexpected_keys = self.load_state_dict(
torch.load(model_path, map_location='cpu'),
strict=False,
)
print(f'Number of missing keys when loading chckepoint: {len(missing_keys)}')
print(f'Number of unexpected keys when loading chckepoint: {len(unexpected_keys)}')
if(eval_mode):
self.eval()
def forward(self, in0, in1, retPerLayer=False, normalize=False):
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
in0 = 2 * in0 - 1
in1 = 2 * in1 - 1
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if (not self.latent and self.version=='0.1') else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
val = 0
for l in range(self.L):
val += res[l]
if(retPerLayer):
return (val, res)
else:
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
def forward(self, inp):
return (inp - self.shift) / self.scale
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(),] if(use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
if(use_sigmoid):
layers += [nn.Sigmoid(),]
self.model = nn.Sequential(*layers)
def forward(self,d0,d1,eps=0.1):
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid, use_sigmoid=False)
# self.parameters = list(self.net.parameters())
# self.loss = torch.nn.BCELoss()
self.loss = torch.nn.BCEWithLogitsLoss()
def forward(self, d0, d1, judge):
per = (judge+1.)/2.
self.logit = self.net.forward(d0,d1)
return self.loss(self.logit, per)
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('Network',net)
print('Total number of parameters: %d' % num_params)
================================================
FILE: comfyui_invsr_trimmed/latent_lpips/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models as tv
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2,5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class vgg16_latent(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, in_chans=3):
super(vgg16_latent, self).__init__()
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
# max pooling layers: vgg_pretrained_features[5, 9, 16, 23]
for x in range(4):
assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
self.slice1.add_module(str(x), vgg_pretrained_features[x])
if (not in_chans == 3):
# assert in_chans == 4
weight = self.slice1[0].weight.data[:, 0,].unsqueeze(1).repeat(1, in_chans, 1, 1)
self.slice1[0].weight.data = weight
for x in range(5, 9): # skip max pooling at index 5
assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(10, 16): # skip max pooling at index 9
assert not isinstance(vgg_pretrained_features[x], torch.nn.MaxPool2d)
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(17, 23): # skip max pooling at index 16
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if(num==18):
self.net = tv.resnet18(pretrained=pretrained)
elif(num==34):
self.net = tv.resnet34(pretrained=pretrained)
elif(num==50):
self.net = tv.resnet50(pretrained=pretrained)
elif(num==101):
self.net = tv.resnet101(pretrained=pretrained)
elif(num==152):
self.net = tv.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
================================================
FILE: comfyui_invsr_trimmed/noise_predictor.py
================================================
from typing import Dict, Optional, Tuple, Union
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders.single_file_model import FromOriginalModelMixin
from diffusers.models.autoencoders.vae import (
Decoder,
DecoderOutput,
DiagonalGaussianDistribution,
Encoder,
)
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.utils.accelerate_utils import apply_forward_hook
from .time_aware_encoder import TimeAwareEncoder
class NoisePredictor(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A noise predicted model from the encoder of AutoencoderKL.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
mid_block_add_attention (`bool`, *optional*, default to `True`):
If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
mid_block will only have resnet blocks
temb_channels (`int`, *optional*, default to 256): Number of channels for time embedding
freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding.
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
Whether to flip sin to cos for Fourier time embedding.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
@register_to_config
def __init__(
self,
in_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
mid_block_add_attention: bool = True,
attention_head_dim: int = 1,
resnet_time_scale_shift: str = "default",
temb_channels: int = 256,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
double_z: bool = True,
):
super().__init__()
# pass init params to Encoder
self.encoder = TimeAwareEncoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=double_z,
mid_block_add_attention=mid_block_add_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
freq_shift=freq_shift,
flip_sin_to_cos=flip_sin_to_cos,
attention_head_dim=attention_head_dim,
)
self.use_slicing = False
self.use_tiling = False
self.double_z = double_z
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(
sample_size / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_tiling(self, use_tiling: bool = True):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.use_tiling = use_tiling
def disable_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.enable_tiling(False)
def enable_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self,
x: torch.Tensor,
timestep: Union[int, torch.Tensor],
return_dict: bool = True,
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (
x.shape[-1] > self.tile_sample_min_size
or x.shape[-2] > self.tile_sample_min_size
):
return self.tiled_encode(x, timestep, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice, timestep) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x, timestep)
if not self.double_z:
return h
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_encode(
self,
x: torch.Tensor,
timestep: Union[int, torch.Tensor],
return_dict: bool = True,
) -> AutoencoderKLOutput:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
`tuple` is returned.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile, timestep)
if self.config.use_quant_conv:
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def forward(
self,
sample: torch.Tensor,
timesteps: torch.Tensor,
sample_posterior: bool = True,
center_input_sample: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
if center_input_sample:
sample = sample * 2 - 1.0
if not self.double_z:
h = self.encode(sample, timesteps)
return h
else:
posterior = self.encode(sample, timesteps).latent_dist
if sample_posterior:
return posterior.sample()
else:
return posterior
================================================
FILE: comfyui_invsr_trimmed/pipeline_stable_diffusion_inversion_sr.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
import numpy as np
import PIL.Image
import torch
from packaging import version
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import requests
>>> import torch
>>> from PIL import Image
>>> from io import BytesIO
>>> from diffusers import StableDiffusionImg2ImgPipeline
>>> device = "cuda"
>>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> response = requests.get(url)
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
>>> init_image = init_image.resize((768, 512))
>>> prompt = "A fantasy landscape, trending on artstation"
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
>>> images[0].save("fantasy_landscape.png")
```
"""
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
if isinstance(image, torch.Tensor):
return image
elif isinstance(image, PIL.Image.Image):
image = [image]
if isinstance(image[0], PIL.Image.Image):
w, h = image[0].size
w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8
image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image]
image = np.concatenate(image, axis=0)
image = np.array(image).astype(np.float32) / 255.0
image = image.transpose(0, 3, 1, 2)
image = 2.0 * image - 1.0
image = torch.from_numpy(image)
elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, dim=0)
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
timesteps: Optional[List[int]] = None,
device: Optional[Union[str, torch.device]] = None,
**kwargs,
):
"""
Prepare the sampling timesteps and noise sigmas.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
num_inference_steps = len(timesteps)
timesteps = torch.tensor(timesteps, dtype=torch.float32, device=device) - 1
scheduler.timesteps = timesteps
if not hasattr(scheduler, 'sigmas_cache'):
scheduler.sigmas_cache = scheduler.sigmas.flip(0)[1:].to(device) #ascending,1000
sigmas = scheduler.sigmas_cache[timesteps.long()]
# minimal sigma
if scheduler.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - scheduler.alphas_cumprod[0]) / scheduler.alphas_cumprod[0]) ** 0.5
elif scheduler.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {scheduler.config.final_sigmas_type}"
)
sigma_last = torch.tensor([sigma_last,], dtype=torch.float32).to(device=sigmas.device)
sigmas = torch.cat([sigmas, sigma_last]).type(torch.float32)
scheduler.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
scheduler._step_index = None
scheduler._begin_index = None
return scheduler.timesteps, num_inference_steps
class StableDiffusionInvEnhancePipeline(
DiffusionPipeline,
StableDiffusionMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
StableDiffusionLoraLoaderMixin,
FromSingleFileMixin,
):
r"""
Pipeline for text-guided image-to-image generation using Stable Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`~transformers.CLIPTextModel`]):
Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
tokenizer ([`~transformers.CLIPTokenizer`]):
A `CLIPTokenizer` to tokenize text.
unet ([`UNet2DConditionModel`]):
A `UNet2DConditionModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
image_encoder: CLIPVisionModelWithProjection = None,
requires_safety_checker: bool = True,
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
**kwargs,
):
deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
prompt_embeds_tuple = self.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
**kwargs,
)
# concatenate for backwards comp
prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
elif self.unet is not None:
prompt_embeds_dtype = self.unet.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens: List[str]
if negative_prompt is None:
uncond_tokens = [""] * batch_size
elif prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif isinstance(negative_prompt, str):
uncond_tokens = [negative_prompt]
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
image_embeds = []
if do_classifier_free_guidance:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if do_classifier_free_guidance:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents(self, latents):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
strength,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
ip_adapter_image=None,
ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
)
if ip_adapter_image_embeds is not None:
if not isinstance(ip_adapter_image_embeds, list):
raise ValueError(
f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
)
elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
raise ValueError(
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
)
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def prepare_latents(
self, image, timestep, batch_size, num_images_per_prompt, dtype, device,
noise=None, generator=None,
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
init_latents = image
else:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
elif isinstance(generator, list):
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
)
init_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
init_latents = self.vae.config.scaling_factor * init_latents
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
# expand init_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // init_latents.shape[0]
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
)
else:
init_latents = torch.cat([init_latents], dim=0)
shape = init_latents.shape
if noise is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# get latents
init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
latents = init_latents
return latents.type(dtype)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.Tensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
target_size: Tuple[int] = (512, 512),
strength: float = 0.25,
num_inference_steps: Optional[int] = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: Optional[float] = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: int = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an low quality image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
target_size ('Tuple[int]'): Targeted image resolution (height, width)
strength (`float`, *optional*, defaults to 0.25):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 20):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
start_noise_predictor ('nn.Module', *optional*): Noise predictor for the initial step
intermediate_noise_predictor ('nn.Module', *optional*): Noise predictor for the intermediate steps.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
strength,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
self.do_classifier_free_guidance,
)
# 4. Preprocess image
self.image_processor.config.do_normalize = False
image = self.image_processor.preprocess(image) # [0, 1], torch tensor, (b,c,h,w)
self.image_processor.config.do_normalize = True
image_up = torch.nn.functional.interpolate(image, size=target_size, mode='bicubic') # upsampling
image_up = self.image_processor.normalize(image_up) # [-1, 1]
# 5. set timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, timesteps, device)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 6. Prepare latent variables
if getattr(self, 'start_noise_predictor', None) is not None:
with torch.amp.autocast('cuda'):
noise = self.start_noise_predictor(
image, latent_timestep, sample_posterior=True, center_input_sample=True,
)
else:
noise = None
latents = self.prepare_latents(
image_up, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype,
device, noise, generator,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{"image_embeds": image_embeds}
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)
# 7.2 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if getattr(self, 'intermediate_noise_predictor', None) is not None and i + 1 < len(timesteps):
t_next = timesteps[i+1]
with torch.amp.autocast('cuda'):
noise = self.intermediate_noise_predictor(image, t_next, center_input_sample=True)
extra_step_kwargs['noise'] = noise
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
================================================
FILE: comfyui_invsr_trimmed/sampler_invsr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-07-13 16:59:27
import os, sys, math, random
import numpy as np
from pathlib import Path
from .utils import util_net
from .utils import util_image
from .utils import util_common
from .utils import util_color_fix
import torch
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mean_psnr
from .pipeline_stable_diffusion_inversion_sr import StableDiffusionInvEnhancePipeline
from diffusers import AutoencoderKL
import comfy.model_management as mm
DEVICE = mm.get_torch_device()
_positive= 'Cinematic, high-contrast, photo-realistic, 8k, ultra HD, ' +\
'meticulous detailing, hyper sharpness, perfect without deformations'
_negative= 'Low quality, blurring, jpeg artifacts, deformed, over-smooth, cartoon, noisy,' +\
'painting, drawing, sketch, oil painting'
class BaseSampler:
def __init__(self, configs):
'''
Input:
configs: config, see the yaml file in folder ./configs/
configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}
seed: int, random seed
'''
self.configs = configs
self.setup_seed()
self.build_model()
def setup_seed(self, seed=None):
seed = self.configs.seed if seed is None else seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def write_log(self, log_str):
print(log_str, flush=True)
def build_model(self):
# Build Stable diffusion
params = dict(self.configs.sd_pipe.params)
torch_dtype = params.pop('torch_dtype')
params['torch_dtype'] = get_torch_dtype(torch_dtype)
base_pipe = util_common.get_obj_from_str(self.configs.sd_pipe.target).from_pretrained(**params)
if self.configs.get('scheduler', None) is not None:
pipe_id = self.configs.scheduler.target.split('.')[-1]
self.write_log(f'Loading scheduler of {pipe_id}...')
base_pipe.scheduler = util_common.get_obj_from_str(self.configs.scheduler.target).from_config(
base_pipe.scheduler.config
)
self.write_log('Loaded Done')
if self.configs.get('vae_fp16', None) is not None:
params_vae = dict(self.configs.vae_fp16.params)
torch_dtype = params_vae.pop('torch_dtype')
params_vae['torch_dtype'] = get_torch_dtype(torch_dtype)
pipe_id = self.configs.vae_fp16.params.pretrained_model_name_or_path
self.write_log(f'Loading improved vae from {pipe_id}...')
base_pipe.vae = util_common.get_obj_from_str(self.configs.vae_fp16.target).from_pretrained(
**params_vae,
)
self.write_log('Loaded Done')
if self.configs.base_model in ['sd-turbo', 'sd2base'] :
sd_pipe = StableDiffusionInvEnhancePipeline.from_pipe(base_pipe)
else:
raise ValueError(f"Unsupported base model: {self.configs.base_model}!")
sd_pipe.to(DEVICE)
if self.configs.sliced_vae:
sd_pipe.vae.enable_slicing()
if self.configs.tiled_vae:
sd_pipe.vae.enable_tiling()
sd_pipe.vae.tile_latent_min_size = self.configs.latent_tiled_size
sd_pipe.vae.tile_sample_min_size = self.configs.sample_tiled_size
if self.configs.gradient_checkpointing_vae:
self.write_log(f"Activating gradient checkpoing for vae...")
sd_pipe.vae.enable_gradient_checkpointing()
model_configs = self.configs.model_start
params = model_configs.get('params', dict)
model_start = util_common.get_obj_from_str(model_configs.target)(**params)
model_start.to(DEVICE)
ckpt_path = model_configs.get('ckpt_path')
assert ckpt_path is not None
self.write_log(f"[InvSR] - Loading started model from {ckpt_path}...")
state = torch.load(ckpt_path, map_location=DEVICE)
if 'state_dict' in state:
state = state['state_dict']
util_net.reload_model(model_start, state)
# self.write_log(f"Loading Done")
model_start.eval()
setattr(sd_pipe, 'start_noise_predictor', model_start)
self.sd_pipe = sd_pipe
class InvSamplerSR(BaseSampler):
def __init__(self, base_sampler):
self.configs = base_sampler.configs
self.sd_pipe = base_sampler.sd_pipe
@torch.no_grad()
def sample_func(self, im_cond):
'''
Input:
im_cond: b x c x h x w, torch tensor, [0,1], RGB
Output:
xt: h x w x c, numpy array, [0,1], RGB
'''
ori_h_lq, ori_w_lq = im_cond.shape[-2:]
ori_w_hq = ori_w_lq * self.configs.basesr.sf
ori_h_hq = ori_h_lq * self.configs.basesr.sf
vae_sf = (2 ** (len(self.sd_pipe.vae.config.block_out_channels) - 1))
if hasattr(self.sd_pipe, 'unet'):
diffusion_sf = (2 ** (len(self.sd_pipe.unet.config.block_out_channels) - 1))
else:
diffusion_sf = self.sd_pipe.transformer.patch_size
mod_lq = vae_sf // self.configs.basesr.sf * diffusion_sf
idle_pch_size = self.configs.basesr.chopping.pch_size
if min(im_cond.shape[-2:]) >= idle_pch_size:
pad_h_up = pad_w_left = 0
else:
while min(im_cond.shape[-2:]) < idle_pch_size:
pad_h_up = max(min((idle_pch_size - im_cond.shape[-2]) // 2, im_cond.shape[-2]-1), 0)
pad_h_down = max(min(idle_pch_size - im_cond.shape[-2] - pad_h_up, im_cond.shape[-2]-1), 0)
pad_w_left = max(min((idle_pch_size - im_cond.shape[-1]) // 2, im_cond.shape[-1]-1), 0)
pad_w_right = max(min(idle_pch_size - im_cond.shape[-1] - pad_w_left, im_cond.shape[-1]-1), 0)
im_cond = F.pad(im_cond, pad=(pad_w_left, pad_w_right, pad_h_up, pad_h_down), mode='reflect')
if im_cond.shape[-2] == idle_pch_size and im_cond.shape[-1] == idle_pch_size:
target_size = (
im_cond.shape[-2] * self.configs.basesr.sf,
im_cond.shape[-1] * self.configs.basesr.sf
)
res_sr = self.sd_pipe(
image=im_cond.type(torch.float16),
prompt=[_positive, ]*im_cond.shape[0],
negative_prompt=[_negative, ]*im_cond.shape[0] if self.configs.cfg_scale > 1.0 else None,
target_size=target_size,
timesteps=self.configs.timesteps,
guidance_scale=self.configs.cfg_scale,
output_type="pt", # torch tensor, b x c x h x w, [0, 1]
).images
else:
if not (im_cond.shape[-2] % mod_lq == 0 and im_cond.shape[-1] % mod_lq == 0):
target_h_lq = math.ceil(im_cond.shape[-2] / mod_lq) * mod_lq
target_w_lq = math.ceil(im_cond.shape[-1] / mod_lq) * mod_lq
pad_h = target_h_lq - im_cond.shape[-2]
pad_w = target_w_lq - im_cond.shape[-1]
im_cond= F.pad(im_cond, pad=(0, pad_w, 0, pad_h), mode='reflect')
im_spliter = util_image.ImageSpliterTh(
im_cond,
pch_size=idle_pch_size,
stride= int(idle_pch_size * 0.50),
sf=self.configs.basesr.sf,
weight_type=self.configs.basesr.chopping.weight_type,
extra_bs=self.configs.basesr.chopping.extra_bs,
)
# pbar = ProgressBar(len(im_spliter) * im_cond.shape[0])
for im_lq_pch, index_infos in im_spliter:
target_size = (
im_lq_pch.shape[-2] * self.configs.basesr.sf,
im_lq_pch.shape[-1] * self.configs.basesr.sf,
)
# start = torch.cuda.Event(enable_timing=True)
# end = torch.cuda.Event(enable_timing=True)
# start.record()
res_sr_pch = self.sd_pipe(
image=im_lq_pch.type(torch.float16),
prompt=[_positive, ]*im_lq_pch.shape[0],
negative_prompt=[_negative, ]*im_lq_pch.shape[0] if self.configs.cfg_scale > 1.0 else None,
target_size=target_size,
timesteps=self.configs.timesteps,
guidance_scale=self.configs.cfg_scale,
output_type="pt", # torch tensor, b x c x h x w, [0, 1]
).images
# end.record()
# torch.cuda.synchronize()
# print(f"Time: {start.elapsed_time(end):.6f}")
im_spliter.update(res_sr_pch, index_infos)
# pbar.update(im_lq_pch.shape[0])
res_sr = im_spliter.gather()
pad_h_up *= self.configs.basesr.sf
pad_w_left *= self.configs.basesr.sf
res_sr = res_sr[:, :, pad_h_up:ori_h_hq+pad_h_up, pad_w_left:ori_w_hq+pad_w_left]
if self.configs.color_fix:
im_cond_up = F.interpolate(
im_cond, size=res_sr.shape[-2:], mode='bicubic', align_corners=False, antialias=True
)
if self.configs.color_fix == 'ycbcr':
res_sr = util_color_fix.ycbcr_color_replace(res_sr, im_cond_up)
elif self.configs.color_fix == 'wavelet':
res_sr = util_color_fix.wavelet_reconstruction(res_sr, im_cond_up)
else:
raise ValueError(f"Unsupported color fixing type: {self.configs.color_fix}")
res_sr = res_sr.clamp(0.0, 1.0).cpu().float().numpy()
return res_sr
def inference(self, image_bchw):
return self.sample_func(image_bchw.to(DEVICE))
def get_torch_dtype(torch_dtype: str):
if torch_dtype == 'torch.float16':
return torch.float16
elif torch_dtype == 'torch.bfloat16':
return torch.bfloat16
elif torch_dtype == 'torch.float32':
return torch.float32
else:
raise ValueError(f'Unexpected torch dtype:{torch_dtype}')
================================================
FILE: comfyui_invsr_trimmed/time_aware_encoder.py
================================================
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from diffusers.utils import is_torch_version
from diffusers.models.unets.unet_2d_blocks import (
UNetMidBlock2D,
get_down_block,
)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
class TimeAwareEncoder(nn.Module):
r"""
The `TimeAwareEncoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
options.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
The number of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2):
The number of layers per block.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
act_fn (`str`, *optional*, defaults to `"silu"`):
The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
resnet_time_scale_shift (`str`, defaults to `"default"`)
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
block_out_channels: Tuple[int, ...] = (64,),
layers_per_block: Union[int, Tuple[int, ...]] = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention=True,
resnet_time_scale_shift: str = "default",
temb_channels: int = 256,
freq_shift: int = 0,
flip_sin_to_cos: bool = True,
attention_head_dim: int = 1,
):
super().__init__()
if isinstance(layers_per_block, int):
layers_per_block = (layers_per_block,) * len(down_block_types)
self.layers_per_block = layers_per_block
timestep_input_dim = max(128, block_out_channels[0])
self.time_proj = Timesteps(timestep_input_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(timestep_input_dim, temb_channels)
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
attention_head_dim=attention_head_dim,
resnet_groups=norm_num_groups,
add_attention=mid_block_add_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
temb_channels=temb_channels,
)
# out
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
self.conv_act = nn.SiLU()
conv_out_channels = 2 * out_channels if double_z else out_channels
self.conv_out = nn.Conv2d(
block_out_channels[-1], conv_out_channels, 3, padding=1
)
self.gradient_checkpointing = False
def forward(
self,
sample: torch.Tensor,
timesteps: Union[torch.Tensor, int],
) -> torch.Tensor:
r"""The forward method of the `Encoder` class."""
# time embedding
if not torch.is_tensor(timesteps):
timesteps = torch.tensor(
[timesteps], dtype=torch.long, device=sample.device
)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(
sample.shape[0], dtype=timesteps.dtype, device=timesteps.device
)
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=list(self.time_embedding.parameters())[0].dtype)
emb = self.time_embedding(t_emb)
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
# down
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block),
sample,
emb,
use_reentrant=False,
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block),
sample,
emb,
use_reentrant=False,
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, emb
)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, emb
)
else:
# down
for down_block in self.down_blocks:
sample, _ = down_block(sample, emb)
# middle
sample = self.mid_block(sample, emb)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
return sample
================================================
FILE: comfyui_invsr_trimmed/utils/__init__.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-01-18 11:40:23
================================================
FILE: comfyui_invsr_trimmed/utils/resize.py
================================================
"""
A standalone PyTorch implementation for fast and efficient bicubic resampling.
The resulting values are the same to MATLAB function imresize('bicubic').
## Author: Sanghyun Son
## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
## Version: 1.2.0
## Last update: July 9th, 2020 (KST)
Dependency: torch
Example::
>>> import torch
>>> import core
>>> x = torch.arange(16).float().view(1, 1, 4, 4)
>>> y = core.imresize(x, sizes=(3, 3))
>>> print(y)
tensor([[[[ 0.7506, 2.1004, 3.4503],
[ 6.1505, 7.5000, 8.8499],
[11.5497, 12.8996, 14.2494]]]])
"""
import math
import typing
import torch
from torch.nn import functional as F
__all__ = ['imresize']
_I = typing.Optional[int]
_D = typing.Optional[torch.dtype]
def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
cont = range_around_0.to(dtype=x.dtype)
return cont
def linear_contribution(x: torch.Tensor) -> torch.Tensor:
ax = x.abs()
range_01 = ax.le(1)
cont = (1 - ax) * range_01.to(dtype=x.dtype)
return cont
def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
ax = x.abs()
ax2 = ax * ax
ax3 = ax * ax2
range_01 = ax.le(1)
range_12 = torch.logical_and(ax.gt(1), ax.le(2))
cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
cont_01 = cont_01 * range_01.to(dtype=x.dtype)
cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
cont_12 = cont_12 * range_12.to(dtype=x.dtype)
cont = cont_01 + cont_12
return cont
def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
range_3sigma = (x.abs() <= 3 * sigma + 1)
# Normalization will be done after
cont = torch.exp(-x.pow(2) / (2 * sigma**2))
cont = cont * range_3sigma.to(dtype=x.dtype)
return cont
def discrete_kernel(kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
'''
For downsampling with integer scale only.
'''
downsampling_factor = int(1 / scale)
if kernel == 'cubic':
kernel_size_orig = 4
else:
raise ValueError('Pass!')
if antialiasing:
kernel_size = kernel_size_orig * downsampling_factor
else:
kernel_size = kernel_size_orig
if downsampling_factor % 2 == 0:
a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
else:
kernel_size -= 1
a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
with torch.no_grad():
r = torch.linspace(-a, a, steps=kernel_size)
k = cubic_contribution(r).view(-1, 1)
k = torch.matmul(k, k.t())
k /= k.sum()
return k
def reflect_padding(x: torch.Tensor, dim: int, pad_pre: int, pad_post: int) -> torch.Tensor:
'''
Apply reflect padding to the given Tensor.
Note that it is slightly different from the PyTorch functional.pad,
where boundary elements are used only once.
Instead, we follow the MATLAB implementation
which uses boundary elements twice.
For example,
[a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
while our implementation yields [a, a, b, c, d, d].
'''
b, c, h, w = x.size()
if dim == 2 or dim == -2:
padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
for p in range(pad_pre):
padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
for p in range(pad_post):
padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
else:
padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
for p in range(pad_pre):
padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
for p in range(pad_post):
padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
return padding_buffer
def padding(x: torch.Tensor,
dim: int,
pad_pre: int,
pad_post: int,
padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
if padding_type is None:
return x
elif padding_type == 'reflect':
x_pad = reflect_padding(x, dim, pad_pre, pad_post)
else:
raise ValueError('{} padding is not supported!'.format(padding_type))
return x_pad
def get_padding(base: torch.Tensor, kernel_size: int, x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
base = base.long()
r_min = base.min()
r_max = base.max() + kernel_size - 1
if r_min <= 0:
pad_pre = -r_min
pad_pre = pad_pre.item()
base += pad_pre
else:
pad_pre = 0
if r_max >= x_size:
pad_post = r_max - x_size + 1
pad_post = pad_post.item()
else:
pad_post = 0
return pad_pre, pad_post, base
def get_weight(dist: torch.Tensor,
kernel_size: int,
kernel: str = 'cubic',
sigma: float = 2.0,
antialiasing_factor: float = 1) -> torch.Tensor:
buffer_pos = dist.new_zeros(kernel_size, len(dist))
for idx, buffer_sub in enumerate(buffer_pos):
buffer_sub.copy_(dist - idx)
# Expand (downsampling) / Shrink (upsampling) the receptive field.
buffer_pos *= antialiasing_factor
if kernel == 'cubic':
weight = cubic_contribution(buffer_pos)
elif kernel == 'gaussian':
weight = gaussian_contribution(buffer_pos, sigma=sigma)
else:
raise ValueError('{} kernel is not supported!'.format(kernel))
weight /= weight.sum(dim=0, keepdim=True)
return weight
def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
# Resize height
if dim == 2 or dim == -2:
k = (kernel_size, 1)
h_out = x.size(-2) - kernel_size + 1
w_out = x.size(-1)
# Resize width
else:
k = (1, kernel_size)
h_out = x.size(-2)
w_out = x.size(-1) - kernel_size + 1
unfold = F.unfold(x, k)
unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
return unfold
def reshape_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, int, int]:
if x.dim() == 4:
b, c, h, w = x.size()
elif x.dim() == 3:
c, h, w = x.size()
b = None
elif x.dim() == 2:
h, w = x.size()
b = c = None
else:
raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
x = x.view(-1, 1, h, w)
return x, b, c, h, w
def reshape_output(x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
rh = x.size(-2)
rw = x.size(-1)
# Back to the original dimension
if b is not None:
x = x.view(b, c, rh, rw) # 4-dim
else:
if c is not None:
x = x.view(c, rh, rw) # 3-dim
else:
x = x.view(rh, rw) # 2-dim
return x
def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
if x.dtype != torch.float32 or x.dtype != torch.float64:
dtype = x.dtype
x = x.float()
else:
dtype = None
return x, dtype
def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
if dtype is not None:
if not dtype.is_floating_point:
x = x - x.detach() + x.round()
# To prevent over/underflow when converting types
if dtype is torch.uint8:
x = x.clamp(0, 255)
x = x.to(dtype=dtype)
return x
def resize_1d(x: torch.Tensor,
dim: int,
size: int,
scale: float,
kernel: str = 'cubic',
sigma: float = 2.0,
padding_type: str = 'reflect',
antialiasing: bool = True) -> torch.Tensor:
'''
Args:
x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
dim (int):
scale (float):
size (int):
Return:
'''
# Identity case
if scale == 1:
return x
# Default bicubic kernel with antialiasing (only when downsampling)
if kernel == 'cubic':
kernel_size = 4
else:
kernel_size = math.floor(6 * sigma)
if antialiasing and (scale < 1):
antialiasing_factor = scale
kernel_size = math.ceil(kernel_size / antialiasing_factor)
else:
antialiasing_factor = 1
# We allow margin to both sizes
kernel_size += 2
# Weights only depend on the shape of input and output,
# so we do not calculate gradients here.
with torch.no_grad():
pos = torch.linspace(
0,
size - 1,
steps=size,
dtype=x.dtype,
device=x.device,
)
pos = (pos + 0.5) / scale - 0.5
base = pos.floor() - (kernel_size // 2) + 1
dist = pos - base
weight = get_weight(
dist,
kernel_size,
kernel=kernel,
sigma=sigma,
antialiasing_factor=antialiasing_factor,
)
pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
# To backpropagate through x
x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
unfold = reshape_tensor(x_pad, dim, kernel_size)
# Subsampling first
if dim == 2 or dim == -2:
sample = unfold[..., base, :]
weight = weight.view(1, kernel_size, sample.size(2), 1)
else:
sample = unfold[..., base]
weight = weight.view(1, kernel_size, 1, sample.size(3))
# Apply the kernel
x = sample * weight
x = x.sum(dim=1, keepdim=True)
return x
def downsampling_2d(x: torch.Tensor, k: torch.Tensor, scale: int, padding_type: str = 'reflect') -> torch.Tensor:
c = x.size(1)
k_h = k.size(-2)
k_w = k.size(-1)
k = k.to(dtype=x.dtype, device=x.device)
k = k.view(1, 1, k_h, k_w)
k = k.repeat(c, c, 1, 1)
e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
e = e.view(c, c, 1, 1)
k = k * e
pad_h = (k_h - scale) // 2
pad_w = (k_w - scale) // 2
x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
y = F.conv2d(x, k, padding=0, stride=scale)
return y
def imresize(x: torch.Tensor,
scale: typing.Optional[float] = None,
sizes: typing.Optional[typing.Tuple[int, int]] = None,
kernel: typing.Union[str, torch.Tensor] = 'cubic',
sigma: float = 2,
rotation_degree: float = 0,
padding_type: str = 'reflect',
antialiasing: bool = True) -> torch.Tensor:
"""
Args:
x (torch.Tensor):
scale (float):
sizes (tuple(int, int)):
kernel (str, default='cubic'):
sigma (float, default=2):
rotation_degree (float, default=0):
padding_type (str, default='reflect'):
antialiasing (bool, default=True):
Return:
torch.Tensor:
"""
if scale is None and sizes is None:
raise ValueError('One of scale or sizes must be specified!')
if scale is not None and sizes is not None:
raise ValueError('Please specify scale or sizes to avoid conflict!')
x, b, c, h, w = reshape_input(x)
if sizes is None and scale is not None:
'''
# Check if we can apply the convolution algorithm
scale_inv = 1 / scale
if isinstance(kernel, str) and scale_inv.is_integer():
kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
raise ValueError(
'An integer downsampling factor '
'should be used with a predefined kernel!'
)
'''
# Determine output size
sizes = (math.ceil(h * scale), math.ceil(w * scale))
scales = (scale, scale)
if scale is None and sizes is not None:
scales = (sizes[0] / h, sizes[1] / w)
x, dtype = cast_input(x)
if isinstance(kernel, str) and sizes is not None:
# Core resizing module
x = resize_1d(
x,
-2,
size=sizes[0],
scale=scales[0],
kernel=kernel,
sigma=sigma,
padding_type=padding_type,
antialiasing=antialiasing)
x = resize_1d(
x,
-1,
size=sizes[1],
scale=scales[1],
kernel=kernel,
sigma=sigma,
padding_type=padding_type,
antialiasing=antialiasing)
elif isinstance(kernel, torch.Tensor) and scale is not None:
x = downsampling_2d(x, kernel, scale=int(1 / scale))
x = reshape_output(x, b, c)
x = cast_output(x, dtype)
return x
================================================
FILE: comfyui_invsr_trimmed/utils/util_color_fix.py
================================================
'''
# --------------------------------------------------------------------------------
# Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)
# --------------------------------------------------------------------------------
'''
import torch
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms import ToTensor, ToPILImage
from .util_image import rgb2ycbcrTorch, ycbcr2rgbTorch
def calc_mean_std(feat: Tensor, eps=1e-5):
"""Calculate mean and std for adaptive_instance_normalization.
Args:
feat (Tensor): 4D tensor.
eps (float): A small value added to the variance to avoid
divide-by-zero. Default: 1e-5.
"""
size = feat.size()
assert len(size) == 4, 'The input feature should be 4D tensor.'
b, c = size[:2]
feat_var = feat.reshape(b, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().reshape(b, c, 1, 1)
feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)
return feat_mean, feat_std
def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):
"""Adaptive instance normalization.
Adjust the reference features to have the similar color and illuminations
as those in the degradate features.
Args:
content_feat (Tensor): The reference feature.
style_feat (Tensor): The degradate features.
"""
size = content_feat.size()
style_mean, style_std = calc_mean_std(style_feat)
content_mean, content_std = calc_mean_std(content_feat)
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
def wavelet_blur(image: Tensor, radius: int):
"""
Apply wavelet blur to the input tensor.
"""
# input shape: (1, 3, H, W)
# convolution kernel
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
# add channel dimensions to the kernel to make it a 4D tensor
kernel = kernel[None, None]
# repeat the kernel across all input channels
kernel = kernel.repeat(3, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
# apply convolution
output = F.conv2d(image, kernel, groups=3, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels=5):
"""
Apply wavelet decomposition to the input tensor.
This function only returns the low frequency & the high frequency.
"""
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq += (image - low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
"""
Apply wavelet decomposition, so that the content will have the same color as the style.
"""
# calculate the wavelet decomposition of the content feature
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
# calculate the wavelet decomposition of the style feature
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
# reconstruct the content feature with the style's high frequency
return content_high_freq + style_low_freq
def ycbcr_color_replace(content_feat:Tensor, style_feat:Tensor):
"""
Apply ycbcr decomposition, so that the content will have the same color as the style.
"""
content_y = rgb2ycbcrTorch(content_feat, only_y=True)
style_ycbcr = rgb2ycbcrTorch(style_feat, only_y=False)
target_ycbcr = torch.cat([content_y, style_ycbcr[:, 1:,]], dim=1)
target_rgb = ycbcr2rgbTorch(target_ycbcr)
return target_rgb
================================================
FILE: comfyui_invsr_trimmed/utils/util_common.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2022-02-06 10:34:59
import os
import random
import requests
import importlib
from pathlib import Path
def mkdir(dir_path, delete=False, parents=True):
import shutil
if not isinstance(dir_path, Path):
dir_path = Path(dir_path)
if delete:
if dir_path.exists():
shutil.rmtree(str(dir_path))
if not dir_path.exists():
dir_path.mkdir(parents=parents)
def get_obj_from_str(string, reload=False):
current_package = __package__.rsplit(".", 1)[0]
is_relative_import = string.startswith(".")
package = current_package if is_relative_import else None
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module, package=package)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=package), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
def get_filenames(dir_path, exts=['png', 'jpg'], recursive=True):
'''
Get the file paths in the given folder.
param exts: list, e.g., ['png',]
return: list
'''
if not isinstance(dir_path, Path):
dir_path = Path(dir_path)
file_paths = []
for current_ext in exts:
if recursive:
file_paths.extend([str(x) for x in dir_path.glob('**/*.'+current_ext)])
else:
file_paths.extend([str(x) for x in dir_path.glob('*.'+current_ext)])
return file_paths
def readline_txt(txt_file):
txt_file = [txt_file, ] if isinstance(txt_file, str) else txt_file
out = []
for txt_file_current in txt_file:
with open(txt_file_current, 'r') as ff:
out.extend([x[:-1] for x in ff.readlines()])
return out
def scan_files_from_folder(dir_paths, exts, recursive=True):
'''
Scaning images from given folder.
Input:
dir_pathas: str or list.
exts: list
'''
exts = [exts, ] if isinstance(exts, str) else exts
dir_paths = [dir_paths, ] if isinstance(dir_paths, str) else dir_paths
file_paths = []
for current_dir in dir_paths:
current_dir = Path(current_dir) if not isinstance(current_dir, Path) else current_dir
for current_ext in exts:
if recursive:
search_flag = f"**/*.{current_ext}"
else:
search_flag = f"*.{current_ext}"
file_paths.extend(sorted([str(x) for x in Path(current_dir).glob(search_flag)]))
return file_paths
def write_path_to_txt(
dir_folder,
txt_path,
search_key,
num_files=None,
write_only_name=False,
write_only_stem=False,
shuffle=False,
):
'''
Scaning the files in the given folder and write them into a txt file
Input:
dir_folder: path of the target folder
txt_path: path to save the txt file
search_key: e.g., '*.png'
write_only_name: bool, only record the file names (including extension),
write_only_stem: bool, only record the file names (not including extension),
'''
txt_path = Path(txt_path) if not isinstance(txt_path, Path) else txt_path
dir_folder = Path(dir_folder) if not isinstance(dir_folder, Path) else dir_folder
if txt_path.exists():
txt_path.unlink()
if write_only_name:
path_list = sorted([str(x.name) for x in dir_folder.glob(search_key)])
elif write_only_stem:
path_list = sorted([str(x.stem) for x in dir_folder.glob(search_key)])
else:
path_list = sorted([str(x) for x in dir_folder.glob(search_key)])
if shuffle:
random.shuffle(path_list)
if num_files is not None:
path_list = path_list[:num_files]
with open(txt_path, mode='w') as ff:
for line in path_list:
ff.write(line+'\n')
================================================
FILE: comfyui_invsr_trimmed/utils/util_ema.py
================================================
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
"""
Copying the ema state (i.e., buffers) to the targeted model
Input:
model: targeted model
"""
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the parameters of the targeted model into the temporary pool for restoring later.
Args:
parameters: parameters of the targeted model.
Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters from the temporaty pool (stored with the `store` method).
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def resume(self, ckpt, num_updates):
"""
Resume from the targeted checkpoint, i.e., copying the checkpoints to ema buffers
Input:
model: targerted model
"""
self.register_buffer('num_updates', torch.tensor(num_updates, dtype=torch.int))
shadow_params = dict(self.named_buffers())
for key, value in ckpt.items():
try:
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
except:
if key.startswith('module') and key not in shadow_params:
key = key[7:]
shadow_params[self.m_name2s_name[key]].data.copy_(value.data)
================================================
FILE: comfyui_invsr_trimmed/utils/util_image.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 16:54:19
import sys
import cv2
import math
import torch
import random
import numpy as np
from pathlib import Path
# --------------------------Metrics----------------------------
def ssim(img1, img2):
C1 = (0.01 * 255)**2
C2 = (0.03 * 255)**2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return ssim_map.mean()
def calculate_ssim(im1, im2, border=0, ycbcr=False):
'''
SSIM the same outputs as MATLAB's
im1, im2: h x w x , [0, 255], uint8
'''
if not im1.shape == im2.shape:
raise ValueError('Input images must have the same dimensions.')
if ycbcr:
im1 = rgb2ycbcr(im1, True)
im2 = rgb2ycbcr(im2, True)
h, w = im1.shape[:2]
im1 = im1[border:h-border, border:w-border]
im2 = im2[border:h-border, border:w-border]
if im1.ndim == 2:
return ssim(im1, im2)
elif im1.ndim == 3:
if im1.shape[2] == 3:
ssims = []
for i in range(3):
ssims.append(ssim(im1[:,:,i], im2[:,:,i]))
return np.array(ssims).mean()
elif im1.shape[2] == 1:
return ssim(np.squeeze(im1), np.squeeze(im2))
else:
raise ValueError('Wrong input image dimensions.')
def calculate_psnr(im1, im2, border=0, ycbcr=False):
'''
PSNR metric.
im1, im2: h x w x , [0, 255], uint8
'''
if not im1.shape == im2.shape:
raise ValueError('Input images must have the same dimensions.')
if ycbcr:
im1 = rgb2ycbcr(im1, True)
im2 = rgb2ycbcr(im2, True)
h, w = im1.shape[:2]
im1 = im1[border:h-border, border:w-border]
im2 = im2[border:h-border, border:w-border]
im1 = im1.astype(np.float64)
im2 = im2.astype(np.float64)
mse = np.mean((im1 - im2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def normalize_np(im, mean=0.5, std=0.5, reverse=False):
'''
Input:
im: h x w x c, numpy array
Normalize: (im - mean) / std
Reverse: im * std + mean
'''
if not isinstance(mean, (list, tuple)):
mean = [mean, ] * im.shape[2]
mean = np.array(mean).reshape([1, 1, im.shape[2]])
if not isinstance(std, (list, tuple)):
std = [std, ] * im.shape[2]
std = np.array(std).reshape([1, 1, im.shape[2]])
if not reverse:
out = (im.astype(np.float32) - mean) / std
else:
out = im.astype(np.float32) * std + mean
return out
def normalize_th(im, mean=0.5, std=0.5, reverse=False):
'''
Input:
im: b x c x h x w, torch tensor
Normalize: (im - mean) / std
Reverse: im * std + mean
'''
if not isinstance(mean, (list, tuple)):
mean = [mean, ] * im.shape[1]
mean = torch.tensor(mean, device=im.device).view([1, im.shape[1], 1, 1])
if not isinstance(std, (list, tuple)):
std = [std, ] * im.shape[1]
std = torch.tensor(std, device=im.device).view([1, im.shape[1], 1, 1])
if not reverse:
out = (im - mean) / std
else:
out = im * std + mean
return out
# ------------------------Image format--------------------------
def rgb2ycbcr(im, only_y=True):
'''
same as matlab rgb2ycbcr
Input:
im: uint8 [0,255] or float [0,1]
only_y: only return Y channel
'''
# transform to float64 data type, range [0, 255]
if im.dtype == np.uint8:
im_temp = im.astype(np.float64)
else:
im_temp = (im * 255).astype(np.float64)
# convert
if only_y:
rlt = np.dot(im_temp, np.array([65.481, 128.553, 24.966])/ 255.0) + 16.0
else:
rlt = np.matmul(im_temp, np.array([[65.481, -37.797, 112.0 ],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]])/255.0) + [16, 128, 128]
if im.dtype == np.uint8:
rlt = rlt.round()
else:
rlt /= 255.
return rlt.astype(im.dtype)
def rgb2ycbcrTorch(im, only_y=True):
'''
same as matlab rgb2ycbcr
Input:
im: float [0,1], N x 3 x H x W
only_y: only return Y channel
'''
# transform to range [0,255.0]
im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C
# convert
if only_y:
rlt = torch.matmul(im_temp, torch.tensor([65.481, 128.553, 24.966],
device=im.device, dtype=im.dtype).view([3,1])/ 255.0) + 16.0
else:
scale = torch.tensor(
[[65.481, -37.797, 112.0 ],
[128.553, -74.203, -93.786],
[24.966, 112.0, -18.214]],
device=im.device, dtype=im.dtype
) / 255.0
bias = torch.tensor([16, 128, 128], device=im.device, dtype=im.dtype).view([-1, 1, 1, 3])
rlt = torch.matmul(im_temp, scale) + bias
rlt /= 255.0
rlt.clamp_(0.0, 1.0)
return rlt.permute([0, 3, 1, 2])
def ycbcr2rgbTorch(im):
'''
same as matlab ycbcr2rgb
Input:
im: float [0,1], N x 3 x H x W
only_y: only return Y channel
'''
# transform to range [0,255.0]
im_temp = im.permute([0,2,3,1]) * 255.0 # N x H x W x C --> N x H x W x C
# convert
scale = torch.tensor(
[[0.00456621, 0.00456621, 0.00456621],
[0, -0.00153632, 0.00791071],
[0.00625893, -0.00318811, 0]],
device=im.device, dtype=im.dtype
) * 255.0
bias = torch.tensor(
[-222.921, 135.576, -276.836], device=im.device, dtype=im.dtype
).view([-1, 1, 1, 3])
rlt = torch.matmul(im_temp, scale) + bias
rlt /= 255.0
rlt.clamp_(0.0, 1.0)
return rlt.permute([0, 3, 1, 2])
def bgr2rgb(im): return cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
def rgb2bgr(im): return cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
flag_tensor = torch.is_tensor(tensor)
if flag_tensor:
tensor = [tensor]
result = []
for _tensor in tensor:
_tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
img_np = img_np.transpose(1, 2, 0)
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = img_np.transpose(1, 2, 0)
if img_np.shape[2] == 1: # gray image
img_np = np.squeeze(img_np, axis=2)
else:
if rgb2bgr:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
if len(result) == 1 and flag_tensor:
result = result[0]
return result
# ------------------------Image resize-----------------------------
def imresize_np(img, scale, antialiasing=True):
# Now the scale should be the same for H and W
# input: img: Numpy, HWC or HW [0,1]
# output: HWC or HW [0,1] w/o round
img = torch.from_numpy(img)
need_squeeze = True if img.dim() == 2 else False
if need_squeeze:
img.unsqueeze_(2)
in_H, in_W, in_C = img.size()
out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
kernel_width = 4
kernel = 'cubic'
# Return the desired dimension order for performing the resize. The
# strategy is to perform the resize first along the dimension with the
# smallest scale factor.
# Now we do not support this.
# get weights and indices
weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
in_H, out_H, scale, kernel, kernel_width, antialiasing)
weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
in_W, out_W, scale, kernel, kernel_width, antialiasing)
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
sym_patch = img[:sym_len_Hs, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
sym_patch = img[-sym_len_He:, :, :]
inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(0, inv_idx)
img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
out_1 = torch.FloatTensor(out_H, in_W, in_C)
kernel_width = weights_H.size(1)
for i in range(out_H):
idx = int(indices_H[i][0])
for j in range(out_C):
out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
# process W dimension
# symmetric copying
out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
sym_patch = out_1[:, :sym_len_Ws, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
sym_patch = out_1[:, -sym_len_We:, :]
inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
sym_patch_inv = sym_patch.index_select(1, inv_idx)
out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
out_2 = torch.FloatTensor(out_H, out_W, in_C)
kernel_width = weights_W.size(1)
for i in range(out_W):
idx = int(indices_W[i][0])
for j in range(out_C):
out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
if need_squeeze:
out_2.squeeze_()
return out_2.numpy()
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
if (scale < 1) and (antialiasing):
# Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
kernel_width = kernel_width / scale
# Output-space coordinates
x = torch.linspace(1, out_length, out_length)
# Input-space coordinates. Calculate the inverse mapping such that 0.5
# in output space maps to 0.5 in input space, and 0.5+scale in output
# space maps to 1.5 in input space.
u = x / scale + 0.5 * (1 - 1 / scale)
# What is the left-most pixel that can be involved in the computation?
left = torch.floor(u - kernel_width / 2)
# What is the maximum number of pixels that can be involved in the
# computation? Note: it's OK to use an extra pixel here; if the
# corresponding weights are all zero, it will be eliminated at the end
# of this function.
P = math.ceil(kernel_width) + 2
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
1, P).expand(out_length, P)
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
# apply cubic kernel
if (scale < 1) and (antialiasing):
weights = scale * cubic(distance_to_center * scale)
else:
weights = cubic(distance_to_center)
# Normalize the weights matrix so that each row sums to 1.
weights_sum = torch.sum(weights, 1).view(out_length, 1)
weights = weights / weights_sum.expand(out_length, P)
# If a column in weights is all zero, get rid of it. only consider the first and last column.
weights_zero_tmp = torch.sum((weights == 0), 0)
if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
indices = indices.narrow(1, 1, P - 2)
weights = weights.narrow(1, 1, P - 2)
if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
indices = indices.narrow(1, 0, P - 2)
weights = weights.narrow(1, 0, P - 2)
weights = weights.contiguous()
indices = indices.contiguous()
sym_len_s = -indices.min() + 1
sym_len_e = indices.max() - in_length
indices = indices + sym_len_s - 1
return weights, indices, int(sym_len_s), int(sym_len_e)
# matlab 'imresize' function, now only support 'bicubic'
def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
(-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
# ------------------------Image I/O-----------------------------
def imread(path, chn='rgb', dtype='float32', force_gray2rgb=True, force_rgba2rgb=False):
'''
Read image.
chn: 'rgb', 'bgr' or 'gray'
out:
im: h x w x c, numpy tensor
'''
try:
im = cv2.imread(str(path), cv2.IMREAD_UNCHANGED) # BGR, uint8
except:
print(str(path))
if im is None:
print(str(path))
if chn.lower() == 'gray':
assert im.ndim == 2, f"{str(path)} can't be successfuly read!"
else:
if im.ndim == 2:
if force_gray2rgb:
im = np.stack([im, im, im], axis=2)
else:
raise ValueError(f"{str(path)} has {im.ndim} channels!")
elif im.ndim == 4:
if force_rgba2rgb:
im = im[:, :, :3]
else:
raise ValueError(f"{str(path)} has {im.ndim} channels!")
else:
if chn.lower() == 'rgb':
im = bgr2rgb(im)
elif chn.lower() == 'bgr':
pass
if dtype == 'float32':
im = im.astype(np.float32) / 255.
elif dtype == 'float64':
im = im.astype(np.float64) / 255.
elif dtype == 'uint8':
pass
else:
sys.exit('Please input corrected dtype: float32, float64 or uint8!')
return im
# ------------------------Augmentation-----------------------------
def data_aug_np(image, mode):
'''
Performs data augmentation of the input image
Input:
image: a cv2 (OpenCV) image
mode: int. Choice of transformation to apply to the image
0 - no transformation
1 - flip up and down
2 - rotate counterwise 90 degree
3 - rotate 90 degree and flip up and down
4 - rotate 180 degree
5 - rotate 180 degree and flip
6 - rotate 270 degree
7 - rotate 270 degree and flip
'''
if mode == 0:
# original
out = image
elif mode == 1:
# flip up and down
out = np.flipud(image)
elif mode == 2:
# rotate counterwise 90 degree
out = np.rot90(image)
elif mode == 3:
# rotate 90 degree and flip up and down
out = np.rot90(image)
out = np.flipud(out)
elif mode == 4:
# rotate 180 degree
out = np.rot90(image, k=2)
elif mode == 5:
# rotate 180 degree and flip
out = np.rot90(image, k=2)
out = np.flipud(out)
elif mode == 6:
# rotate 270 degree
out = np.rot90(image, k=3)
elif mode == 7:
# rotate 270 degree and flip
out = np.rot90(image, k=3)
out = np.flipud(out)
else:
raise Exception('Invalid choice of image transformation')
return out.copy()
def inverse_data_aug_np(image, mode):
'''
Performs inverse data augmentation of the input image
'''
if mode == 0:
# original
out = image
elif mode == 1:
out = np.flipud(image)
elif mode == 2:
out = np.rot90(image, axes=(1,0))
elif mode == 3:
out = np.flipud(image)
out = np.rot90(out, axes=(1,0))
elif mode == 4:
out = np.rot90(image, k=2, axes=(1,0))
elif mode == 5:
out = np.flipud(image)
out = np.rot90(out, k=2, axes=(1,0))
elif mode == 6:
out = np.rot90(image, k=3, axes=(1,0))
elif mode == 7:
# rotate 270 degree and flip
out = np.flipud(image)
out = np.rot90(out, k=3, axes=(1,0))
else:
raise Exception('Invalid choice of image transformation')
return out
# ----------------------Visualization----------------------------
def imshow(x, title=None, cbar=False):
import matplotlib.pyplot as plt
plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
if title:
plt.title(title)
if cbar:
plt.colorbar()
plt.show()
def imblend_with_mask(im, mask, alpha=0.25):
"""
Input:
im, mask: h x w x c numpy array, uint8, [0, 255]
alpha: scaler in [0.0, 1.0]
"""
edge_map = cv2.Canny(mask, 100, 200).astype(np.float32)[:, :, None] / 255.
assert mask.dtype == np.uint8
mask = mask.astype(np.float32) / 255.
if mask.ndim == 2:
mask = mask[:, :, None]
back_color = np.array([159, 121, 238], dtype=np.float32).reshape((1,1,3))
blend = im.astype(np.float32) * alpha + (1 - alpha) * back_color
blend = np.clip(blend, 0, 255)
out = im.astype(np.float32) * (1 - mask) + blend * mask
# paste edge
out = out * (1 - edge_map) + np.array([0,255,0], dtype=np.float32).reshape((1,1,3)) * edge_map
return out.astype(np.uint8)
# -----------------------Covolution------------------------------
def imgrad(im, pading_mode='mirror'):
'''
Calculate image gradient.
Input:
im: h x w x c numpy array
'''
from scipy.ndimage import correlate # lazy import
wx = np.array([[0, 0, 0],
[-1, 1, 0],
[0, 0, 0]], dtype=np.float32)
wy = np.array([[0, -1, 0],
[0, 1, 0],
[0, 0, 0]], dtype=np.float32)
if im.ndim == 3:
gradx = np.stack(
[correlate(im[:,:,c], wx, mode=pading_mode) for c in range(im.shape[2])],
axis=2
)
grady = np.stack(
[correlate(im[:,:,c], wy, mode=pading_mode) for c in range(im.shape[2])],
axis=2
)
grad = np.concatenate((gradx, grady), axis=2)
else:
gradx = correlate(im, wx, mode=pading_mode)
grady = correlate(im, wy, mode=pading_mode)
grad = np.stack((gradx, grady), axis=2)
return {'gradx': gradx, 'grady': grady, 'grad':grad}
def convtorch(im, weight, mode='reflect'):
'''
Image convolution with pytorch
Input:
im: b x c_in x h x w torch tensor
weight: c_out x c_in x k x k torch tensor
Output:
out: c x h x w torch tensor
'''
radius = weight.shape[-1]
chn = im.shape[1]
im_pad = torch.nn.functional.pad(im, pad=(radius // 2, )*4, mode=mode)
out = torch.nn.functional.conv2d(im_pad, weight, padding=0, groups=chn)
return out
# ----------------------Patch Cropping----------------------------
def random_crop(im, pch_size):
'''
Randomly crop a patch from the give image.
'''
h, w = im.shape[:2]
# padding if necessary
if h < pch_size or w < pch_size:
pad_h = min(max(0, pch_size - h), h)
pad_w = min(max(0, pch_size - w), w)
im = cv2.copyMakeBorder(im, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
h, w = im.shape[:2]
if h == pch_size:
ind_h = 0
elif h > pch_size:
ind_h = random.randint(0, h-pch_size)
else:
raise ValueError('Image height is smaller than the patch size')
if w == pch_size:
ind_w = 0
elif w > pch_size:
ind_w = random.randint(0, w-pch_size)
else:
raise ValueError('Image width is smaller than the patch size')
im_pch = im[ind_h:ind_h+pch_size, ind_w:ind_w+pch_size,]
return im_pch
class ToTensor:
def __init__(self, max_value=1.0):
self.max_value = max_value
def __call__(self, im):
assert isinstance(im, np.ndarray)
if im.ndim == 2:
im = im[:, :, np.newaxis]
if im.dtype == np.uint8:
assert self.max_value == 255.
out = torch.from_numpy(im.astype(np.float32).transpose(2,0,1) / self.max_value)
else:
assert self.max_value == 1.0
out = torch.from_numpy(im.transpose(2,0,1))
return out
class RandomCrop:
def __init__(self, pch_size, pass_crop=False):
self.pch_size = pch_size
self.pass_crop = pass_crop
def __call__(self, im):
if self.pass_crop:
return im
if isinstance(im, list) or isinstance(im, tuple):
out = []
for current_im in im:
out.append(random_crop(current_im, self.pch_size))
else:
out = random_crop(im, self.pch_size)
return out
class ImageSpliterNp:
def __init__(self, im, pch_size, stride, sf=1):
'''
Input:
im: h x w x c, numpy array, [0, 1], low-resolution image in SR
pch_size, stride: patch setting
sf: scale factor in image super-resolution
'''
assert stride <= pch_size
self.stride = stride
self.pch_size = pch_size
self.sf = sf
if im.ndim == 2:
im = im[:, :, None]
height, width, chn = im.shape
self.height_starts_list = self.extract_starts(height)
self.width_starts_list = self.extract_starts(width)
self.length = self.__len__()
self.num_pchs = 0
self.im_ori = im
self.im_res = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
self.pixel_count = np.zeros([height*sf, width*sf, chn], dtype=im.dtype)
def extract_starts(self, length):
starts = list(range(0, length, self.stride))
if starts[-1] + self.pch_size > length:
starts[-1] = length - self.pch_size
return starts
def __len__(self):
return len(self.height_starts_list) * len(self.width_starts_list)
def __iter__(self):
return self
def __next__(self):
if self.num_pchs < self.length:
w_start_idx = self.num_pchs // len(self.height_starts_list)
w_start = self.width_starts_list[w_start_idx] * self.sf
w_end = w_start + self.pch_size * self.sf
h_start_idx = self.num_pchs % len(self.height_starts_list)
h_start = self.height_starts_list[h_start_idx] * self.sf
h_end = h_start + self.pch_size * self.sf
pch = self.im_ori[h_start:h_end, w_start:w_end,]
self.w_start, self.w_end = w_start, w_end
self.h_start, self.h_end = h_start, h_end
self.num_pchs += 1
else:
raise StopIteration(0)
return pch, (h_start, h_end, w_start, w_end)
def update(self, pch_res, index_infos):
'''
Input:
pch_res: pch_size x pch_size x 3, [0,1]
index_infos: (h_start, h_end, w_start, w_end)
'''
if index_infos is None:
w_start, w_end = self.w_start, self.w_end
h_start, h_end = self.h_start, self.h_end
else:
h_start, h_end, w_start, w_end = index_infos
self.im_res[h_start:h_end, w_start:w_end] += pch_res
self.pixel_count[h_start:h_end, w_start:w_end] += 1
def gather(self):
assert np.all(self.pixel_count != 0)
return self.im_res / self.pixel_count
class ImageSpliterTh:
def __init__(self, im, pch_size, stride, sf=1, extra_bs=1, weight_type='Gaussian'):
'''
Input:
im: n x c x h x w, torch tensor, float, low-resolution image in SR
pch_size, stride: patch setting
sf: scale factor in image super-resolution
pch_bs: aggregate pchs to processing, only used when inputing single image
'''
assert weight_type in ['Gaussian', 'ones']
self.weight_type = weight_type
assert stride <= pch_size
self.stride = stride
self.pch_size = pch_size
self.sf = sf
self.extra_bs = extra_bs
bs, chn, height, width= im.shape
self.true_bs = bs
self.height_starts_list = self.extract_starts(height)
self.width_starts_list = self.extract_starts(width)
self.starts_list = []
for ii in self.height_starts_list:
for jj in self.width_starts_list:
self.starts_list.append([ii, jj])
self.length = self.__len__()
self.count_pchs = 0
self.im_ori = im
self.dtype = torch.float64
self.im_res = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)
self.pixel_count = torch.zeros([bs, chn, height*sf, width*sf], dtype=self.dtype, device=im.device)
def extract_starts(self, length):
if length <= self.pch_size:
starts = [0,]
else:
starts = list(range(0, length, self.stride))
for ii in range(len(starts)):
if starts[ii] + self.pch_size > length:
starts[ii] = length - self.pch_size
starts = sorted(set(starts), key=starts.index)
return starts
def __len__(self):
return len(self.height_starts_list) * len(self.width_starts_list)
def __iter__(self):
return self
def __next__(self):
if self.count_pchs < self.length:
index_infos = []
current_starts_list = self.starts_list[self.count_pchs:self.count_pchs+self.extra_bs]
for ii, (h_start, w_start) in enumerate(current_starts_list):
w_end = w_start + self.pch_size
h_end = h_start + self.pch_size
current_pch = self.im_ori[:, :, h_start:h_end, w_start:w_end]
if ii == 0:
pch = current_pch
else:
pch = torch.cat([pch, current_pch], dim=0)
h_start *= self.sf
h_end *= self.sf
w_start *= self.sf
w_end *= self.sf
index_infos.append([h_start, h_end, w_start, w_end])
self.count_pchs += len(current_starts_list)
else:
raise StopIteration()
return pch, index_infos
def update(self, pch_res, index_infos):
'''
Input:
pch_res: (n*extra_bs) x c x pch_size x pch_size, float
index_infos: [(h_start, h_end, w_start, w_end),]
'''
assert pch_res.shape[0] % self.true_bs == 0
pch_list = torch.split(pch_res, self.true_bs, dim=0)
assert len(pch_list) == len(index_infos)
for ii, (h_start, h_end, w_start, w_end) in enumerate(index_infos):
current_pch = pch_list[ii].type(self.dtype)
current_weight = self.get_weight(current_pch.shape[-2], current_pch.shape[-1])
self.im_res[:, :, h_start:h_end, w_start:w_end] += current_pch * current_weight
self.pixel_count[:, :, h_start:h_end, w_start:w_end] += current_weight
@staticmethod
def generate_kernel_1d(ksize):
sigma = 0.3 * ((ksize - 1) * 0.5 - 1) + 0.8 # opencv default setting
if ksize % 2 == 0:
kernel = cv2.getGaussianKernel(ksize=ksize+1, sigma=sigma, ktype=cv2.CV_64F)
kernel = kernel[1:, ]
else:
kernel = cv2.getGaussianKernel(ksize=ksize, sigma=sigma, ktype=cv2.CV_64F)
return kernel
def get_weight(self, height, width):
if self.weight_type == 'ones':
kernel = torch.ones(1, 1, height, width)
elif self.weight_type == 'Gaussian':
kernel_h = self.generate_kernel_1d(height).reshape(-1, 1)
kernel_w = self.generate_kernel_1d(width).reshape(1, -1)
kernel = np.matmul(kernel_h, kernel_w)
kernel = torch.from_numpy(kernel).unsqueeze(0).unsqueeze(0) # 1 x 1 x pch_size x pch_size
else:
raise ValueError(f"Unsupported weight type: {self.weight_type}")
return kernel.to(dtype=self.dtype, device=self.im_ori.device)
def gather(self):
assert torch.all(self.pixel_count != 0)
return self.im_res.div(self.pixel_count)
# ----------------------Patch Cliping----------------------------
class Clamper:
def __init__(self, min_max=(-1, 1)):
self.min_bound, self.max_bound = min_max[0], min_max[1]
def __call__(self, im):
if isinstance(im, np.ndarray):
return np.clip(im, a_min=self.min_bound, a_max=self.max_bound)
elif isinstance(im, torch.Tensor):
return torch.clamp(im, min=self.min_bound, max=self.max_bound)
else:
raise TypeError(f'ndarray or Tensor expected, got {type(im)}')
# ----------------------Interpolation----------------------------
class Bicubic:
def __init__(self, scale=None, out_shape=None, activate_matlab=True, resize_back=False):
self.scale = scale
self.activate_matlab = activate_matlab
self.out_shape = out_shape
self.resize_back = resize_back
def __call__(self, im):
if self.activate_matlab:
out = imresize_np(im, scale=self.scale)
if self.resize_back:
out = imresize_np(out, scale=1/self.scale)
else:
out = cv2.resize(
im,
dsize=self.out_shape,
fx=self.scale,
fy=self.scale,
interpolation=cv2.INTER_CUBIC,
)
if self.resize_back:
out = cv2.resize(
out,
dsize=self.out_shape,
fx=1/self.scale,
fy=1/self.scale,
interpolation=cv2.INTER_CUBIC,
)
return out
class SmallestMaxSize:
def __init__(self, max_size, pass_resize=False, interpolation=None):
self.pass_resize = pass_resize
self.max_size = max_size
self.interpolation = interpolation
self.str2mode = {
'nearest': cv2.INTER_NEAREST_EXACT,
'bilinear': cv2.INTER_LINEAR,
'bicubic': cv2.INTER_CUBIC
}
if self.interpolation is not None:
assert interpolation in self.str2mode, f"Not supported interpolation mode: {interpolation}"
def get_interpolation(self, size):
if self.interpolation is None:
if size < self.max_size: # upsampling
interpolation = cv2.INTER_CUBIC
else: # downsampling
interpolation = cv2.INTER_AREA
else:
interpolation = self.str2mode[self.interpolation]
return interpolation
def __call__(self, im):
h, w = im.shape[:2]
if self.pass_resize or min(h, w) == self.max_size:
out = im
else:
if h < w:
dsize = (int(self.max_size * w / h), self.max_size)
out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(h))
else:
dsize = (self.max_size, int(self.max_size * h / w))
out = cv2.resize(im, dsize=dsize, interpolation=self.get_interpolation(w))
if out.dtype == np.uint8:
out = np.clip(out, 0, 255)
else:
out = np.clip(out, 0, 1.0)
return out
# ----------------------augmentation----------------------------
class SpatialAug:
def __init__(self, pass_aug, only_hflip=False, only_vflip=False, only_hvflip=False):
self.only_hflip = only_hflip
self.only_vflip = only_vflip
self.only_hvflip = only_hvflip
self.pass_aug = pass_aug
def __call__(self, im, flag=None):
if self.pass_aug:
return im
if flag is None:
if self.only_hflip:
flag = random.choice([0, 5])
elif self.only_vflip:
flag = random.choice([0, 1])
elif self.only_hvflip:
flag = random.choice([0, 1, 5])
else:
flag = random.randint(0, 7)
if isinstance(im, list) or isinstance(im, tuple):
out = []
for current_im in im:
out.append(data_aug_np(current_im, flag))
else:
out = data_aug_np(im, flag)
return out
================================================
FILE: comfyui_invsr_trimmed/utils/util_net.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 20:29:36
def reload_model(model, ckpt):
module_flag = list(ckpt.keys())[0].startswith('module.')
compile_flag = '_orig_mod' in list(ckpt.keys())[0]
for source_key, source_value in model.state_dict().items():
target_key = source_key
if compile_flag and (not '_orig_mod.' in source_key):
target_key = '_orig_mod.' + target_key
if module_flag and (not source_key.startswith('module')):
target_key = 'module.' + target_key
assert target_key in ckpt
source_value.copy_(ckpt[target_key])
================================================
FILE: comfyui_invsr_trimmed/utils/util_opts.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-11-24 15:07:43
import argparse
def update_args(args_json, args_parser):
for arg in vars(args_parser):
args_json[arg] = getattr(args_parser, arg)
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
================================================
FILE: comfyui_invsr_trimmed/utils/util_sisr.py
================================================
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Power by Zongsheng Yue 2021-12-07 21:37:58
import cv2
import numpy as np
def modcrop(im, sf):
h, w = im.shape[:2]
h -= (h % sf)
w -= (w % sf)
return im[:h, :w,]
#-----------------------------------------Transform--------------------------------------------
class Bicubic:
def __init__(self, scale=None, out_shape=None, matlab_mode=True):
self.scale = scale
self.out_shape = out_shape
def __call__(self, im):
out = cv2.resize(
im,
dsize=self.out_shape,
fx=self.scale,
fy=self.scale,
interpolation=cv2.INTER_CUBIC,
)
return out
================================================
FILE: configs/degradation_testing_realesrgan.yaml
================================================
degradation:
sf: 4
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.5, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 15]
poisson_scale_range: [0.05, 0.3]
gray_noise_prob: 0.4
jpeg_range: [70, 95]
# the second degradation process
second_order_prob: 0.0
second_blur_prob: 0.2
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.8, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 10]
poisson_scale_range2: [0.05, 0.2]
gray_noise_prob2: 0.4
jpeg_range2: [80, 100]
gt_size: 512
opts:
data_source: ~
im_exts: ['png', 'JPEG']
io_backend:
type: disk
blur_kernel_size: 13
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.60, 0.40, 0.0, 0.0, 0.0, 0.0]
sinc_prob: 0.1
blur_sigma: [0.2, 0.8]
betag_range: [1.0, 1.5]
betap_range: [1, 1.2]
blur_kernel_size2: 7
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.60, 0.4, 0.0, 0.0, 0.0, 0.0]
sinc_prob2: 0.0
blur_sigma2: [0.2, 0.5]
betag_range2: [0.5, 0.8]
betap_range2: [1, 1.2]
final_sinc_prob: 0.2
gt_size: ${degradation.gt_size}
crop_pad_size: ${degradation.gt_size}
use_hflip: False
use_rot: False
================================================
FILE: configs/sample-sd-turbo.yaml
================================================
seed: 12345
# Super-resolution settings
basesr:
sf: 4
chopping: # for latent diffusion
pch_size: 128
weight_type: Gaussian
# VAE settings
tiled_vae: True
latent_tiled_size: 128
sample_tiled_size: 1024
gradient_checkpointing_vae: True
sliced_vae: False
# classifer-free guidance
cfg_scale: 1.0
# sampling settings
start_timesteps: 200
# color fixing
color_fix: ~
# Stable Diffusion
base_model: sd-turbo
sd_pipe:
target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
enable_grad_checkpoint: True
params:
pretrained_model_name_or_path: stabilityai/sd-turbo
cache_dir: /mnt/sfs-common/zsyue/modelbase/stable-diffusion/sd-turbo
use_safetensors: True
torch_dtype: torch.float16
model_start:
target: .noise_predictor.NoisePredictor
ckpt_path: ~ # For initializing
params:
in_channels: 3
down_block_types:
- AttnDownBlock2D
- AttnDownBlock2D
up_block_types:
- AttnUpBlock2D
- AttnUpBlock2D
block_out_channels:
- 256 # 192, 256
- 512 # 384, 512
layers_per_block:
- 3
- 3
act_fn: silu
latent_channels: 4
norm_num_groups: 32
sample_size: 128
mid_block_add_attention: True
resnet_time_scale_shift: default
temb_channels: 512
attention_head_dim: 64
freq_shift: 0
flip_sin_to_cos: True
double_z: True
model_middle:
target: .noise_predictor.NoisePredictor
params:
in_channels: 3
down_block_types:
- AttnDownBlock2D
- AttnDownBlock2D
up_block_types:
- AttnUpBlock2D
- AttnUpBlock2D
block_out_channels:
- 256 # 192, 256
- 512 # 384, 512
layers_per_block:
- 3
- 3
act_fn: silu
latent_channels: 4
norm_num_groups: 32
sample_size: 128
mid_block_add_attention: True
resnet_time_scale_shift: default
temb_channels: 512
attention_head_dim: 64
freq_shift: 0
flip_sin_to_cos: True
double_z: True
================================================
FILE: configs/sd-turbo-sr-ldis.yaml
================================================
trainer:
target: trainer.TrainerSDTurboSR
sd_pipe:
target: diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline
num_train_steps: 1000
enable_grad_checkpoint: True
compile: False
vae_split: 8
params:
pretrained_model_name_or_path: stabilityai/sd-turbo
cache_dir: weights
use_safetensors: True
torch_dtype: torch.float16
llpips:
target: latent_lpips.lpips.LPIPS
ckpt_path: weights/vgg16_sdturbo_lpips.pth
compile: False
params:
pretrained: False
net: vgg16
lpips: True
spatial: False
pnet_rand: False
pnet_tune: True
use_dropout: True
eval_mode: True
latent: True
in_chans: 4
verbose: True
model:
target: .noise_predictor.NoisePredictor
ckpt_start_path: ~ # only used for training the intermidiate model
ckpt_path: ~ # For initializing
compile: False
params:
in_channels: 3
down_block_types:
- AttnDownBlock2D
- AttnDownBlock2D
up_block_types:
- AttnUpBlock2D
- AttnUpBlock2D
block_out_channels:
- 256 # 192, 256
- 512 # 384, 512
layers_per_block:
- 3
- 3
act_fn: silu
latent_channels: 4
norm_num_groups: 32
sample_size: 128
mid_block_add_attention: True
resnet_time_scale_shift: default
temb_channels: 512
attention_head_dim: 64
freq_shift: 0
flip_sin_to_cos: True
double_z: True
discriminator:
target: diffusers.models.unets.unet_2d_condition_discriminator.UNet2DConditionDiscriminator
enable_grad_checkpoint: True
compile: False
params:
sample_size: 64
in_channels: 4
center_input_sample: False
flip_sin_to_cos: True
freq_shift: 0
down_block_types:
- DownBlock2D
- CrossAttnDownBlock2D
- CrossAttnDownBlock2D
mid_block_type: UNetMidBlock2DCrossAttn
up_block_types:
- CrossAttnUpBlock2D
- CrossAttnUpBlock2D
- UpBlock2D
only_cross_attention: False
block_out_channels:
- 128
- 256
- 512
layers_per_block:
- 1
- 2
- 2
downsample_padding: 1
mid_block_scale_factor: 1
dropout: 0.0
act_fn: silu
norm_num_groups: 32
norm_eps: 1e-5
cross_attention_dim: 1024
transformer_layers_per_block: 1
reverse_transformer_layers_per_block: ~
encoder_hid_dim: ~
encoder_hid_dim_type: ~
attention_head_dim:
- 8
- 16
- 16
num_attention_heads: ~
dual_cross_attention: False
use_linear_projection: False
class_embed_type: ~
addition_embed_type: text
addition_time_embed_dim: 256
num_class_embeds: ~
upcast_attention: ~
resnet_time_scale_shift: default
resnet_skip_time_act: False
resnet_out_scale_factor: 1.0
time_embedding_type: positional
time_embedding_dim: ~
time_embedding_act_fn: ~
timestep_post_act: ~
time_cond_proj_dim: ~
conv_in_kernel: 3
conv_out_kernel: 3
projection_class_embeddings_input_dim: 2560
attention_type: default
class_embeddings_concat: False
mid_block_only_cross_attention: ~
cross_attention_norm: ~
addition_embed_type_num_heads: 64
degradation:
sf: 4
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.15, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 30]
poisson_scale_range: [0.05, 3.0]
gray_noise_prob: 0.4
jpeg_range: [30, 95]
# the second degradation process
second_order_prob: 0.5
second_blur_prob: 0.8
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.3, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 25]
poisson_scale_range2: [0.05, 2.5]
gray_noise_prob2: 0.4
jpeg_range2: [30, 95]
gt_size: 512
resize_back: False
use_sharp: False
data:
train:
type: realesrgan
params:
data_source:
source1:
root_path: /mnt/sfs-common/zsyue/database/FFHQ
image_path: images1024
moment_path: ~
text_path: ~
im_ext: png
length: 20000
source2:
root_path: /mnt/sfs-common/zsyue/database/LSDIR/train
image_path: images
moment_path: ~
text_path: ~
im_ext: png
max_token_length: 77 # 77
io_backend:
type: disk
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 3.0]
betag_range: [0.5, 4.0]
betap_range: [1, 2.0]
blur_kernel_size2: 15
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4.0]
betap_range2: [1, 2.0]
final_sinc_prob: 0.8
gt_size: ${degradation.gt_size}
use_hflip: True
use_rot: False
random_crop: True
val:
type: base
params:
dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/lq
transform_type: default
transform_kwargs:
mean: 0.0
std: 1.0
extra_dir_path: /mnt/sfs-common/zsyue/projects/DifInv/SR/testingdata/imagenet512/gt
extra_transform_type: default
extra_transform_kwargs:
mean: 0.0
std: 1.0
im_exts: png
length: 16
recursive: False
train:
# predict started inverser
start_mode: True
# learning rate
lr: 5e-5 # learning rate
lr_min: 5e-5 # learning rate
lr_schedule: ~
warmup_iterations: 2000
# discriminator
lr_dis: 5e-5 # learning rate for dicriminator
weight_decay_dis: 1e-3 # weight decay for dicriminator
dis_init_iterations: 10000 # iterations used for updating the discriminator
dis_update_freq: 1
# dataloader
batch: 64
microbatch: 16
num_workers: 4
prefetch_factor: 2
use_text: True
# optimization settings
weight_decay: 0
ema_rate: 0.999
iterations: 200000 # total iterations
# logging
save_freq: 5000
log_freq: [200, 5000] # [training loss, training images, val images]
local_logging: True # manually save images
tf_logging: False # tensorboard logging
# loss
loss_type: L2
loss_coef:
ldif: 1.0
timesteps: [200, 100]
num_inference_steps: 5
# mixed precision
use_amp: True
use_fsdp: False
# random seed
seed: 123456
global_seeding: False
noise_detach: False
validate:
batch: 2
use_ema: True
log_freq: 4 # logging frequence
val_y_channel: True
================================================
FILE: node.py
================================================
from .comfyui_invsr_trimmed import get_configs, InvSamplerSR, BaseSampler, Namespace
import torch
from comfy.utils import ProgressBar
from folder_paths import get_full_path, get_folder_paths, models_dir
import os
import torch.nn.functional as F
def split_tensor_into_batches(tensor, batch_size):
"""
Split a tensor into smaller batches of specified size
Args:
tensor (torch.Tensor): Input tensor of shape (N, C, H, W)
batch_size (int): Desired batch size for splitting
Returns:
list: List of tensors, each with batch_size (except possibly the last one)
"""
# Get original batch size
original_batch_size = tensor.size(0)
# Calculate number of full batches and remaining samples
num_full_batches = original_batch_size // batch_size
remaining_samples = original_batch_size % batch_size
# Split tensor into chunks
batches = []
# Handle full batches
for i in range(num_full_batches):
start_idx = i * batch_size
end_idx = start_idx + batch_size
batch = tensor[start_idx:end_idx]
batches.append(batch)
# Handle remaining samples if any
if remaining_samples > 0:
last_batch = tensor[-remaining_samples:]
batches.append(last_batch)
return batches
class LoadInvSRModels:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sd_model": (['stabilityai/sd-turbo'],),
"invsr_model": (['noise_predictor_sd_turbo_v5.pth', 'noise_predictor_sd_turbo_v5_diftune.pth'],),
"dtype": (['fp16', 'fp32', 'bf16'], {"default": "fp16"}),
"tiled_vae": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("INVSR_PIPE",)
RETURN_NAMES = ("invsr_pipe",)
FUNCTION = "loadmodel"
CATEGORY = "INVSR"
def loadmodel(self, sd_model, invsr_model, dtype, tiled_vae):
match dtype:
case "fp16":
dtype = "torch.float16"
case "fp32":
dtype = "torch.float32"
case "bf16":
dtype = "torch.bfloat16"
cfg_path = os.path.join(
os.path.dirname(__file__), "configs", "sample-sd-turbo.yaml"
)
sd_path = get_folder_paths("diffusers")[0]
try:
ckpt_dir = get_folder_paths("invsr")[0]
except:
ckpt_dir = os.path.join(models_dir, "invsr")
args = Namespace(
bs=1,
chopping_bs=8,
timesteps=None,
num_steps=1,
cfg_path=cfg_path,
sd_path=sd_path,
started_ckpt_dir=ckpt_dir,
tiled_vae=tiled_vae,
color_fix="",
chopping_size=128,
invsr_model=invsr_model
)
configs = get_configs(args)
configs["sd_pipe"]["params"]["torch_dtype"] = dtype
base_sampler = BaseSampler(configs)
return ((base_sampler, invsr_model),)
class InvSRSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"invsr_pipe": ("INVSR_PIPE",),
"images": ("IMAGE",),
"num_steps": ("INT",{"default": 1, "min": 1, "max": 5}),
"cfg": ("FLOAT",{"default": 1.0, "step":0.1}),
# "scale_factor": ("INT",{"default": 4}),
"batch_size": ("INT",{"default": 1}),
"chopping_batch_size": ("INT",{"default": 8}),
"chopping_size": ([128, 256, 512],{"default": 128}),
"color_fix": (['none', 'wavelet', 'ycbcr'], {"default": "none"}),
"seed": ("INT", {"default": 123, "min": 0, "max": 2**32 - 1, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
FUNCTION = "process"
CATEGORY = "INVSR"
def process(self, invsr_pipe, images, num_steps, cfg, batch_size, chopping_batch_size, chopping_size, color_fix, seed):
base_sampler, invsr_model = invsr_pipe
if color_fix == "none":
color_fix = ""
cfg_path = os.path.join(
os.path.dirname(__file__), "configs", "sample-sd-turbo.yaml"
)
sd_path = get_folder_paths("diffusers")[0]
try:
ckpt_dir = get_folder_paths("invsr")[0]
except:
ckpt_dir = os.path.join(models_dir, "invsr")
args = Namespace(
bs=batch_size,
chopping_bs=chopping_batch_size,
timesteps=None,
num_steps=num_steps,
cfg_path=cfg_path,
sd_path=sd_path,
started_ckpt_dir=ckpt_dir,
tiled_vae=base_sampler.configs.tiled_vae,
color_fix=color_fix,
chopping_size=chopping_size,
invsr_model=invsr_model
)
configs = get_configs(args, log=True)
configs["cfg_scale"] = cfg
# configs["basesr"]["sf"] = scale_factor
base_sampler.configs = configs
base_sampler.setup_seed(seed)
sampler = InvSamplerSR(base_sampler)
images_bchw = images.permute(0,3,1,2)
og_h, og_w = images_bchw.shape[2:]
# Calculate new dimensions divisible by 16
new_height = ((og_h + 15) // 16) * 16 # Round up to nearest multiple of 16
new_width = ((og_w + 15) // 16) * 16
resized = False
if og_h != new_height or og_w != new_width:
resized = True
print(f"[InvSR] - Image not divisible by 16. Resizing to {new_height} (h) x {new_width} (w)")
images_bchw = F.interpolate(images_bchw, size=(new_height, new_width), mode='bicubic', align_corners=False)
batches = split_tensor_into_batches(images_bchw, batch_size)
results = []
pbar = ProgressBar(len(batches))
for batch in batches:
result = sampler.inference(image_bchw=batch)
results.append(torch.from_numpy(result))
pbar.update(1)
result_t = torch.cat(results, dim=0)
# Resize to original dimensions * 4
if resized:
result_t = F.interpolate(result_t, size=(og_h * 4, og_w * 4), mode='bicubic', align_corners=False)
return (result_t.permute(0,2,3,1),)
================================================
FILE: pyproject.toml
================================================
[project]
name = "invsr"
description = "This project is an unofficial ComfyUI implementation of [a/InvSR](https://github.com/zsyOAOA/InvSR) (Arbitrary-steps Image Super-resolution via Diffusion Inversion)"
version = "1.0.1"
license = {file = "LICENSE"}
dependencies = ["pyiqa==0.1.12", "opencv-python", "albumentations==1.4.18", "bitsandbytes", "protobuf", "python-box", "omegaconf", "loguru"]
[project.urls]
Repository = "https://github.com/yuvraj108c/ComfyUI_InvSR"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "yuvraj108c"
DisplayName = "ComfyUI_InvSR"
Icon = ""
================================================
FILE: requirements.txt
================================================
opencv-contrib-python-headless
omegaconf
diffusers
numpy<2
huggingface-hub
transformers
================================================
FILE: workflows/invsr.json
================================================
{
"last_node_id": 27,
"last_link_id": 40,
"nodes": [
{
"id": 25,
"type": "LoadInvSRModels",
"pos": [
-877.6565551757812,
719.943603515625
],
"size": [
413.84686279296875,
130
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "invsr_pipe",
"type": "INVSR_PIPE",
"links": [
38
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "LoadInvSRModels"
},
"widgets_values": [
"stabilityai/sd-turbo",
"noise_predictor_sd_turbo_v5.pth",
"bf16",
true
]
},
{
"id": 17,
"type": "LoadImage",
"pos": [
-867.9894409179688,
918.590087890625
],
"size": [
367.18212890625,
439.6049499511719
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
40
],
"slot_index": 0
},
{
"name": "MASK",
"type": "MASK",
"links": null
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"i1.png",
"image"
]
},
{
"id": 27,
"type": "InvSRSampler",
"pos": [
-365.7756652832031,
859.0978393554688
],
"size": [
315,
246
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "invsr_pipe",
"type": "INVSR_PIPE",
"link": 38
},
{
"name": "images",
"type": "IMAGE",
"link": 40
}
],
"outputs": [
{
"name": "image",
"type": "IMAGE",
"links": [
39
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "InvSRSampler"
},
"widgets_values": [
1,
5,
1,
8,
128,
"wavelet",
536149006,
"randomize"
]
},
{
"id": 26,
"type": "PreviewImage",
"pos": [
35.55903625488281,
643.07470703125
],
"size": [
889.1287231445312,
869.5750732421875
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 39
}
],
"outputs": [],
"properties": {
"Node name for S&R": "PreviewImage"
}
}
],
"links": [
[
38,
25,
0,
27,
0,
"INVSR_PIPE"
],
[
39,
27,
0,
26,
0,
"IMAGE"
],
[
40,
17,
0,
27,
1,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 0.9090909090909091,
"offset": [
1131.660545256908,
-643.8089187150072
]
},
"VHS_latentpreview": false,
"VHS_latentpreviewrate": 0
},
"version": 0.4
}