Showing preview only (9,268K chars total). Download the full file or copy to clipboard to get everything.
Repository: JyChen9811/FaithDiff
Branch: main
Commit: f11af6c81a03
Files: 100
Total size: 8.8 MB
Directory structure:
gitextract_k3bn1c1l/
├── .gitignore
├── CKPT_PTH.py
├── FaithDiff/
│ ├── create_FaithDiff_model.py
│ ├── models/
│ │ ├── bsrnet_arch.py
│ │ └── unet_2d_condition_vae_extension.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── pipeline_FaithDiff_tlc.py
│ │ └── pipeline_output.py
│ └── training_utils.py
├── LICENSE
├── README.md
├── dataloader/
│ ├── Realesrgan_offline_dataset.py
│ ├── accelerate_config.yaml
│ ├── realesrgan.py
│ └── train_kernel.yml
├── environment.yml
├── gradio_demo.py
├── llava/
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── eval/
│ │ ├── eval_gpt_review.py
│ │ ├── eval_gpt_review_bench.py
│ │ ├── eval_gpt_review_visual.py
│ │ ├── eval_pope.py
│ │ ├── eval_science_qa.py
│ │ ├── eval_science_qa_gpt4.py
│ │ ├── eval_science_qa_gpt4_requery.py
│ │ ├── eval_textvqa.py
│ │ ├── generate_webpage_data_from_table.py
│ │ ├── m4c_evaluator.py
│ │ ├── model_qa.py
│ │ ├── model_vqa.py
│ │ ├── model_vqa_loader.py
│ │ ├── model_vqa_mmbench.py
│ │ ├── model_vqa_science.py
│ │ ├── qa_baseline_gpt35.py
│ │ ├── run_llava.py
│ │ ├── summarize_gpt_review.py
│ │ ├── table/
│ │ │ ├── answer/
│ │ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ │ ├── answer_bard.jsonl
│ │ │ │ ├── answer_gpt35.jsonl
│ │ │ │ ├── answer_llama-13b.jsonl
│ │ │ │ └── answer_vicuna-13b.jsonl
│ │ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ │ ├── model.jsonl
│ │ │ ├── prompt.jsonl
│ │ │ ├── question.jsonl
│ │ │ ├── results/
│ │ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ │ ├── review/
│ │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ │ ├── reviewer.jsonl
│ │ │ └── rule.json
│ │ └── webpage/
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
│ ├── llm_agent.py
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── builder.py
│ │ ├── consolidate.py
│ │ ├── language_model/
│ │ │ ├── llava_llama.py
│ │ │ ├── llava_mistral.py
│ │ │ └── llava_mpt.py
│ │ ├── llava_arch.py
│ │ ├── make_delta.py
│ │ ├── multimodal_encoder/
│ │ │ ├── builder.py
│ │ │ └── clip_encoder.py
│ │ ├── multimodal_projector/
│ │ │ └── builder.py
│ │ └── utils.py
│ ├── serve/
│ │ ├── __init__.py
│ │ ├── cli.py
│ │ ├── controller.py
│ │ ├── gradio_web_server.py
│ │ ├── model_worker.py
│ │ ├── register_worker.py
│ │ ├── sglang_worker.py
│ │ └── test_message.py
│ ├── train/
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── llama_xformers_attn_monkey_patch.py
│ │ ├── llava_trainer.py
│ │ ├── train.py
│ │ ├── train_mem.py
│ │ └── train_xformers.py
│ └── utils.py
├── requirements.txt
├── test.py
├── test_generate_caption.py
├── test_metrics.py
├── test_wo_llava.py
├── train_SDXL_stage_1.py
├── train_SDXL_stage_2.py
├── train_stage_1.sh
├── train_stage_2.sh
└── utils/
├── color_fix.py
├── image_process.py
└── system.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
__pycache__/
*.py[cod]
*$py.class
/outputs/
/input/
/dataset/
/save/
/output/
checkpoints/
/.vs
.vscode/
.idea/
venv/
.venv/
*.log
.DS_Store
*.pyc
================================================
FILE: CKPT_PTH.py
================================================
LLAVA_CLIP_PATH = './checkpoints/CLIP_VIT/'
LLAVA_MODEL_PATH = './checkpoints/llava_v1.5-13b/llava'
SDXL_PATH = './checkpoints/Real_4_SDXL/'
FAITHDIFF_PATH = './checkpoints/FaithDiff.bin'
VAE_FP16_PATH = './checkpoints/VAE_FP16/'
BSRNet_PATH = './checkpoints/BSRNet.pth'
================================================
FILE: FaithDiff/create_FaithDiff_model.py
================================================
from utils.system import quantize_8bit
from .pipelines.pipeline_FaithDiff_tlc import FaithDiffStableDiffusionXLPipeline
import torch
from diffusers import AutoencoderKL
from .models.unet_2d_condition_vae_extension import UNet2DConditionModel
from diffusers import AutoencoderKL, DDPMScheduler
from .models.bsrnet_arch import RRDBNet as BSRNet
def FaithDiff_pipeline(sdxl_path, VAE_FP16_path, FaithDiff_path, use_fp8 = False):
dtype = torch.float16
vae = AutoencoderKL.from_pretrained(VAE_FP16_path).to(dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(sdxl_path, subfolder="unet", variant="fp16")
unet.load_additional_layers(weight_path=FaithDiff_path, dtype=dtype)
if use_fp8:
quantize_8bit(unet)
else:
unet = unet.to(dtype=torch.float16)
DDPM_scheduler = DDPMScheduler.from_pretrained(sdxl_path, subfolder="scheduler")
pipe = FaithDiffStableDiffusionXLPipeline.from_pretrained(
sdxl_path,
vae = vae,
add_sample = True,
denoise_encoder = unet.denoise_encoder,
DDPM_scheduler = DDPM_scheduler,
add_watermarker=False,
torch_dtype=dtype,
variant="fp16"
)
pipe.unet = unet
return pipe
def create_bsrnet(bsrnet_path):
bsrnet = BSRNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4)
bsrnet.load_state_dict(torch.load(bsrnet_path), strict=True)
return bsrnet
================================================
FILE: FaithDiff/models/bsrnet_arch.py
================================================
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import math
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
print([in_nc, out_nc, nf, nb, gc, sf])
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def check_image_size(self, x, scale):
_, _, h, w = x.size()
mod_pad_h = (scale - h % scale) % scale
mod_pad_w = (scale - w % scale) % scale
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
# print(x.size(), h ,w)
return x
def forward(self, x):
b, c, H, W = x.size()
x = self.check_image_size(x, scale=4)
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out[:,:,:H * 4, :W * 4]
@torch.no_grad()
def deg_remove(self, input, tile_size=512, tile_pad=16):
# return self.test(input)
"""It will first crop input images to tiles, and then process each tile.
Finally, all the processed tiles are merged into one images.
Modified from: https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
"""
self.scale_factor = 4
batch, channel, height, width = input.shape
output_height = height * self.scale_factor
output_width = width * self.scale_factor
output_shape = (batch, channel, output_height, output_width)
# start with black image
output = input.new_zeros(output_shape)
tiles_x = math.ceil(width / tile_size)
tiles_y = math.ceil(height / tile_size)
# loop over all tiles
for y in range(tiles_y):
for x in range(tiles_x):
# extract tile from input image
ofs_x = x * tile_size
ofs_y = y * tile_size
# input tile area on total image
input_start_x = ofs_x
input_end_x = min(ofs_x + tile_size, width)
input_start_y = ofs_y
input_end_y = min(ofs_y + tile_size, height)
# input tile area on total image with padding
input_start_x_pad = max(input_start_x - tile_pad, 0)
input_end_x_pad = min(input_end_x + tile_pad, width)
input_start_y_pad = max(input_start_y - tile_pad, 0)
input_end_y_pad = min(input_end_y + tile_pad, height)
# input tile dimensions
input_tile_width = input_end_x - input_start_x
input_tile_height = input_end_y - input_start_y
tile_idx = y * tiles_x + x + 1
input_tile = input[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
# upscale tile
output_tile = self.forward(input_tile)
# output tile area on total image
output_start_x = input_start_x * self.scale_factor
output_end_x = input_end_x * self.scale_factor
output_start_y = input_start_y * self.scale_factor
output_end_y = input_end_y * self.scale_factor
# output tile area without padding
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale_factor
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale_factor
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale_factor
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale_factor
# put tile into output image
output[:, :, output_start_y:output_end_y,
output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
output_start_x_tile:output_end_x_tile]
return output
================================================
FILE: FaithDiff/models/unet_2d_condition_vae_extension.py
================================================
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from diffusers import UNet2DConditionModel as OriginalUNet2DConditionModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D, get_down_block
from collections import OrderedDict
from diffusers.utils import is_torch_version
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def zero_module(module):
"""Zero out the parameters of a module and return it."""
for p in module.parameters():
nn.init.zeros_(p)
return module
class Encoder(nn.Module):
"""Encoder layer of a variational autoencoder that encodes input into a latent representation."""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 4,
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"),
block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
layers_per_block: int = 2,
norm_num_groups: int = 32,
act_fn: str = "silu",
double_z: bool = True,
mid_block_add_attention: bool = True,
):
super().__init__()
self.layers_per_block = layers_per_block
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
stride=1,
padding=1,
)
self.mid_block = None
self.down_blocks = nn.ModuleList([])
self.use_rgb = False
self.down_block_type = down_block_types
self.block_out_channels = block_out_channels
self.tile_sample_min_size = 1024
self.tile_latent_min_size = int(self.tile_sample_min_size / 8)
self.tile_overlap_factor = 0.25
self.use_tiling = False
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=self.layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
add_downsample=not is_final_block,
resnet_eps=1e-6,
downsample_padding=0,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_head_dim=output_channel,
temb_channels=None,
)
self.down_blocks.append(down_block)
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
resnet_act_fn=act_fn,
output_scale_factor=1,
resnet_time_scale_shift="default",
attention_head_dim=block_out_channels[-1],
resnet_groups=norm_num_groups,
temb_channels=None,
add_attention=mid_block_add_attention,
)
self.gradient_checkpointing = False
def to_rgb_init(self):
"""Initialize layers to convert features to RGB."""
self.to_rgbs = nn.ModuleList([])
self.use_rgb = True
for i, down_block_type in enumerate(self.down_block_type):
output_channel = self.block_out_channels[i]
self.to_rgbs.append(nn.Conv2d(output_channel, 3, kernel_size=3, padding=1))
def enable_tiling(self):
"""Enable tiling for large inputs."""
self.use_tiling = True
def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""Encode the input tensor into a latent representation."""
sample = self.conv_in(sample)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if is_torch_version(">=", "1.11.0"):
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(down_block), sample, use_reentrant=False
)
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.mid_block), sample, use_reentrant=False
)
else:
for down_block in self.down_blocks:
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
return sample
else:
for down_block in self.down_blocks:
sample = down_block(sample)
sample = self.mid_block(sample)
return sample
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
"""Blend two tensors vertically with a smooth transition."""
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
"""Blend two tensors horizontally with a smooth transition."""
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
"""Encode the input tensor using tiling for large inputs."""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encode(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
moments = torch.cat(result_rows, dim=2)
return moments
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
"""Forward pass of the encoder, using tiling if enabled for large inputs."""
if self.use_tiling and (sample.shape[-1] > self.tile_latent_min_size or sample.shape[-2] > self.tile_latent_min_size):
return self.tiled_encode(sample)
return self.encode(sample)
class ControlNetConditioningEmbedding(nn.Module):
"""A small network to preprocess conditioning inputs, inspired by ControlNet."""
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 4
):
super().__init__()
self.conv_in = nn.Conv2d(conditioning_channels, conditioning_channels, kernel_size=3, padding=1)
self.norm_in = nn.GroupNorm(num_channels=conditioning_channels, num_groups=32, eps=1e-6)
self.conv_out = zero_module(
nn.Conv2d(conditioning_channels, conditioning_embedding_channels, kernel_size=3, padding=1)
)
def forward(self, conditioning):
"""Process the conditioning input through the network."""
conditioning = self.norm_in(conditioning)
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
embedding = self.conv_out(embedding)
return embedding
class QuickGELU(nn.Module):
"""A fast approximation of the GELU activation function."""
def forward(self, x: torch.Tensor):
"""Apply the QuickGELU activation to the input tensor."""
return x * torch.sigmoid(1.702 * x)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
"""Apply LayerNorm and preserve the input dtype."""
orig_type = x.dtype
ret = super().forward(x)
return ret.type(orig_type)
class ResidualAttentionBlock(nn.Module):
"""A transformer-style block with self-attention and an MLP."""
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([("c_fc", nn.Linear(d_model, d_model * 2)), ("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 2, d_model))])
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
"""Apply self-attention to the input tensor."""
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
"""Forward pass through the residual attention block."""
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""The output of UnifiedUNet2DConditionModel."""
sample: torch.FloatTensor = None
class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
"""A unified 2D UNet model extending OriginalUNet2DConditionModel with custom functionality."""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: Union[int, Tuple[int]] = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
dropout: float = 0.0,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads: int = 64,
):
"""Initialize the UnifiedUNet2DConditionModel."""
super().__init__(
sample_size=sample_size,
in_channels=in_channels,
out_channels=out_channels,
center_input_sample=center_input_sample,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
up_block_types=up_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
dropout=dropout,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
time_embedding_type=time_embedding_type,
time_embedding_dim=time_embedding_dim,
time_embedding_act_fn=time_embedding_act_fn,
timestep_post_act=timestep_post_act,
time_cond_proj_dim=time_cond_proj_dim,
conv_in_kernel=conv_in_kernel,
conv_out_kernel=conv_out_kernel,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
attention_type=attention_type,
class_embeddings_concat=class_embeddings_concat,
mid_block_only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
addition_embed_type_num_heads=addition_embed_type_num_heads,
)
# Additional attributes
self.denoise_encoder = None
self.information_transformer_layes = None
self.condition_embedding = None
self.agg_net = None
self.spatial_ch_projs = None
def init_vae_encoder(self, dtype):
self.denoise_encoder = Encoder()
if dtype is not None:
self.denoise_encoder.dtype = dtype
def init_information_transformer_layes(self):
num_trans_channel = 640
num_trans_head = 8
num_trans_layer = 2
num_proj_channel = 320
self.information_transformer_layes = nn.Sequential(*[ResidualAttentionBlock(num_trans_channel, num_trans_head) for _ in range(num_trans_layer)])
self.spatial_ch_projs = zero_module(nn.Linear(num_trans_channel, num_proj_channel))
def init_ControlNetConditioningEmbedding(self, channel=512):
self.condition_embedding = ControlNetConditioningEmbedding(320, channel)
def init_extra_weights(self):
self.agg_net = nn.ModuleList()
def load_additional_layers(self, dtype: Optional[torch.dtype] = torch.float16, channel: int = 512, weight_path: Optional[str] = None):
"""Load additional layers and weights from a file.
Args:
weight_path (str): Path to the weight file.
dtype (torch.dtype, optional): Data type for the loaded weights. Defaults to torch.float16.
channel (int): Conditioning embedding channel out size. Defaults 512.
"""
if self.denoise_encoder is None:
self.init_vae_encoder(dtype)
if self.information_transformer_layes is None:
self.init_information_transformer_layes()
if self.condition_embedding is None:
self.init_ControlNetConditioningEmbedding(channel)
if self.agg_net is None:
self.init_extra_weights()
# Load weights if provided
if weight_path is not None:
state_dict = torch.load(weight_path, weights_only=False)
self.load_state_dict(state_dict, strict=True)
# Move all modules to the same device and dtype as the model
device = next(self.parameters()).device
if dtype is not None or device is not None:
self.to(device=device, dtype=dtype or next(self.parameters()).dtype)
def to(self, *args, **kwargs):
"""Override to() to move all additional modules to the same device and dtype."""
super().to(*args, **kwargs)
for module in [self.denoise_encoder, self.information_transformer_layes,
self.condition_embedding, self.agg_net, self.spatial_ch_projs]:
if module is not None:
module.to(*args, **kwargs)
return self
def load_state_dict(self, state_dict, strict=True):
"""Load state dictionary into the model.
Args:
state_dict (dict): State dictionary to load.
strict (bool, optional): Whether to strictly enforce that all keys match. Defaults to True.
"""
core_dict = {}
additional_dicts = {
'denoise_encoder': {},
'information_transformer_layes': {},
'condition_embedding': {},
'agg_net': {},
'spatial_ch_projs': {}
}
for key, value in state_dict.items():
if key.startswith('denoise_encoder.'):
additional_dicts['denoise_encoder'][key[len('denoise_encoder.'):]] = value
elif key.startswith('information_transformer_layes.'):
additional_dicts['information_transformer_layes'][key[len('information_transformer_layes.'):]] = value
elif key.startswith('condition_embedding.'):
additional_dicts['condition_embedding'][key[len('condition_embedding.'):]] = value
elif key.startswith('agg_net.'):
additional_dicts['agg_net'][key[len('agg_net.'):]] = value
elif key.startswith('spatial_ch_projs.'):
additional_dicts['spatial_ch_projs'][key[len('spatial_ch_projs.'):]] = value
else:
core_dict[key] = value
super().load_state_dict(core_dict, strict=False)
for module_name, module_dict in additional_dicts.items():
module = getattr(self, module_name, None)
if module is not None and module_dict:
module.load_state_dict(module_dict, strict=strict)
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
input_embedding: Optional[torch.Tensor] = None,
add_sample: bool = True,
return_dict: bool = True,
use_condition_embedding: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
"""Forward pass prioritizing the original modified implementation.
Args:
sample (torch.FloatTensor): The noisy input tensor with shape `(batch, channel, height, width)`.
timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
encoder_hidden_states (torch.Tensor): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
class_labels (torch.Tensor, optional): Optional class labels for conditioning.
timestep_cond (torch.Tensor, optional): Conditional embeddings for timestep.
attention_mask (torch.Tensor, optional): An attention mask of shape `(batch, key_tokens)`.
cross_attention_kwargs (Dict[str, Any], optional): A kwargs dictionary for the AttentionProcessor.
added_cond_kwargs (Dict[str, torch.Tensor], optional): Additional embeddings to add to the UNet blocks.
down_block_additional_residuals (Tuple[torch.Tensor], optional): Residuals for down UNet blocks.
mid_block_additional_residual (torch.Tensor, optional): Residual for the middle UNet block.
down_intrablock_additional_residuals (Tuple[torch.Tensor], optional): Additional residuals within down blocks.
encoder_attention_mask (torch.Tensor, optional): A cross-attention mask of shape `(batch, sequence_length)`.
input_embedding (torch.Tensor, optional): Additional input embedding for preprocessing.
add_sample (bool): Whether to add the sample to the processed embedding. Defaults to True.
return_dict (bool): Whether to return a UNet2DConditionOutput. Defaults to True.
use_condition_embedding (bool): Whether to use the condition embedding. Defaults to True.
Returns:
Union[UNet2DConditionOutput, Tuple]: The processed sample tensor, either as a UNet2DConditionOutput or tuple.
"""
default_overall_up_factor = 2**self.num_upsamplers
forward_upsample_size = False
upsample_size = None
for dim in sample.shape[-2:]:
if dim % default_overall_up_factor != 0:
forward_upsample_size = True
break
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# 1. time
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
if class_emb is not None:
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
aug_emb = self.get_aug_embed(
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
if self.config.addition_embed_type == "image_hint":
aug_emb, hint = aug_emb
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
encoder_hidden_states = self.process_encoder_hidden_states(
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
)
# 2. pre-process (following the original modified logic)
sample = self.conv_in(sample) # [B, 4, H, W] -> [B, 320, H, W]
if input_embedding is not None and self.condition_embedding is not None and self.information_transformer_layes is not None:
if use_condition_embedding:
input_embedding = self.condition_embedding(input_embedding) # [B, 320, H, W]
batch_size, channel, height, width = input_embedding.shape
concat_feat = torch.cat([sample, input_embedding], dim=1).view(batch_size, 2 * channel, height * width).transpose(1, 2)
concat_feat = self.information_transformer_layes(concat_feat)
feat_alpha = self.spatial_ch_projs(concat_feat).transpose(1, 2).view(batch_size, channel, height, width)
sample = sample + feat_alpha if add_sample else feat_alpha # Update sample as in the original version
# 2.5 GLIGEN position net (kept from the original version)
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down (continues the standard flow)
if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = down_intrablock_additional_residuals is not None
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
deprecate(
"T2I should not use down_block_additional_residuals",
"1.3.0",
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
standard_warn=False,
)
down_intrablock_additional_residuals = down_block_additional_residuals
is_adapter = True
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
additional_residuals = {}
if is_adapter and len(down_intrablock_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0)
down_block_res_samples += res_samples
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = self.mid_block(sample, emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape:
sample += down_intrablock_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[:-len(upsample_block.resnets)]
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
)
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
return UNet2DConditionOutput(sample=sample)
================================================
FILE: FaithDiff/pipelines/__init__.py
================================================
================================================
FILE: FaithDiff/pipelines/pipeline_FaithDiff_tlc.py
================================================
import inspect
import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import cv2
import torch
from transformers import (
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer
)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import (
AttnProcessor2_0,
FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
USE_PEFT_BACKEND,
deprecate,
is_invisible_watermark_available,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import FaithDiffStableDiffusionXLPipelineOutput
import PIL.Image
if is_invisible_watermark_available():
from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
from diffusers import AutoencoderKL, DDPMScheduler
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import StableDiffusionXLPipeline
>>> pipe = StableDiffusionXLPipeline.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
```
"""
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
if img.dtype == 'float64':
img = img.astype('float32')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
class LocalAttention:
"""A class to handle local attention by splitting tensors into overlapping grids for processing."""
def __init__(self, kernel_size=None, overlap=0.5):
"""Initialize the LocalAttention module.
Args:
kernel_size (tuple[int, int], optional): Size of the grid (height, width). Defaults to None.
overlap (float): Overlap factor between adjacent grids (0.0 to 1.0). Defaults to 0.5.
"""
super().__init__()
self.kernel_size = kernel_size
self.overlap = overlap
def grids_list(self, x):
"""Split the input tensor into a list of non-overlapping grid patches.
Args:
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
Returns:
list[torch.Tensor]: List of tensor patches.
"""
b, c, h, w = x.shape
self.original_size = (b, c, h, w)
assert b == 1
k1, k2 = self.kernel_size
if h < k1:
k1 = h
if w < k2:
k2 = w
num_row = (h - 1) // k1 + 1
num_col = (w - 1) // k2 + 1
self.nr = num_row
self.nc = num_col
import math
step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
parts = []
idxes = []
i = 0
last_i = False
while i < h and not last_i:
j = 0
if i + k1 >= h:
i = h - k1
last_i = True
last_j = False
while j < w and not last_j:
if j + k2 >= w:
j = w - k2
last_j = True
parts.append(x[:, :, i:i + k1, j:j + k2])
idxes.append({'i': i, 'j': j})
j = j + step_j
i = i + step_i
return parts
def grids(self, x):
"""Split the input tensor into overlapping grid patches and concatenate them.
Args:
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
Returns:
torch.Tensor: Concatenated tensor of all grid patches.
"""
b, c, h, w = x.shape
self.original_size = (b, c, h, w)
assert b == 1
k1, k2 = self.kernel_size
if h < k1:
k1 = h
if w < k2:
k2 = w
self.tile_weights = self._gaussian_weights(k2, k1)
num_row = (h - 1) // k1 + 1
num_col = (w - 1) // k2 + 1
self.nr = num_row
self.nc = num_col
import math
step_j = k2 if num_col == 1 else math.ceil(k2 * self.overlap)
step_i = k1 if num_row == 1 else math.ceil(k1 * self.overlap)
parts = []
idxes = []
i = 0
last_i = False
while i < h and not last_i:
j = 0
if i + k1 >= h:
i = h - k1
last_i = True
last_j = False
while j < w and not last_j:
if j + k2 >= w:
j = w - k2
last_j = True
parts.append(x[:, :, i:i + k1, j:j + k2])
idxes.append({'i': i, 'j': j})
j = j + step_j
i = i + step_i
self.idxes = idxes
return torch.cat(parts, dim=0)
def _gaussian_weights(self, tile_width, tile_height):
"""Generate a Gaussian weight mask for tile contributions.
Args:
tile_width (int): Width of the tile.
tile_height (int): Height of the tile.
Returns:
torch.Tensor: Gaussian weight tensor of shape (channels, height, width).
"""
from numpy import pi, exp, sqrt
import numpy as np
latent_width = tile_width
latent_height = tile_height
var = 0.01
midpoint = (latent_width - 1) / 2
x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
midpoint = latent_height / 2
y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
weights = np.outer(y_probs, x_probs)
return torch.tile(torch.tensor(weights, device=torch.device('cuda')), (4, 1, 1))
def grids_inverse(self, outs):
"""Reconstruct the original tensor from processed grid patches with overlap blending.
Args:
outs (torch.Tensor): Processed grid patches.
Returns:
torch.Tensor: Reconstructed tensor of original size.
"""
preds = torch.zeros(self.original_size).to(outs.device)
b, c, h, w = self.original_size
count_mt = torch.zeros((b, 4, h, w)).to(outs.device)
k1, k2 = self.kernel_size
for cnt, each_idx in enumerate(self.idxes):
i = each_idx['i']
j = each_idx['j']
preds[0, :, i:i + k1, j:j + k2] += outs[cnt, :, :, :] * self.tile_weights
count_mt[0, :, i:i + k1, j:j + k2] += self.tile_weights
del outs
torch.cuda.empty_cache()
return preds / count_mt
def _pad(self, x):
"""Pad the input tensor to align with kernel size.
Args:
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
Returns:
tuple: Padded tensor and padding values.
"""
b, c, h, w = x.shape
k1, k2 = self.kernel_size
mod_pad_h = (k1 - h % k1) % k1
mod_pad_w = (k2 - w % k2) % k2
pad = (mod_pad_w//2, mod_pad_w-mod_pad_w//2, mod_pad_h//2, mod_pad_h-mod_pad_h//2)
x = F.pad(x, pad, 'reflect')
return x, pad
def forward(self, x):
"""Apply local attention by splitting into grids and reconstructing.
Args:
x (torch.Tensor): Input tensor of shape (batch, channels, height, width).
Returns:
torch.Tensor: Processed tensor of original size.
"""
b, c, h, w = x.shape
qkv = self.grids(x)
out = self.grids_inverse(qkv)
return out
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
Args:
noise_cfg (torch.Tensor): Noise configuration tensor.
noise_pred_text (torch.Tensor): Predicted noise from text-conditioned model.
guidance_rescale (float): Rescaling factor for guidance. Defaults to 0.0.
Returns:
torch.Tensor: Rescaled noise configuration.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
"""Retrieve latents from an encoder output.
Args:
encoder_output (torch.Tensor): Output from an encoder (e.g., VAE).
generator (torch.Generator, optional): Random generator for sampling. Defaults to None.
sample_mode (str): Sampling mode ("sample" or "argmax"). Defaults to "sample".
Returns:
torch.Tensor: Retrieved latent tensor.
"""
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FaithDiffStableDiffusionXLPipeline(
DiffusionPipeline,
StableDiffusionMixin,
FromSingleFileMixin,
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
IPAdapterMixin,
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion XL uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([` CLIPTextModelWithProjection`]):
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`CLIPTokenizer`):
Second Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
`stabilityai/stable-diffusion-xl-base-1-0`.
add_watermarker (`bool`, *optional*):
Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
watermark output images. If not defined, it will default to True if the package is installed, otherwise no
watermarker will be used.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->denoise_encoder->unet->vae"
_optional_components = [
"tokenizer",
"tokenizer_2",
"text_encoder",
"text_encoder_2",
"feature_extractor",
]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"add_text_embeds",
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
]
def __init__(
self,
vae: AutoencoderKL,
denoise_encoder: AutoencoderKL,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
DDPM_scheduler: DDPMScheduler,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None,
):
super().__init__()
self.register_modules(
vae=vae,
denoise_encoder = denoise_encoder,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
DDPM_scheduler = DDPM_scheduler,
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = self.unet.config.sample_size
add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
if add_watermarker:
self.watermark = StableDiffusionXLWatermarker()
else:
self.watermark = None
def encode_prompt(
self,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = 'cuda' #device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
text_encoders = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
dtype = text_encoders[0].dtype
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
text_encoder = text_encoder.to(dtype)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(self, TextualInversionLoaderMixin):
negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if self.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if self.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if do_classifier_free_guidance:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
lr_img,
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if lr_img is None:
raise ValueError(f"`lr_image` must be provided!")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def upcast_vae(self):
dtype = self.vae.dtype
self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
self.vae.post_quant_conv.to(dtype)
self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype)
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
) -> torch.FloatTensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
def set_encoder_tile_settings(self,
denoise_encoder_tile_sample_min_size = 1024,
denoise_encoder_sample_overlap_factor = 0.25,
vae_sample_size=1024,
vae_tile_overlap_factor = 0.25):
self.denoise_encoder.tile_sample_min_size = denoise_encoder_tile_sample_min_size
self.denoise_encoder.tile_overlap_factor = denoise_encoder_sample_overlap_factor
self.vae.config.sample_size = vae_sample_size
self.vae.tile_overlap_factor = vae_tile_overlap_factor
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
self.denoise_encoder.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
self.denoise_encoder.disable_tiling()
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def denoising_end(self):
return self._denoising_end
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
def prepare_image_latents(
self, image, batch_size, num_images_per_prompt, dtype, device, do_classifier_free_guidance, generator=None
):
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
raise ValueError(
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
)
image = image.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
if image.shape[1] == 4:
image_latents = image
else:
# make sure the VAE is in float32 mode, as it overflows in float16
# needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
# if needs_upcasting:
# image = image.float()
# self.upcast_vae()
image_latents = self.denoise_encoder(image)
# cast back to fp16 if needed
# if needs_upcasting:
# self.vae.to(dtype=torch.float16)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand image_latents for batch_size
deprecation_message = (
f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial"
" images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
" that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
" your script to pass as many initial images as text prompts to suppress this warning."
)
deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
if do_classifier_free_guidance:
uncond_image_latents = torch.zeros_like(image_latents)
image_latents = image_latents #torch.cat([image_latents, image_latents], dim=0)
if image_latents.dtype != self.vae.dtype:
image_latents = image_latents.to(dtype=self.vae.dtype)
# if self.is_cosxl_edit:
# image_latents = image_latents * self.vae.config.scaling_factor
return image_latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
lr_img: PipelineImageInput = None,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
start_point: Optional[str] = "noise",
timesteps: List[int] = None,
denoising_end: Optional[float] = None,
overlap: float = 0.5,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = None,
target_size: Optional[Tuple[int, int]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
add_sample: bool = True,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
lr_img (PipelineImageInput, optional): Low-resolution input image for conditioning the generation process.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
and checkpoints that are not specifically fine-tuned on low resolutions.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
start_point (str, *optional*):
The starting point for the generation process. Can be "noise" (random noise) or "lr" (low-resolution image).
Defaults to "noise".
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
overlap (float):
Overlap factor for local attention tiling (between 0.0 and 1.0). Controls the overlap between adjacent
grid patches during processing. Defaults to 0.5.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
add_sample (bool):
Whether to include sample conditioning (e.g., low-resolution image) in the UNet during denoising.
Defaults to True.
Examples:
Returns:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
# 0. Default height and width to unet
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
original_size = original_size or (height, width)
target_size = target_size or (height, width)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
lr_img,
prompt,
prompt_2,
height,
width,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
self._interrupt = False
self.tlc_vae_latents = LocalAttention((target_size[0] // 8,target_size[1]// 8), overlap)
self.tlc_vae_img = LocalAttention((target_size[0]// 8,target_size[1]// 8), overlap)
# 2. Define call parameters
batch_size = 1
num_images_per_prompt = 1
device = torch.device('cuda') #self._execution_device
# 3. Encode input prompt
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)
num_samples = num_images_per_prompt
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
lora_scale=lora_scale
)
lr_img_list = [lr_img]
lr_img = self.image_processor.preprocess(lr_img_list, height=height, width=width).to(device, dtype = prompt_embeds.dtype)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
image_latents = self.prepare_image_latents(lr_img,
batch_size,
num_images_per_prompt,
prompt_embeds.dtype,
device,
self.do_classifier_free_guidance)
image_latents = self.tlc_vae_img.grids(image_latents)
# 5. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
if start_point == 'lr':
latents_condition_image = self.vae.encode(lr_img).latent_dist.sample()
latents_condition_image = latents_condition_image * self.vae.config.scaling_factor
start_steps_tensor = torch.randint(999, 999+1, (latents.shape[0],), device=latents.device)
start_steps_tensor = start_steps_tensor.long()
latents = self.DDPM_scheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)
latents = self.tlc_vae_latents.grids(latents)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * image_latents.shape[0]
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
# 8.1 Apply denoising_end
if (
self.denoising_end is not None
and isinstance(self.denoising_end, float)
and self.denoising_end > 0
and self.denoising_end < 1
):
discrete_timestep_cutoff = int(
round(
self.scheduler.config.num_train_timesteps
- (self.denoising_end * self.scheduler.config.num_train_timesteps)
)
)
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
# 9. Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
self._num_timesteps = len(timesteps)
sub_latents_num = latents.shape[0]
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if i >=1 :
latents = self.tlc_vae_latents.grids(latents).to(dtype=latents.dtype)
if self.interrupt:
continue
concat_grid = []
for sub_num in range(sub_latents_num):
self.scheduler.__dict__.update(views_scheduler_status[sub_num])
sub_latents = latents[sub_num, :, :, :].unsqueeze(0)
img_sub_latents = image_latents[sub_num, :, :, :].unsqueeze(0)
latent_model_input = torch.cat([sub_latents] * 2) if self.do_classifier_free_guidance else sub_latents
img_sub_latents = torch.cat([img_sub_latents] * 2) if self.do_classifier_free_guidance else img_sub_latents
scaled_latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
pos_height = self.tlc_vae_latents.idxes[sub_num]['i']
pos_width = self.tlc_vae_latents.idxes[sub_num]['j']
add_time_ids = [
torch.tensor([original_size]),
torch.tensor([[pos_height, pos_width]]),
torch.tensor([target_size])
]
add_time_ids = torch.cat(add_time_ids, dim=1).to(img_sub_latents.device, dtype=img_sub_latents.dtype)
add_time_ids = add_time_ids.repeat(2, 1).to(dtype=img_sub_latents.dtype)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
with torch.amp.autocast(device.type, dtype=latents.dtype, enabled=latents.dtype != self.unet.dtype):
noise_pred = self.unet(
scaled_latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
timestep_cond=timestep_cond,
cross_attention_kwargs=self.cross_attention_kwargs,
input_embedding=img_sub_latents,
add_sample = add_sample,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = sub_latents.dtype
sub_latents = self.scheduler.step(noise_pred, t, sub_latents, **extra_step_kwargs, return_dict=False)[0]
views_scheduler_status[sub_num] = copy.deepcopy(self.scheduler.__dict__)
concat_grid.append(sub_latents)
if latents.dtype != sub_latents:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
sub_latents = sub_latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
negative_pooled_prompt_embeds = callback_outputs.pop(
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if XLA_AVAILABLE:
xm.mark_step()
latents = self.tlc_vae_latents.grids_inverse(torch.cat(concat_grid, dim=0)).to(sub_latents.dtype)
if not output_type == "latent":
# make sure the VAE is in float32 mode, as it overflows in float16
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
if needs_upcasting:
self.upcast_vae()
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != self.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
self.vae = self.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
if has_latents_mean and has_latents_std:
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if needs_upcasting:
self.vae.to(dtype=torch.float16)
else:
image = latents
if not output_type == "latent":
# apply watermark if available
if self.watermark is not None:
image = self.watermark.apply_watermark(image)
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image, )
return FaithDiffStableDiffusionXLPipelineOutput(images=image)
================================================
FILE: FaithDiff/pipelines/pipeline_output.py
================================================
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from diffusers.utils import BaseOutput, is_flax_available
@dataclass
class FaithDiffStableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
if is_flax_available():
import flax
@flax.struct.dataclass
class FlaxFaithDiffStableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Flax Stable Diffusion XL pipelines.
Args:
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
"""
images: np.ndarray
================================================
FILE: FaithDiff/training_utils.py
================================================
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from FaithDiff.models.unet_2d_condition_vae_extension import UNet2DConditionModel
from diffusers.schedulers import SchedulerMixin
from diffusers.utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_peft_available,
is_torch_npu_available,
is_torchvision_available,
is_transformers_available,
)
if is_transformers_available():
import transformers
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
if is_peft_available():
from peft import set_peft_model_state_dict
if is_torchvision_available():
from torchvision import transforms
if is_torch_npu_available():
import torch_npu # noqa: F401
def set_seed(seed: int):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
seed (`int`): The seed to set.
Returns:
`None`
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
for the given timesteps using the provided noise scheduler.
Args:
noise_scheduler (`NoiseScheduler`):
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
the SNR values.
timesteps (`torch.Tensor`):
A tensor of timesteps for which the SNR is computed.
Returns:
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
full list of supported enums is documented at
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
Args:
interpolation_type (`str`):
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
in torchvision.
Returns:
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
"""
if not is_torchvision_available():
raise ImportError(
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
)
if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic":
interpolation_mode = transforms.InterpolationMode.BICUBIC
elif interpolation_type == "box":
interpolation_mode = transforms.InterpolationMode.BOX
elif interpolation_type == "nearest":
interpolation_mode = transforms.InterpolationMode.NEAREST
elif interpolation_type == "nearest_exact":
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
elif interpolation_type == "hamming":
interpolation_mode = transforms.InterpolationMode.HAMMING
elif interpolation_type == "lanczos":
interpolation_mode = transforms.InterpolationMode.LANCZOS
else:
raise ValueError(
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
)
return interpolation_mode
def compute_dream_and_update_latents(
unet: UNet2DConditionModel,
noise_scheduler: SchedulerMixin,
timesteps: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
target: torch.Tensor,
encoder_hidden_states: torch.Tensor,
dream_detail_preservation: float = 1.0,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
forward step without gradients.
Args:
`unet`: The state unet to use to make a prediction.
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
`timesteps`: The timesteps for the noise_scheduler to user.
`noise`: A tensor of noise in the shape of noisy_latents.
`noisy_latents`: Previously noise latents from the training loop.
`target`: The ground-truth tensor to predict after eps is removed.
`encoder_hidden_states`: Text embeddings from the text model.
`dream_detail_preservation`: A float value that indicates detail preservation level.
See reference.
Returns:
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
pred = None
with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
_noisy_latents, _target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
_target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return _noisy_latents, _target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns:
A state dict containing just the LoRA parameters.
"""
lora_state_dict = {}
for name, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
lora_layer = getattr(module, "lora_layer")
if lora_layer is not None:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
"""
Casts the training parameters of the model to the specified data type.
Args:
model: The PyTorch model whose parameters will be cast.
dtype: The data type to which the model parameters will be cast.
"""
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
):
"""
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
Args:
lora_state_dict: The state dictionary to be set.
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
text_encoder: Where the `lora_state_dict` is to be set.
"""
text_encoder_state_dict = {
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
def compute_density_for_timestep_sampling(
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
):
"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device="cpu")
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device="cpu")
return u
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
"""
Computes loss weighting scheme for SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "sigma_sqrt":
weighting = (sigmas**-2.0).float()
elif weighting_scheme == "cosmap":
bot = 1 - 2 * sigmas + 2 * sigmas**2
weighting = 2 / (math.pi * bot)
else:
weighting = torch.ones_like(sigmas)
return weighting
def free_memory():
"""
Runs garbage collection. Then clears the cache of the available accelerator.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.npu.empty_cache()
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
def should_update_ema(args, step):
if args.ema_update_interval is None:
# If the EMA update interval is not set, always update the EMA.
return True
else:
should_update = step % args.ema_update_interval == 0
# if should_update:
# logger.info("Updating EMA weights...")
return should_update
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
args,
accelerator,
parameters: Iterable[torch.nn.Parameter],
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3,
foreach: bool = True,
model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None,
**kwargs,
):
"""
Args:
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
decay (float): The decay factor for the exponential moving average.
min_decay (float): The minimum decay factor for the exponential moving average.
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
use_ema_warmup (bool): Whether to use EMA warmup.
inv_gamma (float):
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
weights will be stored on CPU.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
"""
if isinstance(parameters, torch.nn.Module):
deprecation_message = (
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
"Please pass the parameters of the module instead."
)
deprecate(
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
"1.0.0",
deprecation_message,
standard_warn=False,
)
parameters = parameters.parameters()
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
use_ema_warmup = True
if kwargs.get("max_value", None) is not None:
deprecation_message = (
"The `max_value` argument is deprecated. Please use `decay` instead."
)
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
decay = kwargs["max_value"]
if kwargs.get("min_value", None) is not None:
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
min_decay = kwargs["min_value"]
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
if kwargs.get("device", None) is not None:
deprecation_message = (
"The `device` argument is deprecated. Please use `to` instead."
)
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])
self.temp_stored_params = None
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_ema_warmup = use_ema_warmup
self.inv_gamma = inv_gamma
self.power = power
self.optimization_step = 0
self.cur_decay_value = None # set in `step()`
self.foreach = foreach
self.model_cls = model_cls
self.model_config = model_config
self.args = args
self.accelerator = accelerator
self.training = True # To emulate nn.Module's training mode
def save_state_dict(self, path: str) -> None:
"""
Save the EMA model's state directly to a file.
Args:
path (str): The file path where the EMA state will be saved.
"""
# if the folder containing the path does not exist, create it
os.makedirs(os.path.dirname(path), exist_ok=True)
# grab state dict
state_dict = self.state_dict()
# save it using torch.save
torch.save(state_dict, path)
logger.info(f"EMA model state saved to {path}")
def load_state_dict(self, path: str) -> None:
"""
Load the EMA model's state from a file and apply it to this instance.
Args:
path (str): The file path from where the EMA state will be loaded.
"""
state_dict = torch.load(path, map_location="cpu", weights_only=True)
# Load metadata
self.decay = state_dict.get("decay", self.decay)
self.min_decay = state_dict.get("min_decay", self.min_decay)
self.optimization_step = state_dict.get(
"optimization_step", self.optimization_step
)
self.update_after_step = state_dict.get(
"update_after_step", self.update_after_step
)
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
self.power = state_dict.get("power", self.power)
# Load shadow parameters
shadow_params = []
idx = 0
while f"shadow_params.{idx}" in state_dict:
shadow_params.append(state_dict[f"shadow_params.{idx}"])
idx += 1
if len(shadow_params) != len(self.shadow_params):
raise ValueError(
f"Mismatch in number of shadow parameters: expected {len(self.shadow_params)}, "
f"but found {len(shadow_params)} in the state dict."
)
for current_param, loaded_param in zip(self.shadow_params, shadow_params):
current_param.data.copy_(loaded_param.data)
logger.info(f"EMA model state loaded from {path}")
@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)
ema_model = cls(
model.parameters(), model_cls=model_cls, model_config=model.config
)
ema_model.load_state_dict(ema_kwargs)
return ema_model
def save_pretrained(self, path, max_shard_size: str = "10GB"):
if self.model_cls is None:
raise ValueError(
"`save_pretrained` can only be used if `model_cls` was defined at __init__."
)
if self.model_config is None:
raise ValueError(
"`save_pretrained` can only be used if `model_config` was defined at __init__."
)
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict(exclude_params=True)
state_dict.pop("shadow_params", None)
model.register_to_config(**state_dict)
self.copy_to(model.parameters())
model.save_pretrained(path, max_shard_size=max_shard_size)
def get_decay(self, optimization_step: int = None) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
if optimization_step is None:
optimization_step = self.optimization_step
step = max(0, optimization_step - self.update_after_step - 1)
if step <= 0:
return 0.0
if self.use_ema_warmup:
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
else:
cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = min(cur_decay_value, self.decay)
# make sure decay is not smaller than min_decay
cur_decay_value = max(cur_decay_value, self.min_decay)
return cur_decay_value
@torch.no_grad()
def step(self, parameters: Iterable[torch.nn.Parameter], global_step: int = None):
if not should_update_ema(self.args, global_step):
return
# print(self.args.ema_device, not self.args.ema_cpu_only)
if self.args.ema_device == "cpu" and not self.args.ema_cpu_only:
# Move EMA to accelerator for faster update.
self.to(device=self.accelerator.device, non_blocking=True)
if isinstance(parameters, torch.nn.Module):
deprecation_message = (
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
"Please pass the parameters of the module instead."
)
deprecate(
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
"1.0.0",
deprecation_message,
standard_warn=False,
)
parameters = parameters.parameters()
parameters = list(parameters)
if global_step is not None:
# When we're updating the EMA periodically, we can't trust the counter.
self.optimization_step = global_step
else:
self.optimization_step += 1
# Compute the decay factor for the exponential moving average.
decay = self.get_decay(self.optimization_step)
self.cur_decay_value = decay
one_minus_decay = 1 - decay
context_manager = contextlib.nullcontext
if (
is_transformers_available()
and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled()
):
import deepspeed
if self.foreach:
if (
is_transformers_available()
and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled()
):
context_manager = deepspeed.zero.GatheredParameters(
parameters, modifier_rank=None
)
with context_manager():
if self.args.ema_cpu_only:
params_grad = [param.cpu() for param in parameters if param.requires_grad]
else:
params_grad = [param for param in parameters if param.requires_grad]
s_params_grad = [
s_param
for s_param, param in zip(self.shadow_params, parameters)
if param.requires_grad
]
if len(params_grad) < len(parameters):
if self.args.ema_cpu_only:
torch._foreach_copy_(
[
s_param
for s_param, param in zip(self.shadow_params, parameters)
if not param.requires_grad
],
[param.cpu() for param in parameters if not param.requires_grad],
non_blocking=True,
)
else:
if self.args.ema_cpu_only:
torch._foreach_copy_(
[
s_param
for s_param, param in zip(self.shadow_params, parameters)
if not param.requires_grad
],
[param for param in parameters if not param.requires_grad],
non_blocking=True,
)
torch._foreach_sub_(
s_params_grad,
torch._foreach_sub(s_params_grad, params_grad),
alpha=one_minus_decay,
)
else:
for s_param, param in zip(self.shadow_params, parameters):
if (
is_transformers_available()
and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled()
):
context_manager = deepspeed.zero.GatheredParameters(
param, modifier_rank=None
)
with context_manager():
if param.requires_grad:
s_param.sub_(
one_minus_decay * (s_param - param.to(s_param.device))
)
else:
s_param.copy_(param)
if self.args.ema_device == "cpu" and not self.args.ema_cpu_only:
# Move back to CPU for safe-keeping.
self.to(device=self.args.ema_device, non_blocking=True)
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = list(parameters)
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters],
[
s_param.to(param.device).data
for s_param, param in zip(self.shadow_params, parameters)
],
)
else:
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.to(param.device).data)
def pin_memory(self) -> None:
r"""
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
offloading EMA params to the host.
"""
if torch.backends.mps.is_available():
logger.warning("Apple silicon does not support pinned memory. Skipping.")
return
if self.args.ema_cpu_only:
return
# This probably won't work, but we'll do it anyway.
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
def to(self, *args, **kwargs):
for param in self.shadow_params:
param.data = param.data.to(*args, **kwargs)
return self
def cuda(self, device=None):
return self.to(device="cuda" if device is None else f"cuda:{device}")
def cpu(self):
return self.to(device="cpu")
def state_dict(
self, destination=None, prefix="", keep_vars=False, exclude_params: bool = False
):
r"""
Returns a dictionary containing a whole state of the EMA model.
"""
state_dict = {
"decay": self.decay,
"min_decay": self.min_decay,
"optimization_step": self.optimization_step,
"update_after_step": self.update_after_step,
"use_ema_warmup": self.use_ema_warmup,
"inv_gamma": self.inv_gamma,
"power": self.power,
}
if exclude_params:
return state_dict
for idx, param in enumerate(self.shadow_params):
state_dict[f"{prefix}shadow_params.{idx}"] = (
param if keep_vars else param.detach()
)
return state_dict
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Save the current parameters for restoring later.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Restore the parameters stored with the `store` method.
"""
if self.temp_stored_params is None:
raise RuntimeError(
"This ExponentialMovingAverage has no `store()`ed weights "
"to `restore()`"
)
if self.foreach:
torch._foreach_copy_(
[param.data for param in parameters],
[c_param.data for c_param in self.temp_stored_params],
)
else:
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)
# Better memory-wise.
self.temp_stored_params = None
def parameter_count(self) -> int:
return sum(p.numel() for p in self.shadow_params)
# Implementing nn.Module methods to emulate its behavior
def named_children(self):
# No child modules
return iter([])
def children(self):
return iter([])
def modules(self):
yield self
def named_modules(self, memo=None, prefix=""):
yield prefix, self
def parameters(self, recurse=True):
return iter(self.shadow_params)
def named_parameters(self, prefix="", recurse=True):
for i, param in enumerate(self.shadow_params):
name = f"{prefix}shadow_params.{i}"
yield name, param
def buffers(self, recurse=True):
return iter([])
def named_buffers(self, prefix="", recurse=True):
return iter([])
def train(self, mode=True):
self.training = mode
return self
def eval(self):
return self.train(False)
def zero_grad(self):
# No gradients to zero in EMA model
pass
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) [2025] [Junyang Chen]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
### (CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolution
[](https://huggingface.co/jychen9811/FaithDiff)

> [[Project Page](https://jychen9811.github.io/FaithDiff_page/)]   [[Paper](https://openaccess.thecvf.com/content/CVPR2025/papers/Chen_FaithDiff_Unleashing_Diffusion_Priors_for_Faithful_Image_Super-resolution_CVPR_2025_paper.pdf)]
> [Junyang Chen](https://jychen9811.github.io/), [Jinshan Pan](https://jspan.github.io/), [Jiangxin Dong](https://scholar.google.com/citations?user=ruebFVEAAAAJ&hl=zh-CN&oi=ao) <br>
> [IMAG Lab](https://imag-njust.net/), Nanjing University of Science and Technology
> If FaithDiff is helpful for you, please help star the GitHub Repo. Thanks!
> Welcome to visit our website (专注底层视觉领域的信息服务平台) for low-level vision: [https://lowlevelcv.com/](https://lowlevelcv.com/)
---
😊 You may also want to check our relevant works:
1. **STCDiT (CVPR 2026)** [Paper](https://arxiv.org/abs/2511.18786) | [Code](https://github.com/JyChen9811/STCDiT)
A motion-aware VAE and anchor-frame-guided DiT framework enables stable video restoration, even under complex camera motions.
2. **CODSR (CVPR2026)** [Paper](https://arxiv.org/abs/2512.14061 ) | [Code](https://github.com/Chanson94/CODSR )
A one-step diffusion SR framework enabling region-discriminative activation of generative priors and precise semantic grounding.
### 🚩 **New Features/Updates**
- ✅ April 3, 2025. The code has been integrated into [Diffusers](https://github.com/huggingface/diffusers/blob/main/examples/community/pipeline_faithdiff_stable_diffusion_xl.py). Respect to Eliseu Silva!!!
- ✅ April 1, 2025. Supports FP8 inference and CPU offloading, significantly reducing memory usage. Thanks Eliseu Silva!!!
- ✅ March 28, 2025. Update a nice gradio demo.
- ✅ March 24, 2025. Release the training code.
- ✅ February 09, 2025. Support ultra-high-resolution (8K and above) image restoration on 24GB GPUs.
- ✅ February 08, 2025. Release [RealDeg](https://drive.google.com/file/d/1B8BaaMjXJ-1TfcTgE9MrAg8ufvaGkndP/view?usp=sharing). It includes 238 images with unknown degradations, consisting of old photographs, social media images, and classic film stills.
- ✅ February 07, 2025. Release the testing code and [pre-trained model](https://huggingface.co/jychen9811/FaithDiff).
- ✅ November 25, 2024. Creat the repository and the [project page](https://jychen9811.github.io/FaithDiff_page/).
### ⚡ **To do**
- FaithDiff-SD3-Large
- ~~Release the training code~~
- ~~Release the testing code and pre-trained model~~
---
### 📷 Real-World Enhancement Results
[<img src="figs/nezha.jpg" width="500px" height="320px"/>](https://imgsli.com/MzQ3NDQx) [<img src="figs/wukong.jpg" height="320px"/>](https://imgsli.com/MzQ3NDM5)
[<img src="figs/old_photo.jpg" width="500px" height="320px"/>](https://imgsli.com/MzQ3NDYx) [<img src="figs/social_media.jpg" height="320px"/>](https://imgsli.com/MzQ3NDU2)
<!--  -->
---
### 🌈 AIGC Enhancement Results
[<img src="figs/pikaqiu.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjEz)
[<img src="figs/cat_and_snake.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjAx)
[<img src="figs/yangtuo.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NTk0)
[<img src="figs/duolaameng.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NTk2)
[<img src="figs/tiger.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjA0)
[<img src="figs/little_girl.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjA2)
[<img src="figs/boy.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjE1)
[<img src="figs/girl_and_cat.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjA5)
[<img src="figs/astronaut.jpg" width="270px" height="270px"/>](https://imgsli.com/MzQ3NjEw)
---
### :gift: Gradio Demo
```
python gradio_demo.py
#### Additional parameters
You can add the following parameters to the gradio application.
```Shell
--cpu_offload = Offloads the weights of the pipeline components to the CPU RAM. If you have a GPU with less than 12GB it would be a good idea to use this parameter.
--use_fp8 = Changes the diffusion model precision from FP16 to FP8, significantly reducing GPU memory requirements. This option in conjunction with **--cpu_offload** will require only 5GB VRAM for a 2x upscale.
# FP8 Inference and CPU offloading
python gradio_demo.py --cpu_offload --use_fp8
# FP8 Inference, CPU offloading and without LLaVA
python gradio_demo.py --cpu_offload --use_fp8 --no_llava
```

---
### ⚡ How to train
#### Environment
```
conda env create --name faithdiff -f environment.yml
```
#### Training Script
```Shell
# Stage 1
bash train_stage_1.sh
# After Stage 1 training, enter the checkpoints folder.
cd ./train_FaithDiff_stage_1_offline/checkpoint-6000
python zero_to_fp32.py ./ ./pretrain.bin
# Stage 2
bash train_stage_2.sh
# After Stage 2 training, enter the checkpoints folder.
cd ./train_FaithDiff_stage_2_offline/checkpoint
python zero_to_fp32.py ./ ./FaithDiff.bin
```
#### Tips for Human Face data preparation
- *To quickly filter out low-quality data in the FFHQ dataset, we recommend using topiq to assess image quality. Here are the [official results](https://github.com/chaofengc/IQA-PyTorch/blob/a7f2be4363f3a4c765c6868239336f6eeba33c93/tests/FFHQ_score_topiq_nr-face.csv). We empirically selected images with a metric above 0.72.*
- *During training, we recommend resizing the face image resolution to a range between 768 and 512.*
- *If you need to improve the restoration performance of portrait images, [Unsplash](https://unsplash.com/) offers high-quality portrait images. You can search for different clothing names to obtain full-body portrait data.*
---
### 🚀 How to evaluate
#### Download Dependent Models
- [FaithDiff Pre-trained model](https://huggingface.co/jychen9811/FaithDiff)
- [SDXL RealVisXL_V4.0](https://huggingface.co/SG161222/RealVisXL_V4.0)
- [SDXL VAE FP16](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)
- [LLaVA CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336)
- [LLaVA v1.5 13B](https://huggingface.co/liuhaotian/llava-v1.5-13b)
- [BSRNet](https://drive.usercontent.google.com/download?id=1JGJLiENPkOqi39bvQYa_jlIPlMk24iKH&export=download&authuser=0&confirm=t&uuid=ebaa5d11-ac76-4f54-aabf-90fa43997dec&at=AEz70l4zk_8LTafpGtR0ZSE50F1N:1742369984793)
- Put them in the `./checkpoints` folder and update the corresponding path in CKPT_path.py.
#### Val Dataset
RealDeg: [Google Drive](https://drive.google.com/file/d/1B8BaaMjXJ-1TfcTgE9MrAg8ufvaGkndP/view?usp=sharing)
*To evaluate the performance of our method in real-world scenarios, we collect a dataset of 238 images with unknown degradations, consisting of old photographs, social media images, and classic film stills. The category of old photographs includes black-and-white images, faded photographs, and colorized versions. Social media images are uploaded by us to various social media platforms (e.g., WeChat, RedNote, Sina Weibo and Zhihu), undergoing one or multiple rounds of cross-platform processing. The classic film stills are selected from iconic films spanning the 1980s to 2000s, such as The Shawshank Redemption, Harry Potter, and Spider-Man, etc. The images feature diverse content, including people, buildings, animals, and various natural elements. In addition, the shortest side of the image resolution is at least 720 pixels.*
#### Inference Script
```Shell
# Script that support two GPUs.
CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir='./dataset/RealDeg' --save_dir=./save/RealDeg --upscale=2 --guidance_scale=5 --num_inference_steps=20 --load_8bit_llava
# Scripts that support only one GPU.
CUDA_VISIBLE_DEVICES=0 python test_generate_caption.py --img_dir='./dataset/RealDeg' --save_dir=./save/RealDeg_caption --load_8bit_llava
CUDA_VISIBLE_DEVICES=0 python test_wo_llava.py --img_dir='./dataset/RealDeg' --json_dir=./save/RealDeg_caption --save_dir=./save/RealDeg --upscale=2 --guidance_scale=5 --num_inference_steps=20
# If attempting ultra-high-resolution image restoration, add --use_tile_vae in the scripts. The same applies to test_wo_llava.
CUDA_VISIBLE_DEVICES=0,1 python test.py --img_dir='./dataset/RealDeg' --save_dir=./save/RealDeg --use_tile_vae --upscale=8 --guidance_scale=5 --num_inference_steps=20 --load_8bit_llava
```
---
### BibTeX
@inproceedings{chen2024faithdiff,
title={FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolution},
author={Chen, Junyang and Pan, Jinshan and Dong, Jiangxin},
booktitle={CVPR},
year={2025}
}
---
### Contact
If you have any questions, please feel free to reach me out at `jychen9811@gmail.com`.
---
### Acknowledgments
Our project is based on [diffusers](https://github.com/huggingface/diffusers/tree/main), [SUPIR](https://github.com/Fanghua-Yu/SUPIR), [TLC](https://github.com/megvii-research/TLC) and [SimpleTuner](https://github.com/bghira/SimpleTuner/tree/main). Thanks for their awesome works.
================================================
FILE: dataloader/Realesrgan_offline_dataset.py
================================================
import os
import glob
import torch
import random
import numpy as np
from PIL import Image
from functools import partial
import torch.nn.functional as F
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
from basicsr.utils.img_process_util import filter2D
from PIL import Image
import json
from transformers import CLIPImageProcessor
from torch import nn
from torchvision import transforms
from torch.utils import data as data
from torchvision.transforms.functional import normalize
from .realesrgan import RealESRGAN_degradation
import cv2
import random
from glob import glob
from collections import OrderedDict
import yaml
from PIL import Image
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def opt_parse(opt_path):
with open(opt_path, mode='r') as f:
Loader, _ = ordered_yaml()
opt = yaml.load(f, Loader=Loader) # ignore_security_alert_wait_for_fix RCE
return opt
def convert_image_to_fn(img_type, image, minsize=512, eps=0.02):
width, height = image.size
if min(width, height) < minsize:
scale = minsize/min(width, height) + eps
image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))
if image.mode != img_type:
return image.convert(img_type)
return image
def exists(x):
return x is not None
class LocalImageDataset(data.Dataset):
def __init__(self,
img_file = None,
face_file = None,
yml_kernel = None,
image_size=512,
tokenizer=None,
tokenizer_2=None,
center_crop=False,
random_flip=True,
resize_bak=True,
convert_image_to="RGB",
t_drop_rate=0.05
):
super(LocalImageDataset, self).__init__()
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.resize_bak = resize_bak
self.crop_size = image_size
self.t_drop_rate = t_drop_rate
nature_paths = []
nature_lr_paths = []
nature_jsons = []
face_paths = []
lq_face_paths = []
face_jsons = []
self.data_types = ['nature', 'face']
self.data_prob = [0.875, 0.125]
for img_path_idx in img_file[0]:
img_path_list = sorted(glob(os.path.join(img_path_idx, '**', '*.png'), recursive=True))
nature_paths += img_path_list
for lq_img_path_idx in img_file[1]:
lq_img_path_list = sorted(glob(os.path.join(lq_img_path_idx, '**', '*.png'), recursive=True))
nature_lr_paths += lq_img_path_list
for text_path_idx in img_file[2]:
text_path_list = sorted(glob(os.path.join(text_path_idx, '**', '*.json'), recursive=True))
nature_jsons += text_path_list
for face_path_idx in face_file[0]:
face_path_list = sorted(glob(os.path.join(face_path_idx, '**', '*.png'), recursive=True))
face_paths += face_path_list
for lq_face_path_idx in face_file[1]:
lq_face_path_list = sorted(glob(os.path.join(lq_face_path_idx, '**', '*.png'), recursive=True))
lq_face_paths += lq_face_path_list
for face_text_path_idx in face_file[2]:
face_text_path_list = sorted(glob(os.path.join(face_text_path_idx, '**', '*.json'), recursive=True))
face_jsons += face_text_path_list
self.data_collection = {'nature': (np.array(nature_paths), np.array(nature_jsons), np.array(nature_lr_paths)), 'face': (np.array(face_paths), np.array(face_jsons), np.array(lq_face_paths))}
self.data_lens = {'nature': len(nature_paths), 'face': len(face_paths)}
print(self.data_lens)
self.data_lens = {'nature': len(nature_jsons), 'face': len(face_jsons)}
print(self.data_lens)
self.data_lens = {'nature': len(nature_lr_paths), 'face': len(lq_face_paths)}
print(self.data_lens)
self.datatypes_lens = [len(nature_paths), len(face_paths)]
self.cumulative_lens = np.cumsum([0] + self.datatypes_lens)
def __getitem__(self, index):
data_type_idx = np.where(self.cumulative_lens <= index )[0][-1]
data_type = self.data_types[data_type_idx]
index = index - self.cumulative_lens[data_type_idx]
crop_pad_size = self.crop_size
# load image
img_path = self.data_collection[data_type][0][index]
json_path = self.data_collection[data_type][1][index]
lq_img_path = self.data_collection[data_type][2][index]
gt_path = img_path
data = json.load(open(json_path))
init_text = data["caption"]
words = init_text.split()
words = words[3:]
words[0] = words[0].capitalize()
text = ' '.join(words)
text = text.split('. ')
text = '. '.join(text[:2]) + '.'
image = Image.open(img_path).convert('RGB')
if 'FFHQ' in lq_img_path:
if random.random() < 0.5:
lq_img_path = lq_img_path.replace('LR_crops_1', 'LR_crops_2')
lq_image = Image.open(lq_img_path).convert('RGB')
if 'FFHQ' in img_path:
random_size = random.randint(128, 192)
lq_image = lq_image.resize((random_size, random_size), Image.BICUBIC)
image = image.resize((int(random_size * 4), int(random_size * 4)), Image.BICUBIC)
w, h = lq_image.size
pil_img = np.array(image)
pil_lr_img = np.array(lq_image)
pil_img, pil_lr_img = augment([pil_img, pil_lr_img], hflip=True, rotation=False)
crop_pad_size = self.crop_size // 4
# pad
if h < crop_pad_size or w < crop_pad_size:
pad_h = max(0, crop_pad_size - h)
pad_w = max(0, crop_pad_size - w)
pil_lr_img = cv2.copyMakeBorder(pil_lr_img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
pad_h = max(0, self.crop_size - h * 4)
pad_w = max(0, self.crop_size - w * 4)
pil_img = cv2.copyMakeBorder(pil_img, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101)
# crop
if pil_lr_img.shape[0] > crop_pad_size or pil_lr_img.shape[1] > crop_pad_size:
h, w = pil_lr_img.shape[0:2]
# randomly choose top and left coordinates
top = random.randint(0, h - crop_pad_size)
left = random.randint(0, w - crop_pad_size)
pil_lr_img = pil_lr_img[top : top + crop_pad_size, left : left + crop_pad_size, ...]
pil_img = pil_img[top * 4 : (top + crop_pad_size) * 4, left * 4: (left + crop_pad_size) * 4, ...]
else:
top = 0
left = 0
lq_image = Image.fromarray(pil_lr_img)
mode = random.choice([Image.NEAREST, Image.BILINEAR, Image.BICUBIC])
lr_w, lr_h = lq_image.size
lq_image = lq_image.resize((lr_w * 4, lr_h * 4), mode)
image = Image.fromarray(pil_img)
original_size = torch.tensor([h * 4, w * 4])
crop_coords_top_left = torch.tensor([top * 4, left * 4])
GT_image_t = np.asarray(image)/255.
LR_image_t = np.asarray(lq_image)/255.
GT_image_t, LR_image_t = img2tensor([GT_image_t, LR_image_t], bgr2rgb=False, float32=True)
LR_image_t = LR_image_t * 2.0 - 1.0
GT_image_t = GT_image_t * 2.0 - 1.0
rand_num = random.random()
if rand_num < self.t_drop_rate:
text = ""
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
text_input_ids_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
null_text_input_ids = self.tokenizer(
'',
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
null_text_input_ids_2 = self.tokenizer_2(
'',
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return {
'lq_image': LR_image_t,
"image": GT_image_t,
"text_input_ids": text_input_ids,
"text_input_ids_2": text_input_ids_2,
"original_size": original_size,
"crop_coords_top_left": crop_coords_top_left,
"target_size": torch.tensor([self.crop_size, self.crop_size]),
'gt_path': gt_path,
'null_text_input_ids': null_text_input_ids,
"null_text_input_ids_2": null_text_input_ids_2
# "check_img": check_image
}
def __len__(self):
total_length = 0
for key, value in self.data_lens.items():
total_length += value
return total_length
================================================
FILE: dataloader/accelerate_config.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 7399
================================================
FILE: dataloader/realesrgan.py
================================================
import os
import numpy as np
import cv2
import glob
import math
import yaml
import random
from collections import OrderedDict
import torch
import torch.nn.functional as F
from basicsr.data.transforms import augment
from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img
from basicsr.utils.img_process_util import filter2D
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,
normalize, rgb_to_grayscale)
cur_path = os.path.dirname(os.path.abspath(__file__))
def ordered_yaml():
"""Support OrderedDict for yaml.
Returns:
yaml Loader and Dumper.
"""
try:
from yaml import CDumper as Dumper
from yaml import CLoader as Loader
except ImportError:
from yaml import Dumper, Loader
_mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
def dict_representer(dumper, data):
return dumper.represent_dict(data.items())
def dict_constructor(loader, node):
return OrderedDict(loader.construct_pairs(node))
Dumper.add_representer(OrderedDict, dict_representer)
Loader.add_constructor(_mapping_tag, dict_constructor)
return Loader, Dumper
def opt_parse(opt_path):
with open(opt_path, mode='r') as f:
Loader, _ = ordered_yaml()
opt = yaml.load(f, Loader=Loader) # ignore_security_alert_wait_for_fix RCE
return opt
class RealESRGAN_degradation(object):
def __init__(self, opt_name='params_realesrgan.yml', device='cpu'):
opt_path = f'{opt_name}'
self.opt = opt_parse(opt_path)
self.device = device #torch.device('cpu')
optk = self.opt['kernel_info']
# blur settings for the first degradation
self.blur_kernel_size = optk['blur_kernel_size']
self.kernel_list = optk['kernel_list']
self.kernel_prob = optk['kernel_prob']
self.blur_sigma = optk['blur_sigma']
self.betag_range = optk['betag_range']
self.betap_range = optk['betap_range']
self.sinc_prob = optk['sinc_prob']
# blur settings for the second degradation
self.blur_kernel_size2 = optk['blur_kernel_size2']
self.kernel_list2 = optk['kernel_list2']
self.kernel_prob2 = optk['kernel_prob2']
self.blur_sigma2 = optk['blur_sigma2']
self.betag_range2 = optk['betag_range2']
self.betap_range2 = optk['betap_range2']
self.sinc_prob2 = optk['sinc_prob2']
# a final sinc filter
self.final_sinc_prob = optk['final_sinc_prob']
self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
self.pulse_tensor[10, 10] = 1
self.jpeger = DiffJPEG(differentiable=False).to(self.device)
self.usm_shaper = USMSharp().to(self.device)
def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
fn_idx = torch.randperm(4)
for fn_id in fn_idx:
if fn_id == 0 and brightness is not None:
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = adjust_brightness(img, brightness_factor)
if fn_id == 1 and contrast is not None:
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = adjust_contrast(img, contrast_factor)
if fn_id == 2 and saturation is not None:
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = adjust_saturation(img, saturation_factor)
if fn_id == 3 and hue is not None:
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = adjust_hue(img, hue_factor)
return img
def random_augment(self, img_gt):
# random horizontal flip
# img_gt, status = augment(img_gt, hflip=False, rotation=False, return_status=True)
"""
# random color jitter
if np.random.uniform() < self.opt['color_jitter_prob']:
jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)
img_gt = img_gt + jitter_val
img_gt = np.clip(img_gt, 0, 1)
# random grayscale
if np.random.uniform() < self.opt['gray_prob']:
#img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)
img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])
"""
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)
return img_gt
def random_kernels(self):
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob:
# this sinc filter setting is for kernels ranging from [7, 21]
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel = random_mixed_kernels(
self.kernel_list,
self.kernel_prob,
kernel_size,
self.blur_sigma,
self.blur_sigma, [-math.pi, math.pi],
self.betag_range,
self.betap_range,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------ Generate kernels (used in the second degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
if np.random.uniform() < self.sinc_prob2:
if kernel_size < 13:
omega_c = np.random.uniform(np.pi / 3, np.pi)
else:
omega_c = np.random.uniform(np.pi / 5, np.pi)
kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
else:
kernel2 = random_mixed_kernels(
self.kernel_list2,
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
self.blur_sigma2, [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
noise_range=None)
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
# ------------------------------------- sinc kernel ------------------------------------- #
if np.random.uniform() < self.final_sinc_prob:
kernel_size = random.choice(self.kernel_range)
omega_c = np.random.uniform(np.pi / 3, np.pi)
sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
sinc_kernel = torch.FloatTensor(sinc_kernel)
else:
sinc_kernel = self.pulse_tensor
kernel = torch.FloatTensor(kernel)
kernel2 = torch.FloatTensor(kernel2)
return kernel, kernel2, sinc_kernel, kernel_size
@torch.no_grad()
def degrade_process(self, img_gt, resize_bak=False, seed=42):
img_gt = self.random_augment(img_gt)
# random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)
kernel1, kernel2, sinc_kernel, kernel_size = self.random_kernels()
img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)
#img_gt = self.usm_shaper(img_gt) # shaper gt
ori_h, ori_w = img_gt.size()[2:4]
#scale_final = random.randint(4, 16)
scale_final = 4
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(img_gt, kernel1)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# noise
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt['second_blur_prob']:
out = filter2D(out, kernel2)
# random resize
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]
if updown_type == 'up':
scale = np.random.uniform(1, self.opt['resize_range2'][1])
elif updown_type == 'down':
scale = np.random.uniform(self.opt['resize_range2'][0], 1)
else:
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)
# noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
rounds=False)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
out = filter2D(out, sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)
out = filter2D(out, sinc_kernel)
if np.random.uniform() < self.opt['gray_prob']:
out = rgb_to_grayscale(out, num_output_channels=1)
if np.random.uniform() < self.opt['color_jitter_prob']:
brightness = self.opt.get('brightness', (0.5, 1.5))
contrast = self.opt.get('contrast', (0.5, 1.5))
saturation = self.opt.get('saturation', (0, 1.5))
hue = self.opt.get('hue', (-0.1, 0.1))
out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)
if resize_bak:
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)
# clamp and round
img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
return img_gt, img_lq
================================================
FILE: dataloader/train_kernel.yml
================================================
scale: 4
color_jitter_prob: 0.0
gray_prob: 0.0
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.3, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 15]
poisson_scale_range: [0.05, 2.0]
gray_noise_prob: 0.4
jpeg_range: [60, 95]
# the second degradation process
second_blur_prob: 0.5
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.6, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 12]
poisson_scale_range2: [0.05, 1.0]
gray_noise_prob2: 0.4
jpeg_range2: [60, 100]
kernel_info:
blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 3]
betag_range: [0.5, 4]
betap_range: [1, 2]
blur_kernel_size2: 21
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.5]
betag_range2: [0.5, 4]
betap_range2: [1, 2]
final_sinc_prob: 0.8
================================================
FILE: environment.yml
================================================
name: FaithDiff
channels:
- pytorch
- nvidia
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- brotli-python=1.0.9=py310h6a678d5_9
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.11.26=h06a4308_0
- certifi=2024.12.14=py310h06a4308_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- cuda-cudart=12.4.127=0
- cuda-cupti=12.4.127=0
- cuda-libraries=12.4.1=0
- cuda-nvrtc=12.4.127=0
- cuda-nvtx=12.4.127=0
- cuda-opencl=12.6.77=0
- cuda-runtime=12.4.1=0
- cuda-version=12.6=3
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py310h06a4308_0
- freetype=2.12.1=h4a9f257_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0
- gnutls=3.6.15=he1e5248_0
- idna=3.7=py310h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.4=py310h06a4308_1
- jpeg=9e=h5eee18b_3
- lame=3.100=h7b6447c_0
- lcms2=2.16=hb9589c4_0
- ld_impl_linux-64=2.40=h12ee557_0
- lerc=4.0.0=h6a678d5_0
- libcublas=12.4.5.8=0
- libcufft=11.2.1.3=0
- libcufile=1.11.1.6=0
- libcurand=10.3.7.77=0
- libcusolver=11.6.1.9=0
- libcusparse=12.3.1.170=0
- libdeflate=1.22=h5eee18b_0
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=2.0.0=h9bf148f_0
- libnpp=12.2.5.30=0
- libnvfatbin=12.6.77=0
- libnvjitlink=12.4.127=0
- libnvjpeg=12.3.1.117=0
- libpng=1.6.39=h5eee18b_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=hffd6297_1
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_1
- llvm-openmp=14.0.6=h9e868ea_0
- lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py310h5eee18b_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py310h5eee18b_1
- mkl_fft=1.3.11=py310h5eee18b_0
- mkl_random=1.2.8=py310h1128e8f_0
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py310h06a4308_0
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.2.1=py310h06a4308_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openssl=3.0.15=h5eee18b_0
- pip=24.2=py310h06a4308_0
- pysocks=1.7.1=py310h06a4308_0
- python=3.10.16=he870216_1
- pytorch=2.4.0=py3.10_cuda12.4_cudnn9.1.0_0
- pytorch-cuda=12.4=hc786d27_7
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.2=py310h5eee18b_0
- readline=8.2=h5eee18b_0
- requests=2.32.3=py310h06a4308_1
- setuptools=75.1.0=py310h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- sympy=1.13.3=py310h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.14=h39e8969_0
- torchaudio=2.4.0=py310_cu124
- torchtriton=3.0.0=py310
- torchvision=0.19.0=py310_cu124
- typing_extensions=4.12.2=py310h06a4308_0
- urllib3=2.2.3=py310h06a4308_0
- wheel=0.44.0=py310h06a4308_0
- xz=5.4.6=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.6=hc292b87_0
- pip:
- absl-py==2.1.0
- accelerate==1.1.1
- addict==2.4.0
- aiofiles==23.2.1
- aiohappyeyeballs==2.4.4
- aiohttp==3.11.11
- aiosignal==1.3.2
- annotated-types==0.7.0
- anyio==4.8.0
- asttokens==3.0.0
- astunparse==1.6.3
- async-timeout==5.0.1
- attrs==24.3.0
- beautifulsoup4==4.12.3
- bitsandbytes==0.45.0
- cfgv==3.4.0
- click==8.1.8
- colorama==0.4.6
- contourpy==1.3.1
- cycler==0.12.1
- datasets==3.2.0
- decorator==5.1.1
- deepspeed==0.15.1
- diffusers==0.28.0
- dill==0.3.8
- distlib==0.3.9
- einops==0.7.0
- exceptiongroup==1.2.2
- executing==2.1.0
- facexlib==0.3.0
- fastapi==0.115.6
- ffmpy==0.5.0
- filterpy==1.4.5
- flatbuffers==24.12.23
- fonttools==4.55.3
- frozenlist==1.5.0
- fsspec==2024.9.0
- ftfy==6.3.1
- future==1.0.0
- gast==0.6.0
- gdown==5.2.0
- google-pasta==0.2.0
- gradio==5.23.1
- gradio-client==1.8.0
- grpcio==1.69.0
- h11==0.14.0
- h5py==3.12.1
- hjson==3.1.0
- httpcore==1.0.7
- httpx==0.28.1
- huggingface-hub==0.25.2
- icecream==2.1.3
- identify==2.6.5
- imageio==2.34.0
- imageio-ffmpeg==0.5.1
- iniconfig==2.0.0
- ipdb==0.13.13
- ipython==8.31.0
- jedi==0.19.2
- keras==3.8.0
- kiwisolver==1.4.8
- lazy-loader==0.4
- libclang==18.1.1
- llvmlite==0.43.0
- lmdb==1.6.2
- loguru==0.7.2
- lpips==0.1.4
- markdown==3.7
- markdown-it-py==3.0.0
- matplotlib==3.10.0
- matplotlib-inline==0.1.7
- mdurl==0.1.2
- ml-dtypes==0.4.1
- msgpack==1.1.0
- multidict==6.1.0
- multiprocess==0.70.16
- namex==0.0.8
- ninja==1.11.1.3
- nodeenv==1.9.1
- numba==0.60.0
- numpy==1.26.4
- nvidia-ml-py==12.560.30
- omegaconf==2.4.0.dev3
- openai-clip==1.0.1
- opencv-python==4.9.0.80
- opencv-python-headless==4.10.0.84
- opt-einsum==3.4.0
- optree==0.13.1
- orjson==3.10.13
- packaging==24.2
- pandas==2.2.3
- parso==0.8.4
- pexpect==4.9.0
- pillow==10.4.0
- platformdirs==4.3.6
- pluggy==1.5.0
- pre-commit==4.0.1
- prompt-toolkit==3.0.48
- propcache==0.2.1
- protobuf==5.29.2
- psutil==6.1.1
- ptyprocess==0.7.0
- pure-eval==0.2.3
- py-cpuinfo==9.0.0
- pyarrow==18.1.0
- pydantic==2.10.4
- pydantic-core==2.27.2
- pydub==0.25.1
- pygments==2.19.1
- pyiqa==0.1.13
- pyparsing==3.2.1
- pytest==8.3.4
- python-dateutil==2.9.0.post0
- python-multipart==0.0.20
- pytz==2024.2
- regex==2024.11.6
- rich==13.9.4
- ruff==0.8.6
- safetensors==0.4.3
- scikit-image==0.25.0
- scipy==1.15.0
- semantic-version==2.10.0
- sentencepiece==0.2.0
- shellingham==1.5.4
- six==1.17.0
- sniffio==1.3.1
- soupsieve==2.6
- stack-data==0.6.3
- starlette==0.41.3
- tensorboard==2.18.0
- tensorboard-data-server==0.7.2
- tensorflow-io-gcs-filesystem==0.37.1
- termcolor==2.5.0
- tifffile==2024.12.12
- timm==1.0.12
- tokenizers==0.15.2
- tomli==2.2.1
- tomlkit==0.12.0
- tqdm==4.67.1
- traitlets==5.14.3
- transformers==4.46.1
- typer==0.15.1
- tzdata==2024.2
- uvicorn==0.34.0
- virtualenv==20.28.1
- wcwidth==0.2.13
- websockets==12.0
- werkzeug==3.1.3
- wrapt==1.17.2
- xxhash==3.5.0
- yapf==0.43.0
- yarl==1.18.3
================================================
FILE: gradio_demo.py
================================================
import torch
import torch.utils.checkpoint
import torch.cuda
import random
import gradio as gr
import numpy as np
import argparse
from typing import List
from PIL import Image
from FaithDiff.create_FaithDiff_model import FaithDiff_pipeline
from PIL import Image
from CKPT_PTH import LLAVA_MODEL_PATH, SDXL_PATH, FAITHDIFF_PATH, VAE_FP16_PATH
from utils.color_fix import wavelet_color_fix, adain_color_fix
from utils.image_process import check_image_size, create_hdr_effect
from llava.llm_agent import LLavaAgent
from utils.system import torch_gc
MAX_SEED = np.iinfo(np.int32).max
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--port", type=int, default='6688')
parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--cpu_offload", action='store_true', default=False)
parser.add_argument("--use_fp8", action='store_true', default=False)
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
use_llava = not args.no_llava
cpu_offload = args.cpu_offload
use_fp8 = args.use_fp8
if torch.cuda.device_count() >= 2:
LLaVA_device = 'cuda:1'
Diffusion_device = 'cuda:0'
elif torch.cuda.device_count() == 1:
Diffusion_device = 'cuda:0'
LLaVA_device = 'cuda:0'
else:
raise ValueError('Currently support CUDA only.')
# load LLaVA
if use_llava:
llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=True, load_4bit=False)
else:
llava_agent = None
pipe = FaithDiff_pipeline(sdxl_path=SDXL_PATH, VAE_FP16_path=VAE_FP16_PATH, FaithDiff_path=FAITHDIFF_PATH, use_fp8=use_fp8)
pipe = pipe.to(Diffusion_device)
### enable_vae_tiling
pipe.set_encoder_tile_settings()
pipe.enable_vae_tiling()
if cpu_offload:
pipe.enable_model_cpu_offload()
@torch.no_grad()
def caption_process(
image: Image.Image,
) -> List[np.ndarray]:
if use_llava:
caption = llava_agent.gen_image_caption([image])
else:
caption = ['Caption Generation is not available. Please add text manually.']
return caption[0]
def clear_result():
return gr.update(value=None)
def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
if randomize_seed:
generation_seed = random.randint(0, MAX_SEED)
return generation_seed
@torch.no_grad()
def process(
image: Image.Image,
user_prompt: str,
num_inference_steps: int,
scale_factor: int,
guidance_scale: float,
seed: int,
latent_tiled_size: int,
latent_tiled_overlap: int,
color_fix: str,
start_point: str,
hdr: float
) -> List[np.ndarray]:
w, h = image.size
w *= scale_factor
h *= scale_factor
image = image.resize((w, h), Image.LANCZOS)
input_image, width_init, height_init, width_now, height_now = check_image_size(image)
if use_llava:
init_text = user_prompt
words = init_text.split()
words = words[3:]
words[0] = words[0].capitalize()
text = ' '.join(words)
text = text.split('. ')
text = '. '.join(text[:2]) + '.'
user_prompt = text
negative_prompt_init = ""
generator = torch.Generator(device=Diffusion_device).manual_seed(seed)
input_image = create_hdr_effect(input_image, hdr)
gen_image = pipe(lr_img=input_image, prompt = user_prompt, negative_prompt = negative_prompt_init, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, generator=generator, start_point=start_point, height = height_now, width=width_now, overlap=latent_tiled_overlap, target_size=(latent_tiled_size, latent_tiled_size)).images[0]
torch_gc()
cropped_image = gen_image.crop((0, 0, width_init, height_init))
if color_fix == 'nofix':
out_image = cropped_image
else:
if color_fix == 'wavelet':
out_image = wavelet_color_fix(cropped_image, image)
elif color_fix == 'adain':
out_image = adain_color_fix(cropped_image, image)
image = np.array(out_image)
return image
#
css = """
body {
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
margin: 0;
padding: 0;
}
.gradio-container {
border-radius: 15px;
padding: 30px 40px;
box-shadow: 0 8px 30px rgba(0, 0, 0, 0.3);
margin: 40px 340px;
}
.gradio-container h1 {
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.2);
}
.fillable {
width: 100% !important;
max-width: unset !important;
}
.gradio-slider-input {
input[type="number"] {
width: 8em;
}
}
.slider-input-right > .wrap > .head {
display: flex;
}
.slider-input-right > .wrap > .head > .tab-like-container {
margin-left: auto;
}
#examples_container {
margin: auto;
width: 90%;
}
#examples_row {
justify-content: center;
}
#tips_row{
padding-left: 20px;
}
.sidebar {
border-radius: 10px;
padding: 10px;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
}
.sidebar .toggle-button {
background: linear-gradient(90deg, #34d399, #10b981) !important;
border: none;
padding: 12px 24px;
text-transform: uppercase;
font-weight: bold;
letter-spacing: 1px;
border-radius: 5px;
cursor: pointer;
transition: transform 0.2s ease-in-out;
}
.toggle-button:hover {
transform: scale(1.05);
}
"""
title = """<h1 align="center">FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolution</h1>
<div style="display: flex; flex-direction: column; justify-content: center; align-items: center; text-align: center; overflow:hidden;">
<span>💻 <a href="https://github.com/JyChen9811/FaithDiff/">GitHub Code</a> | 📜 <a href="https://arxiv.org/abs/2411.18824"> Paper</a></span>
<span>If FaithDiff is helpful for you, please help star the GitHub Repo. Thanks!</span>
</div>
"""
block = gr.Blocks(css=css, theme=gr.themes.Ocean(), title="FaithDiff").queue()
with block:
gr.Markdown(title)
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image", sources=["upload"], height=500)
with gr.Column():
result = gr.Image(label="Generated Image", show_label=True, format="png", interactive=False, scale=1, height=500, min_width=670)
with gr.Row():
with gr.Accordion("Input Prompt", open=True):
with gr.Column():
user_prompt = gr.Textbox(lines=2, label="User Prompt", value="")
with gr.Row():
run_button = gr.Button(value="Restoration Run", variant="primary")
llave_button = gr.Button(value="Caption Generation Run")
with gr.Sidebar(label="Parameters", open=True):
gr.Markdown("### General parameters")
with gr.Row():
cfg_scale = gr.Slider(label="CFG Scale", elem_classes="gradio-slider-input slider-input-right", info="Set a value larger than 1 to enable it!", minimum=0.1, maximum=10.0, value=5, step=0.1)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=10, maximum=50, value=20, step=1)
generation_seed = gr.Slider(label="Seed", elem_classes="gradio-slider-input", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
with gr.Row():
latent_tiled_size = gr.Slider(label="Tile Size", elem_classes="gradio-slider-input slider-input-right", minimum=1024, maximum=1280, value=1024, step=1)
latent_tiled_overlap = gr.Slider(label="Tile Overlap", elem_classes="gradio-slider-input slider-input-right", minimum=0.1, maximum=0.9, value=0.5, step=0.1)
scale_factor = gr.Number(label="SR Scale", value=2)
color_fix = gr.Dropdown(label="Color Fix", choices=["wavelet", "adain", "nofix"], value="adain")
hdr = gr.Slider(label="HDR Effect", elem_classes="gradio-slider-input", minimum=0, maximum=2, value=0, step=0.1)
start_point = gr.Dropdown(label="Start Point", choices=["lr", "noise"], value="lr")
with gr.Accordion(label="Example Images", open=True):
with gr.Row(elem_id="examples_row"):
with gr.Column(scale=12, elem_id="examples_container"):
gr.Examples(
examples=[
[ "./examples/band.png",
"Three men posing for a picture with their guitars.",
20,
2.0,
5,
42,
1024,
0.5,
"adain",
"lr",
0
],
],
inputs = [
input_image,
user_prompt,
num_inference_steps,
scale_factor,
cfg_scale,
generation_seed,
latent_tiled_size,
latent_tiled_overlap,
color_fix,
start_point,
hdr
],
fn=process,
outputs=result,
cache_examples=False,
)
inputs = [
input_image,
user_prompt,
num_inference_steps,
scale_factor,
cfg_scale,
generation_seed,
latent_tiled_size,
latent_tiled_overlap,
color_fix,
start_point,
hdr
]
run_button.click(
fn=clear_result,
inputs=None,
outputs=result,
).then(
fn=randomize_seed_fn,
inputs=[generation_seed, randomize_seed],
outputs=generation_seed,
queue=False,
api_name=False,
).then(fn=process, inputs=inputs, outputs=[result])
llave_button.click(fn=caption_process, inputs=[input_image], outputs=[user_prompt])
block.launch(server_name=server_ip, server_port=server_port)
================================================
FILE: llava/__init__.py
================================================
from .model import LlavaLlamaForCausalLM
================================================
FILE: llava/constants.py
================================================
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "."
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
IMAGE_PLACEHOLDER = "<image-placeholder>"
================================================
FILE: llava/conversation.py
================================================
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep
gitextract_k3bn1c1l/
├── .gitignore
├── CKPT_PTH.py
├── FaithDiff/
│ ├── create_FaithDiff_model.py
│ ├── models/
│ │ ├── bsrnet_arch.py
│ │ └── unet_2d_condition_vae_extension.py
│ ├── pipelines/
│ │ ├── __init__.py
│ │ ├── pipeline_FaithDiff_tlc.py
│ │ └── pipeline_output.py
│ └── training_utils.py
├── LICENSE
├── README.md
├── dataloader/
│ ├── Realesrgan_offline_dataset.py
│ ├── accelerate_config.yaml
│ ├── realesrgan.py
│ └── train_kernel.yml
├── environment.yml
├── gradio_demo.py
├── llava/
│ ├── __init__.py
│ ├── constants.py
│ ├── conversation.py
│ ├── eval/
│ │ ├── eval_gpt_review.py
│ │ ├── eval_gpt_review_bench.py
│ │ ├── eval_gpt_review_visual.py
│ │ ├── eval_pope.py
│ │ ├── eval_science_qa.py
│ │ ├── eval_science_qa_gpt4.py
│ │ ├── eval_science_qa_gpt4_requery.py
│ │ ├── eval_textvqa.py
│ │ ├── generate_webpage_data_from_table.py
│ │ ├── m4c_evaluator.py
│ │ ├── model_qa.py
│ │ ├── model_vqa.py
│ │ ├── model_vqa_loader.py
│ │ ├── model_vqa_mmbench.py
│ │ ├── model_vqa_science.py
│ │ ├── qa_baseline_gpt35.py
│ │ ├── run_llava.py
│ │ ├── summarize_gpt_review.py
│ │ ├── table/
│ │ │ ├── answer/
│ │ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ │ ├── answer_bard.jsonl
│ │ │ │ ├── answer_gpt35.jsonl
│ │ │ │ ├── answer_llama-13b.jsonl
│ │ │ │ └── answer_vicuna-13b.jsonl
│ │ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ │ ├── model.jsonl
│ │ │ ├── prompt.jsonl
│ │ │ ├── question.jsonl
│ │ │ ├── results/
│ │ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ │ ├── review/
│ │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ │ ├── reviewer.jsonl
│ │ │ └── rule.json
│ │ └── webpage/
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
│ ├── llm_agent.py
│ ├── mm_utils.py
│ ├── model/
│ │ ├── __init__.py
│ │ ├── apply_delta.py
│ │ ├── builder.py
│ │ ├── consolidate.py
│ │ ├── language_model/
│ │ │ ├── llava_llama.py
│ │ │ ├── llava_mistral.py
│ │ │ └── llava_mpt.py
│ │ ├── llava_arch.py
│ │ ├── make_delta.py
│ │ ├── multimodal_encoder/
│ │ │ ├── builder.py
│ │ │ └── clip_encoder.py
│ │ ├── multimodal_projector/
│ │ │ └── builder.py
│ │ └── utils.py
│ ├── serve/
│ │ ├── __init__.py
│ │ ├── cli.py
│ │ ├── controller.py
│ │ ├── gradio_web_server.py
│ │ ├── model_worker.py
│ │ ├── register_worker.py
│ │ ├── sglang_worker.py
│ │ └── test_message.py
│ ├── train/
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ ├── llama_xformers_attn_monkey_patch.py
│ │ ├── llava_trainer.py
│ │ ├── train.py
│ │ ├── train_mem.py
│ │ └── train_xformers.py
│ └── utils.py
├── requirements.txt
├── test.py
├── test_generate_caption.py
├── test_metrics.py
├── test_wo_llava.py
├── train_SDXL_stage_1.py
├── train_SDXL_stage_2.py
├── train_stage_1.sh
├── train_stage_2.sh
└── utils/
├── color_fix.py
├── image_process.py
└── system.py
SYMBOL INDEX (466 symbols across 60 files)
FILE: FaithDiff/create_FaithDiff_model.py
function FaithDiff_pipeline (line 10) | def FaithDiff_pipeline(sdxl_path, VAE_FP16_path, FaithDiff_path, use_fp8...
function create_bsrnet (line 35) | def create_bsrnet(bsrnet_path):
FILE: FaithDiff/models/bsrnet_arch.py
function initialize_weights (line 8) | def initialize_weights(net_l, scale=1):
function make_layer (line 28) | def make_layer(block, n_layers):
class ResidualDenseBlock_5C (line 35) | class ResidualDenseBlock_5C(nn.Module):
method __init__ (line 36) | def __init__(self, nf=64, gc=32, bias=True):
method forward (line 49) | def forward(self, x):
class RRDB (line 58) | class RRDB(nn.Module):
method __init__ (line 61) | def __init__(self, nf, gc=32):
method forward (line 67) | def forward(self, x):
class RRDBNet (line 74) | class RRDBNet(nn.Module):
method __init__ (line 75) | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
method check_image_size (line 93) | def check_image_size(self, x, scale):
method forward (line 101) | def forward(self, x):
method deg_remove (line 116) | def deg_remove(self, input, tile_size=512, tile_pad=16):
FILE: FaithDiff/models/unet_2d_condition_vae_extension.py
function zero_module (line 33) | def zero_module(module):
class Encoder (line 39) | class Encoder(nn.Module):
method __init__ (line 41) | def __init__(
method to_rgb_init (line 110) | def to_rgb_init(self):
method enable_tiling (line 118) | def enable_tiling(self):
method encode (line 122) | def encode(self, sample: torch.FloatTensor) -> torch.FloatTensor:
method blend_v (line 150) | def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int)...
method blend_h (line 157) | def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int)...
method tiled_encode (line 164) | def tiled_encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
method forward (line 192) | def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
class ControlNetConditioningEmbedding (line 199) | class ControlNetConditioningEmbedding(nn.Module):
method __init__ (line 201) | def __init__(
method forward (line 213) | def forward(self, conditioning):
class QuickGELU (line 222) | class QuickGELU(nn.Module):
method forward (line 224) | def forward(self, x: torch.Tensor):
class LayerNorm (line 229) | class LayerNorm(nn.LayerNorm):
method forward (line 231) | def forward(self, x: torch.Tensor):
class ResidualAttentionBlock (line 238) | class ResidualAttentionBlock(nn.Module):
method __init__ (line 240) | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor ...
method attention (line 251) | def attention(self, x: torch.Tensor):
method forward (line 256) | def forward(self, x: torch.Tensor):
class UNet2DConditionOutput (line 264) | class UNet2DConditionOutput(BaseOutput):
class UNet2DConditionModel (line 269) | class UNet2DConditionModel(OriginalUNet2DConditionModel, ConfigMixin, UN...
method __init__ (line 274) | def __init__(
method init_vae_encoder (line 389) | def init_vae_encoder(self, dtype):
method init_information_transformer_layes (line 393) | def init_information_transformer_layes(self):
method init_ControlNetConditioningEmbedding (line 400) | def init_ControlNetConditioningEmbedding(self, channel=512):
method init_extra_weights (line 402) | def init_extra_weights(self):
method load_additional_layers (line 405) | def load_additional_layers(self, dtype: Optional[torch.dtype] = torch....
method to (line 435) | def to(self, *args, **kwargs):
method load_state_dict (line 444) | def load_state_dict(self, state_dict, strict=True):
method forward (line 480) | def forward(
FILE: FaithDiff/pipelines/pipeline_FaithDiff_tlc.py
function img2tensor (line 74) | def img2tensor(imgs, bgr2rgb=True, float32=True):
class LocalAttention (line 101) | class LocalAttention:
method __init__ (line 104) | def __init__(self, kernel_size=None, overlap=0.5):
method grids_list (line 115) | def grids_list(self, x):
method grids (line 160) | def grids(self, x):
method _gaussian_weights (line 207) | def _gaussian_weights(self, tile_width, tile_height):
method grids_inverse (line 230) | def grids_inverse(self, outs):
method _pad (line 254) | def _pad(self, x):
method forward (line 271) | def forward(self, x):
function rescale_noise_cfg (line 286) | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
function retrieve_latents (line 308) | def retrieve_latents(
function retrieve_timesteps (line 331) | def retrieve_timesteps(
class FaithDiffStableDiffusionXLPipeline (line 374) | class FaithDiffStableDiffusionXLPipeline(
method __init__ (line 445) | def __init__(
method encode_prompt (line 485) | def encode_prompt(
method prepare_extra_step_kwargs (line 720) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 737) | def check_inputs(
method prepare_latents (line 820) | def prepare_latents(self, batch_size, num_channels_latents, height, wi...
method upcast_vae (line 837) | def upcast_vae(self):
method get_guidance_scale_embedding (line 858) | def get_guidance_scale_embedding(
method set_encoder_tile_settings (line 888) | def set_encoder_tile_settings(self,
method enable_vae_tiling (line 898) | def enable_vae_tiling(self):
method disable_vae_tiling (line 907) | def disable_vae_tiling(self):
method guidance_scale (line 916) | def guidance_scale(self):
method guidance_rescale (line 920) | def guidance_rescale(self):
method clip_skip (line 924) | def clip_skip(self):
method do_classifier_free_guidance (line 931) | def do_classifier_free_guidance(self):
method cross_attention_kwargs (line 935) | def cross_attention_kwargs(self):
method denoising_end (line 939) | def denoising_end(self):
method num_timesteps (line 943) | def num_timesteps(self):
method interrupt (line 947) | def interrupt(self):
method prepare_image_latents (line 950) | def prepare_image_latents(
method __call__ (line 1008) | def __call__(
FILE: FaithDiff/pipelines/pipeline_output.py
class FaithDiffStableDiffusionXLPipelineOutput (line 11) | class FaithDiffStableDiffusionXLPipelineOutput(BaseOutput):
class FlaxFaithDiffStableDiffusionXLPipelineOutput (line 28) | class FlaxFaithDiffStableDiffusionXLPipelineOutput(BaseOutput):
FILE: FaithDiff/training_utils.py
function set_seed (line 40) | def set_seed(seed: int):
function compute_snr (line 60) | def compute_snr(noise_scheduler, timesteps):
function resolve_interpolation_mode (line 97) | def resolve_interpolation_mode(interpolation_type: str):
function compute_dream_and_update_latents (line 141) | def compute_dream_and_update_latents(
function unet_lora_state_dict (line 195) | def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch....
function cast_training_params (line 214) | def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Mod...
function _set_state_dict_into_text_encoder (line 231) | def _set_state_dict_into_text_encoder(
function compute_density_for_timestep_sampling (line 250) | def compute_density_for_timestep_sampling(
function compute_loss_weighting_for_sd3 (line 272) | def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
function free_memory (line 290) | def free_memory():
function should_update_ema (line 305) | def should_update_ema(args, step):
class EMAModel (line 316) | class EMAModel:
method __init__ (line 321) | def __init__(
method save_state_dict (line 414) | def save_state_dict(self, path: str) -> None:
method load_state_dict (line 429) | def load_state_dict(self, path: str) -> None:
method from_pretrained (line 470) | def from_pretrained(cls, path, model_cls) -> "EMAModel":
method save_pretrained (line 481) | def save_pretrained(self, path, max_shard_size: str = "10GB"):
method get_decay (line 500) | def get_decay(self, optimization_step: int = None) -> float:
method step (line 523) | def step(self, parameters: Iterable[torch.nn.Parameter], global_step: ...
method copy_to (line 633) | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
method pin_memory (line 655) | def pin_memory(self) -> None:
method to (line 670) | def to(self, *args, **kwargs):
method cuda (line 675) | def cuda(self, device=None):
method cpu (line 678) | def cpu(self):
method state_dict (line 681) | def state_dict(
method store (line 704) | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
method restore (line 710) | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
method parameter_count (line 731) | def parameter_count(self) -> int:
method named_children (line 736) | def named_children(self):
method children (line 740) | def children(self):
method modules (line 743) | def modules(self):
method named_modules (line 746) | def named_modules(self, memo=None, prefix=""):
method parameters (line 749) | def parameters(self, recurse=True):
method named_parameters (line 752) | def named_parameters(self, prefix="", recurse=True):
method buffers (line 757) | def buffers(self, recurse=True):
method named_buffers (line 760) | def named_buffers(self, prefix="", recurse=True):
method train (line 763) | def train(self, mode=True):
method eval (line 767) | def eval(self):
method zero_grad (line 770) | def zero_grad(self):
FILE: dataloader/Realesrgan_offline_dataset.py
function ordered_yaml (line 26) | def ordered_yaml():
function opt_parse (line 50) | def opt_parse(opt_path):
function convert_image_to_fn (line 57) | def convert_image_to_fn(img_type, image, minsize=512, eps=0.02):
function exists (line 66) | def exists(x):
class LocalImageDataset (line 70) | class LocalImageDataset(data.Dataset):
method __init__ (line 71) | def __init__(self,
method __getitem__ (line 144) | def __getitem__(self, index):
method __len__ (line 276) | def __len__(self):
FILE: dataloader/realesrgan.py
function ordered_yaml (line 22) | def ordered_yaml():
function opt_parse (line 46) | def opt_parse(opt_path):
class RealESRGAN_degradation (line 53) | class RealESRGAN_degradation(object):
method __init__ (line 54) | def __init__(self, opt_name='params_realesrgan.yml', device='cpu'):
method color_jitter_pt (line 88) | def color_jitter_pt(self, img, brightness, contrast, saturation, hue):
method random_augment (line 108) | def random_augment(self, img_gt):
method random_kernels (line 129) | def random_kernels(self):
method degrade_process (line 191) | def degrade_process(self, img_gt, resize_bak=False, seed=42):
FILE: gradio_demo.py
function caption_process (line 59) | def caption_process(
function clear_result (line 69) | def clear_result():
function randomize_seed_fn (line 72) | def randomize_seed_fn(generation_seed: int, randomize_seed: bool) -> int:
function process (line 78) | def process(
FILE: llava/conversation.py
class SeparatorStyle (line 9) | class SeparatorStyle(Enum):
class Conversation (line 19) | class Conversation:
method get_prompt (line 32) | def get_prompt(self):
method append_message (line 109) | def append_message(self, role, message):
method process_image (line 112) | def process_image(self, image, image_process_mode, return_pil=False, i...
method get_images (line 152) | def get_images(self, return_pil=False):
method to_gradio_chatbot (line 162) | def to_gradio_chatbot(self):
method copy (line 180) | def copy(self):
method dict (line 191) | def dict(self):
FILE: llava/eval/eval_gpt_review.py
function get_eval (line 13) | def get_eval(content: str, max_tokens: int):
function parse_score (line 39) | def parse_score(review):
FILE: llava/eval/eval_gpt_review_bench.py
function get_eval (line 11) | def get_eval(content: str, max_tokens: int):
function parse_score (line 36) | def parse_score(review):
FILE: llava/eval/eval_gpt_review_visual.py
function get_eval (line 11) | def get_eval(content: str, max_tokens: int):
function parse_score (line 36) | def parse_score(review):
FILE: llava/eval/eval_pope.py
function eval_pope (line 5) | def eval_pope(answers, label_file):
FILE: llava/eval/eval_science_qa.py
function get_args (line 8) | def get_args():
function convert_caps (line 19) | def convert_caps(results):
function get_pred_idx (line 28) | def get_pred_idx(prediction, choices, options):
FILE: llava/eval/eval_science_qa_gpt4.py
function get_args (line 9) | def get_args():
function convert_caps (line 19) | def convert_caps(results):
function get_pred_idx (line 28) | def get_pred_idx(prediction, choices, options):
FILE: llava/eval/eval_science_qa_gpt4_requery.py
function get_args (line 9) | def get_args():
function convert_caps (line 21) | def convert_caps(results):
function get_pred_idx (line 30) | def get_pred_idx(prediction, choices, options):
FILE: llava/eval/eval_textvqa.py
function get_args (line 9) | def get_args():
function prompt_processor (line 17) | def prompt_processor(prompt):
function eval_single (line 35) | def eval_single(annotation_file, result_file):
FILE: llava/eval/generate_webpage_data_from_table.py
function read_jsonl (line 10) | def read_jsonl(path: str, key: str=None):
function trim_hanging_lines (line 23) | def trim_hanging_lines(s: str, n: int) -> str:
FILE: llava/eval/m4c_evaluator.py
class EvalAIAnswerProcessor (line 7) | class EvalAIAnswerProcessor:
method __init__ (line 178) | def __init__(self, *args, **kwargs):
method word_tokenize (line 181) | def word_tokenize(self, word):
method process_punctuation (line 186) | def process_punctuation(self, in_text):
method process_digit_article (line 198) | def process_digit_article(self, in_text):
method __call__ (line 213) | def __call__(self, item):
class TextVQAAccuracyEvaluator (line 221) | class TextVQAAccuracyEvaluator:
method __init__ (line 222) | def __init__(self):
method _compute_answer_scores (line 225) | def _compute_answer_scores(self, raw_answers):
method eval_pred_list (line 248) | def eval_pred_list(self, pred_list):
class STVQAAccuracyEvaluator (line 260) | class STVQAAccuracyEvaluator:
method __init__ (line 261) | def __init__(self):
method eval_pred_list (line 264) | def eval_pred_list(self, pred_list):
class STVQAANLSEvaluator (line 276) | class STVQAANLSEvaluator:
method __init__ (line 277) | def __init__(self):
method get_anls (line 282) | def get_anls(self, s1, s2):
method eval_pred_list (line 289) | def eval_pred_list(self, pred_list):
class TextCapsBleu4Evaluator (line 301) | class TextCapsBleu4Evaluator:
method __init__ (line 302) | def __init__(self):
method eval_pred_list (line 321) | def eval_pred_list(self, pred_list):
FILE: llava/eval/model_qa.py
function eval_model (line 14) | def eval_model(model_name, questions_file, answers_file):
FILE: llava/eval/model_vqa.py
function split_list (line 18) | def split_list(lst, n):
function get_chunk (line 24) | def get_chunk(lst, n, k):
function eval_model (line 29) | def eval_model(args):
FILE: llava/eval/model_vqa_loader.py
function split_list (line 19) | def split_list(lst, n):
function get_chunk (line 25) | def get_chunk(lst, n, k):
class CustomDataset (line 31) | class CustomDataset(Dataset):
method __init__ (line 32) | def __init__(self, questions, image_folder, tokenizer, image_processor...
method __getitem__ (line 39) | def __getitem__(self, index):
method __len__ (line 60) | def __len__(self):
function collate_fn (line 64) | def collate_fn(batch):
function create_data_loader (line 72) | def create_data_loader(questions, image_folder, tokenizer, image_process...
function eval_model (line 79) | def eval_model(args):
FILE: llava/eval/model_vqa_mmbench.py
function split_list (line 22) | def split_list(lst, n):
function get_chunk (line 28) | def get_chunk(lst, n, k):
function is_none (line 33) | def is_none(value):
function get_options (line 44) | def get_options(row, options):
function eval_model (line 54) | def eval_model(args):
FILE: llava/eval/model_vqa_science.py
function split_list (line 18) | def split_list(lst, n):
function get_chunk (line 24) | def get_chunk(lst, n, k):
function eval_model (line 29) | def eval_model(args):
FILE: llava/eval/qa_baseline_gpt35.py
function get_answer (line 16) | def get_answer(question_id: int, question: str, max_tokens: int):
FILE: llava/eval/run_llava.py
function image_parser (line 28) | def image_parser(args):
function load_image (line 33) | def load_image(image_file):
function load_images (line 42) | def load_images(image_files):
function eval_model (line 50) | def eval_model(args):
FILE: llava/eval/summarize_gpt_review.py
function parse_args (line 9) | def parse_args():
FILE: llava/eval/webpage/script.js
function text2Markdown (line 35) | function text2Markdown(text) {
function capitalizeFirstChar (line 41) | function capitalizeFirstChar(str) {
function updateQuestionSelect (line 48) | function updateQuestionSelect(question_id) {
function updateModelSelect (line 64) | function updateModelSelect() {
function populateModels (line 70) | function populateModels(models) {
function populateQuestions (line 81) | function populateQuestions(questions) {
function displayQuestion (line 110) | function displayQuestion(index) {
function displayAnswers (line 116) | function displayAnswers(index) {
function switchQuestionAndCategory (line 203) | function switchQuestionAndCategory() {
function updateExpandButtonVisibility (line 226) | function updateExpandButtonVisibility(card) {
FILE: llava/llm_agent.py
class LLavaAgent (line 18) | class LLavaAgent:
method __init__ (line 19) | def __init__(self, model_path, device='cuda', conv_mode='vicuna_v1', l...
method update_qs (line 49) | def update_qs(self, qs=None):
method gen_image_caption (line 65) | def gen_image_caption(self, imgs, temperature=0.2, top_p=0.7, num_beam...
FILE: llava/mm_utils.py
function select_best_resolution (line 12) | def select_best_resolution(original_size, possible_resolutions):
function resize_and_pad_image (line 42) | def resize_and_pad_image(image, target_resolution):
function divide_to_patches (line 77) | def divide_to_patches(image, patch_size):
function get_anyres_image_grid_shape (line 99) | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
function process_anyres_image (line 119) | def process_anyres_image(image, processor, grid_pinpoints):
function load_image_from_base64 (line 148) | def load_image_from_base64(image):
function expand2square (line 152) | def expand2square(pil_img, background_color):
function process_images (line 166) | def process_images(images, image_processor, model_cfg):
function tokenizer_image_token (line 185) | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOK...
function get_model_name_from_path (line 207) | def get_model_name_from_path(model_path):
class KeywordsStoppingCriteria (line 215) | class KeywordsStoppingCriteria(StoppingCriteria):
method __init__ (line 216) | def __init__(self, keywords, tokenizer, input_ids):
method call_for_batch (line 230) | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.F...
method __call__ (line 243) | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTe...
FILE: llava/model/apply_delta.py
function apply_delta (line 13) | def apply_delta(base_model_path, target_model_path, delta_path):
FILE: llava/model/builder.py
function load_pretrained_model (line 26) | def load_pretrained_model(model_path, model_base, model_name, load_8bit=...
FILE: llava/model/consolidate.py
function consolidate_ckpt (line 13) | def consolidate_ckpt(src_path, dst_path):
FILE: llava/model/language_model/llava_llama.py
class LlavaConfig (line 30) | class LlavaConfig(LlamaConfig):
class LlavaLlamaModel (line 34) | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
method __init__ (line 37) | def __init__(self, config: LlamaConfig):
class LlavaLlamaForCausalLM (line 41) | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
method __init__ (line 44) | def __init__(self, config):
method get_model (line 54) | def get_model(self):
method forward (line 57) | def forward(
method generate (line 106) | def generate(
method prepare_inputs_for_generation (line 145) | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
FILE: llava/model/language_model/llava_mistral.py
class LlavaMistralConfig (line 31) | class LlavaMistralConfig(MistralConfig):
class LlavaMistralModel (line 35) | class LlavaMistralModel(LlavaMetaModel, MistralModel):
method __init__ (line 38) | def __init__(self, config: MistralConfig):
class LlavaMistralForCausalLM (line 42) | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
method __init__ (line 45) | def __init__(self, config):
method get_model (line 54) | def get_model(self):
method forward (line 57) | def forward(
method generate (line 105) | def generate(
method prepare_inputs_for_generation (line 144) | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
FILE: llava/model/language_model/llava_mpt.py
class LlavaMptConfig (line 25) | class LlavaMptConfig(MptConfig):
class LlavaMptModel (line 29) | class LlavaMptModel(LlavaMetaModel, MptModel):
method __init__ (line 32) | def __init__(self, config: MptConfig):
method embed_tokens (line 36) | def embed_tokens(self, x):
class LlavaMptForCausalLM (line 40) | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
method __init__ (line 44) | def __init__(self, config):
method get_model (line 53) | def get_model(self):
method _set_gradient_checkpointing (line 56) | def _set_gradient_checkpointing(self, module, value=False):
method forward (line 60) | def forward(
method prepare_inputs_for_generation (line 87) | def prepare_inputs_for_generation(self, input_ids, past_key_values=Non...
FILE: llava/model/llava_arch.py
class LlavaMetaModel (line 29) | class LlavaMetaModel:
method __init__ (line 31) | def __init__(self, config):
method get_vision_tower (line 43) | def get_vision_tower(self):
method initialize_vision_modules (line 49) | def initialize_vision_modules(self, model_args, fsdp=None):
function unpad_image (line 100) | def unpad_image(tensor, original_size):
class LlavaMetaForCausalLM (line 131) | class LlavaMetaForCausalLM(ABC):
method get_model (line 134) | def get_model(self):
method get_vision_tower (line 137) | def get_vision_tower(self):
method encode_images (line 140) | def encode_images(self, images):
method prepare_inputs_labels_for_multimodal (line 145) | def prepare_inputs_labels_for_multimodal(
method initialize_vision_tokenizer (line 326) | def initialize_vision_tokenizer(self, model_args, tokenizer):
FILE: llava/model/make_delta.py
function make_delta (line 13) | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_...
FILE: llava/model/multimodal_encoder/builder.py
function build_vision_tower (line 5) | def build_vision_tower(vision_tower_cfg, **kwargs):
FILE: llava/model/multimodal_encoder/clip_encoder.py
class CLIPVisionTower (line 7) | class CLIPVisionTower(nn.Module):
method __init__ (line 8) | def __init__(self, vision_tower, args, delay_load=False):
method load_model (line 24) | def load_model(self, device_map=None):
method feature_select (line 35) | def feature_select(self, image_forward_outs):
method forward (line 46) | def forward(self, images):
method dummy_feature (line 60) | def dummy_feature(self):
method dtype (line 64) | def dtype(self):
method device (line 68) | def device(self):
method config (line 72) | def config(self):
method hidden_size (line 79) | def hidden_size(self):
method num_patches_per_side (line 83) | def num_patches_per_side(self):
method num_patches (line 87) | def num_patches(self):
FILE: llava/model/multimodal_projector/builder.py
class IdentityMap (line 6) | class IdentityMap(nn.Module):
method __init__ (line 7) | def __init__(self):
method forward (line 10) | def forward(self, x, *args, **kwargs):
method config (line 14) | def config(self):
class SimpleResBlock (line 18) | class SimpleResBlock(nn.Module):
method __init__ (line 19) | def __init__(self, channels):
method forward (line 28) | def forward(self, x):
function build_vision_projector (line 33) | def build_vision_projector(config, delay_load=False, **kwargs):
FILE: llava/model/utils.py
function auto_upgrade (line 4) | def auto_upgrade(config):
FILE: llava/serve/cli.py
function load_image (line 18) | def load_image(image_file):
function main (line 27) | def main(args):
FILE: llava/serve/controller.py
class DispatchMethod (line 28) | class DispatchMethod(Enum):
method from_str (line 33) | def from_str(cls, name):
class WorkerInfo (line 43) | class WorkerInfo:
function heart_beat_controller (line 51) | def heart_beat_controller(controller):
class Controller (line 57) | class Controller:
method __init__ (line 58) | def __init__(self, dispatch_method: str):
method register_worker (line 69) | def register_worker(self, worker_name: str, check_heart_beat: bool,
method get_worker_status (line 88) | def get_worker_status(self, worker_name: str):
method remove_worker (line 101) | def remove_worker(self, worker_name: str):
method refresh_all_workers (line 104) | def refresh_all_workers(self):
method list_models (line 112) | def list_models(self):
method get_worker_address (line 120) | def get_worker_address(self, model_name: str):
method receive_heart_beat (line 173) | def receive_heart_beat(self, worker_name: str, queue_length: int):
method remove_stable_workers_by_expiration (line 183) | def remove_stable_workers_by_expiration(self):
method worker_api_generate_stream (line 193) | def worker_api_generate_stream(self, params):
method worker_api_get_status (line 220) | def worker_api_get_status(self):
function register_worker (line 243) | async def register_worker(request: Request):
function refresh_all_workers (line 251) | async def refresh_all_workers():
function list_models (line 256) | async def list_models():
function get_worker_address (line 262) | async def get_worker_address(request: Request):
function receive_heart_beat (line 269) | async def receive_heart_beat(request: Request):
function worker_api_generate_stream (line 277) | async def worker_api_generate_stream(request: Request):
function worker_api_get_status (line 284) | async def worker_api_get_status(request: Request):
FILE: llava/serve/gradio_web_server.py
function get_conv_log_filename (line 32) | def get_conv_log_filename():
function get_model_list (line 38) | def get_model_list():
function load_demo (line 58) | def load_demo(url_params, request: gr.Request):
function load_demo_refresh_model_list (line 71) | def load_demo_refresh_model_list(request: gr.Request):
function vote_last_response (line 82) | def vote_last_response(state, vote_type, model_selector, request: gr.Req...
function upvote_last_response (line 94) | def upvote_last_response(state, model_selector, request: gr.Request):
function downvote_last_response (line 100) | def downvote_last_response(state, model_selector, request: gr.Request):
function flag_last_response (line 106) | def flag_last_response(state, model_selector, request: gr.Request):
function regenerate (line 112) | def regenerate(state, image_process_mode, request: gr.Request):
function clear_history (line 122) | def clear_history(request: gr.Request):
function add_text (line 128) | def add_text(state, text, image, image_process_mode, request: gr.Request):
function http_bot (line 154) | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, ...
function build_demo (line 315) | def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
FILE: llava/serve/model_worker.py
function heart_beat_worker (line 37) | def heart_beat_worker(controller):
class ModelWorker (line 44) | class ModelWorker:
method __init__ (line 45) | def __init__(self, controller_addr, worker_addr,
method register_to_controller (line 75) | def register_to_controller(self):
method send_heart_beat (line 87) | def send_heart_beat(self):
method get_queue_length (line 108) | def get_queue_length(self):
method get_status (line 115) | def get_status(self):
method generate_stream (line 123) | def generate_stream(self, params):
method generate_stream_gate (line 195) | def generate_stream_gate(self, params):
function release_model_semaphore (line 225) | def release_model_semaphore(fn=None):
function generate_stream (line 232) | async def generate_stream(request: Request):
function get_status (line 248) | async def get_status(request: Request):
FILE: llava/serve/sglang_worker.py
function heart_beat_worker (line 38) | def heart_beat_worker(controller):
function pipeline (line 45) | def pipeline(s, prompt, max_tokens):
class ModelWorker (line 54) | class ModelWorker:
method __init__ (line 55) | def __init__(self, controller_addr, worker_addr, sgl_endpoint,
method register_to_controller (line 85) | def register_to_controller(self):
method send_heart_beat (line 97) | def send_heart_beat(self):
method get_queue_length (line 118) | def get_queue_length(self):
method get_status (line 125) | def get_status(self):
method generate_stream (line 132) | async def generate_stream(self, params):
method generate_stream_gate (line 172) | async def generate_stream_gate(self, params):
function release_model_semaphore (line 195) | def release_model_semaphore(fn=None):
function generate_stream (line 202) | async def generate_stream(request: Request):
function get_status (line 218) | async def get_status(request: Request):
FILE: llava/serve/test_message.py
function main (line 9) | def main():
FILE: llava/train/llama_flash_attn_monkey_patch.py
function forward (line 16) | def forward(
function _prepare_decoder_attention_mask (line 98) | def _prepare_decoder_attention_mask(
function replace_llama_attn_with_flash_attn (line 105) | def replace_llama_attn_with_flash_attn():
FILE: llava/train/llama_xformers_attn_monkey_patch.py
function replace_llama_attn_with_xformers_attn (line 19) | def replace_llama_attn_with_xformers_attn():
function xformers_forward (line 23) | def xformers_forward(
FILE: llava/train/llava_trainer.py
function maybe_zero_3 (line 18) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_mm_adapter_state_maybe_zero_3 (line 32) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function split_to_even_chunks (line 38) | def split_to_even_chunks(indices, lengths, num_chunks):
function get_modality_length_grouped_indices (line 60) | def get_modality_length_grouped_indices(lengths, batch_size, world_size,...
function get_length_grouped_indices (line 88) | def get_length_grouped_indices(lengths, batch_size, world_size, generato...
class LengthGroupedSampler (line 99) | class LengthGroupedSampler(Sampler):
method __init__ (line 105) | def __init__(
method __len__ (line 122) | def __len__(self):
method __iter__ (line 125) | def __iter__(self):
class LLaVATrainer (line 133) | class LLaVATrainer(Trainer):
method _get_train_sampler (line 135) | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
method create_optimizer (line 150) | def create_optimizer(self):
method _save_checkpoint (line 230) | def _save_checkpoint(self, model, trial, metrics=None):
method _save (line 251) | def _save(self, output_dir: Optional[str] = None, state_dict=None):
FILE: llava/train/train.py
function rank0_print (line 44) | def rank0_print(*args):
class ModelArguments (line 54) | class ModelArguments:
class DataArguments (line 70) | class DataArguments:
class TrainingArguments (line 80) | class TrainingArguments(transformers.TrainingArguments):
function maybe_zero_3 (line 115) | def maybe_zero_3(param, ignore_status=False, name=None):
function get_peft_state_maybe_zero_3 (line 130) | def get_peft_state_maybe_zero_3(named_params, bias):
function get_peft_state_non_lora_maybe_zero_3 (line 155) | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only...
function get_mm_adapter_state_maybe_zero_3 (line 163) | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
function find_all_linear_names (line 169) | def find_all_linear_names(model):
function safe_save_model_for_hf_trainer (line 185) | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
function smart_tokenizer_and_embedding_resize (line 224) | def smart_tokenizer_and_embedding_resize(
function _tokenize_fn (line 249) | def _tokenize_fn(strings: Sequence[str],
function _mask_targets (line 276) | def _mask_targets(target, tokenized_lens, speakers):
function _add_speaker_and_signal (line 287) | def _add_speaker_and_signal(header, source, get_conversation=True):
function preprocess_multimodal (line 308) | def preprocess_multimodal(
function preprocess_llama_2 (line 332) | def preprocess_llama_2(
function preprocess_v1 (line 414) | def preprocess_v1(
function preprocess_mpt (line 500) | def preprocess_mpt(
function preprocess_plain (line 588) | def preprocess_plain(
function preprocess (line 610) | def preprocess(
class LazySupervisedDataset (line 658) | class LazySupervisedDataset(Dataset):
method __init__ (line 661) | def __init__(self, data_path: str,
method __len__ (line 672) | def __len__(self):
method lengths (line 676) | def lengths(self):
method modality_lengths (line 684) | def modality_lengths(self):
method __getitem__ (line 692) | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
class DataCollatorForSupervisedDataset (line 743) | class DataCollatorForSupervisedDataset(object):
method __call__ (line 748) | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
function make_supervised_data_module (line 776) | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokeni...
function train (line 788) | def train(attn_implementation=None):
FILE: llava/utils.py
function build_logger (line 17) | def build_logger(logger_name, logger_filename):
class StreamToLogger (line 60) | class StreamToLogger(object):
method __init__ (line 64) | def __init__(self, logger, log_level=logging.INFO):
method __getattr__ (line 70) | def __getattr__(self, attr):
method write (line 73) | def write(self, buf):
method flush (line 87) | def flush(self):
function disable_torch_init (line 93) | def disable_torch_init():
function violates_moderation (line 102) | def violates_moderation(text):
function pretty_print_semaphore (line 123) | def pretty_print_semaphore(semaphore):
FILE: test_metrics.py
function get_timestamp (line 19) | def get_timestamp():
function setup_logger (line 23) | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=Fa...
function dict2str (line 53) | def dict2str(opt, indent=1):
function main (line 74) | def main():
FILE: train_SDXL_stage_1.py
function ordered_yaml (line 74) | def ordered_yaml():
function image_grid (line 97) | def image_grid(imgs, rows, cols):
function image_grid (line 109) | def image_grid(imgs, rows, cols):
function get_concat_h (line 121) | def get_concat_h(im1, im2):
function log_validation (line 127) | def log_validation(unet, args, accelerator, weight_dtype, step):
function parse_args (line 200) | def parse_args():
function convert_to_np (line 547) | def convert_to_np(image, resolution):
function main (line 553) | def main():
FILE: train_SDXL_stage_2.py
function ordered_yaml (line 74) | def ordered_yaml():
function image_grid (line 97) | def image_grid(imgs, rows, cols):
function image_grid (line 108) | def image_grid(imgs, rows, cols):
function get_concat_h (line 119) | def get_concat_h(im1, im2):
function log_validation (line 125) | def log_validation(unet, args, accelerator, weight_dtype, step):
function parse_args (line 207) | def parse_args():
function convert_to_np (line 559) | def convert_to_np(image, resolution):
function main (line 565) | def main():
FILE: utils/color_fix.py
function adain_color_fix (line 14) | def adain_color_fix(target: Image, source: Image):
function wavelet_color_fix (line 29) | def wavelet_color_fix(target: Image, source: Image):
function calc_mean_std (line 44) | def calc_mean_std(feat: Tensor, eps=1e-5):
function adaptive_instance_normalization (line 59) | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tens...
function wavelet_blur (line 73) | def wavelet_blur(image: Tensor, radius: int):
function wavelet_decomposition (line 94) | def wavelet_decomposition(image: Tensor, levels=5):
function wavelet_reconstruction (line 108) | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):
FILE: utils/image_process.py
function check_image_size (line 7) | def check_image_size(x, padder_size=8):
function image2tensor (line 24) | def image2tensor(img):
function tensor2image (line 30) | def tensor2image(img):
function create_hdr_effect (line 37) | def create_hdr_effect(original_image, hdr):
FILE: utils/system.py
function torch_gc (line 6) | def torch_gc():
function quantize_8bit (line 13) | def quantize_8bit(unet):
Condensed preview — 100 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (9,433K chars).
[
{
"path": ".gitignore",
"chars": 147,
"preview": "__pycache__/\n*.py[cod]\n*$py.class\n/outputs/\n/input/\n/dataset/\n/save/\n/output/\ncheckpoints/\n/.vs\n.vscode/\n.idea/\nvenv/\n.v"
},
{
"path": "CKPT_PTH.py",
"chars": 272,
"preview": "LLAVA_CLIP_PATH = './checkpoints/CLIP_VIT/'\nLLAVA_MODEL_PATH = './checkpoints/llava_v1.5-13b/llava'\nSDXL_PATH = './check"
},
{
"path": "FaithDiff/create_FaithDiff_model.py",
"chars": 1410,
"preview": "from utils.system import quantize_8bit\nfrom .pipelines.pipeline_FaithDiff_tlc import FaithDiffStableDiffusionXLPipeline\n"
},
{
"path": "FaithDiff/models/bsrnet_arch.py",
"chars": 7258,
"preview": "import functools\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.nn.init as init\nimport "
},
{
"path": "FaithDiff/models/unet_2d_condition_vae_extension.py",
"chars": 32551,
"preview": "# Copyright 2024 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
},
{
"path": "FaithDiff/pipelines/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "FaithDiff/pipelines/pipeline_FaithDiff_tlc.py",
"chars": 71668,
"preview": "import inspect\nimport copy\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nimport torch\nimport torc"
},
{
"path": "FaithDiff/pipelines/pipeline_output.py",
"chars": 1062,
"preview": "from dataclasses import dataclass\nfrom typing import List, Union\n\nimport numpy as np\nimport PIL.Image\n\nfrom diffusers.ut"
},
{
"path": "FaithDiff/training_utils.py",
"chars": 29816,
"preview": "import contextlib\nimport copy\nimport gc\nimport math\nimport random\nfrom typing import Any, Dict, Iterable, List, Optional"
},
{
"path": "LICENSE",
"chars": 1073,
"preview": "MIT License\n\nCopyright (c) [2025] [Junyang Chen]\n\nPermission is hereby granted, free of charge, to any person obtaining "
},
{
"path": "README.md",
"chars": 9272,
"preview": "### (CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolution \n[:\n label_list = [json.loads(q)['label'] for "
},
{
"path": "llava/eval/eval_science_qa.py",
"chars": 3920,
"preview": "import argparse\nimport json\nimport os\nimport re\nimport random\n\n\ndef get_args():\n parser = argparse.ArgumentParser()\n "
},
{
"path": "llava/eval/eval_science_qa_gpt4.py",
"chars": 3675,
"preview": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n "
},
{
"path": "llava/eval/eval_science_qa_gpt4_requery.py",
"chars": 5774,
"preview": "import argparse\nimport json\nimport os\nimport re\nimport random\nfrom collections import defaultdict\n\n\ndef get_args():\n "
},
{
"path": "llava/eval/eval_textvqa.py",
"chars": 2226,
"preview": "import os\nimport argparse\nimport json\nimport re\n\nfrom llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator\n\n\ndef get"
},
{
"path": "llava/eval/generate_webpage_data_from_table.py",
"chars": 4088,
"preview": "\"\"\"Generate json file for webpage.\"\"\"\nimport json\nimport os\nimport re\n\n# models = ['llama', 'alpaca', 'gpt35', 'bard']\nm"
},
{
"path": "llava/eval/m4c_evaluator.py",
"chars": 10265,
"preview": "# Copyright (c) Facebook, Inc. and its affiliates.\nimport re\n\nfrom tqdm import tqdm\n\n\nclass EvalAIAnswerProcessor:\n \""
},
{
"path": "llava/eval/model_qa.py",
"chars": 2430,
"preview": "import argparse\nfrom transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria\nimport torch\nimport os\nim"
},
{
"path": "llava/eval/model_vqa.py",
"chars": 4115,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
},
{
"path": "llava/eval/model_vqa_loader.py",
"chars": 5975,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
},
{
"path": "llava/eval/model_vqa_mmbench.py",
"chars": 6408,
"preview": "import argparse\nimport torch\nimport os\nimport json\nimport pandas as pd\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llav"
},
{
"path": "llava/eval/model_vqa_science.py",
"chars": 4592,
"preview": "import argparse\nimport torch\nimport os\nimport json\nfrom tqdm import tqdm\nimport shortuuid\n\nfrom llava.constants import I"
},
{
"path": "llava/eval/qa_baseline_gpt35.py",
"chars": 2345,
"preview": "\"\"\"Generate answers with GPT-3.5\"\"\"\n# Note: you need to be using OpenAI Python v0.27.0 for the code below to work\nimport"
},
{
"path": "llava/eval/run_llava.py",
"chars": 4443,
"preview": "import argparse\nimport torch\n\nfrom llava.constants import (\n IMAGE_TOKEN_INDEX,\n DEFAULT_IMAGE_TOKEN,\n DEFAULT_"
},
{
"path": "llava/eval/summarize_gpt_review.py",
"chars": 2438,
"preview": "import json\nimport os\nfrom collections import defaultdict\n\nimport numpy as np\n\nimport argparse\n\ndef parse_args():\n pa"
},
{
"path": "llava/eval/table/answer/answer_alpaca-13b.jsonl",
"chars": 57071,
"preview": "{\"question_id\": 1, \"text\": \"Improving time management skills involves setting priorities, breaking tasks into smaller ch"
},
{
"path": "llava/eval/table/answer/answer_bard.jsonl",
"chars": 112274,
"preview": "{\"answer_id\": \"3oW4JY265ZPJGTYi2CgRYF\", \"model_id\": \"bard:20230327\", \"question_id\": 1, \"text\": \"Here are some tips on ho"
},
{
"path": "llava/eval/table/answer/answer_gpt35.jsonl",
"chars": 107603,
"preview": "{\"answer_id\": \"BZGowHM7L3RvtWRktKZjLT\", \"model_id\": \"gpt-3.5-turbo:20230327\", \"question_id\": 1, \"text\": \"Here are some t"
},
{
"path": "llava/eval/table/answer/answer_llama-13b.jsonl",
"chars": 76353,
"preview": "{\"answer_id\": \"J3UA6eGXGyFeUGqGpP3g34\", \"model_id\": \"llama-13b:v1\", \"question_id\": 1, \"text\": \"The following are some st"
},
{
"path": "llava/eval/table/answer/answer_vicuna-13b.jsonl",
"chars": 131904,
"preview": "{\"answer_id\": \"cV4zXygaNP6CXEsgdHMEqz\", \"model_id\": \"vicuna-13b:20230322-clean-lang\", \"question_id\": 1, \"text\": \"Improvi"
},
{
"path": "llava/eval/table/caps_boxes_coco2014_val_80.jsonl",
"chars": 58574,
"preview": "{\"id\": \"000000296284\", \"image\": \"000000296284.jpg\", \"captions\": [\"A donut shop is full of different flavors of donuts.\","
},
{
"path": "llava/eval/table/model.jsonl",
"chars": 681,
"preview": "{\"model_id\": \"vicuna-13b:20230322-clean-lang\", \"model_name\": \"vicuna-13b\", \"model_version\": \"20230322-clean-lang\", \"mode"
},
{
"path": "llava/eval/table/prompt.jsonl",
"chars": 5129,
"preview": "{\"prompt_id\": 1, \"system_prompt\": \"You are a helpful and precise assistant for checking the quality of the answer.\", \"pr"
},
{
"path": "llava/eval/table/question.jsonl",
"chars": 12885,
"preview": "{\"question_id\": 1, \"text\": \"How can I improve my time management skills?\", \"category\": \"generic\"}\n{\"question_id\": 2, \"te"
},
{
"path": "llava/eval/table/results/test_sqa_llava_13b_v0.json",
"chars": 3950324,
"preview": "{\n \"acc\": 90.8983730252299,\n \"correct\": 3855,\n \"count\": 4241,\n \"results\": {\n \"4\": 1,\n \"5\": 1,\n \"11\": 1,\n "
},
{
"path": "llava/eval/table/results/test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json",
"chars": 3830902,
"preview": "{\n \"acc\": 91.08700778118369,\n \"correct\": 3863,\n \"count\": 4241,\n \"results\": {\n \"4\": 1,\n \"5\": 1,\n \"11\": 1,\n "
},
{
"path": "llava/eval/table/review/review_alpaca-13b_vicuna-13b.jsonl",
"chars": 73131,
"preview": "{\"review_id\": \"QM5m5nnioWr8M2LFHsaQvu\", \"question_id\": 1, \"answer1_id\": \"kEL9ifUHDeYuAXzevje2se\", \"answer2_id\": \"cV4zXyg"
},
{
"path": "llava/eval/table/review/review_bard_vicuna-13b.jsonl",
"chars": 73145,
"preview": "{\"review_id\": \"4CeMvEQyE6fKMJwvSLY3P4\", \"question_id\": 1, \"answer1_id\": \"3oW4JY265ZPJGTYi2CgRYF\", \"answer2_id\": \"cV4zXyg"
},
{
"path": "llava/eval/table/review/review_gpt35_vicuna-13b.jsonl",
"chars": 73399,
"preview": "{\"review_id\": \"jyhS7AFj2mrFNqoRXQJDPS\", \"question_id\": 1, \"answer1_id\": \"BZGowHM7L3RvtWRktKZjLT\", \"answer2_id\": \"cV4zXyg"
},
{
"path": "llava/eval/table/review/review_llama-13b_vicuna-13b.jsonl",
"chars": 67249,
"preview": "{\"review_id\": \"WFp5i5yjjFethrgugKTDmX\", \"question_id\": 1, \"answer1_id\": \"J3UA6eGXGyFeUGqGpP3g34\", \"answer2_id\": \"cV4zXyg"
},
{
"path": "llava/eval/table/reviewer.jsonl",
"chars": 604,
"preview": "{\"reviewer_id\": \"gpt-4-0328-default\", \"prompt_id\": 1, \"metadata\": {\"temperature\": 0.2, \"max_tokens\": 1024}, \"description"
},
{
"path": "llava/eval/table/rule.json",
"chars": 9098,
"preview": "{\n \"coding\": {\"role\": \"Assistant\", \"prompt\": \"Your task is to evaluate the coding abilities of the above two assistan"
},
{
"path": "llava/eval/webpage/index.html",
"chars": 7664,
"preview": "<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n <meta charset=\"UTF-8\">\n <meta name=\"viewport\" content=\"width=device-width"
},
{
"path": "llava/eval/webpage/script.js",
"chars": 9967,
"preview": "// Description: Script for the evaluation webpage.\n\nlet currentQuestionIndex = 1;\n\n// Store the model name mapping for l"
},
{
"path": "llava/eval/webpage/styles.css",
"chars": 1822,
"preview": "body {\n font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n background-color: #f8f9fa;\n}\n\n.navbar-dark "
},
{
"path": "llava/llm_agent.py",
"chars": 5579,
"preview": "import torch\nimport os\nimport json\nfrom tqdm import tqdm\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_T"
},
{
"path": "llava/mm_utils.py",
"chars": 9555,
"preview": "from PIL import Image\nfrom io import BytesIO\nimport base64\nimport torch\nimport math\nimport ast\n\nfrom transformers import"
},
{
"path": "llava/model/__init__.py",
"chars": 237,
"preview": "\nfrom .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig\nfrom .language_model.llava_mpt import LlavaM"
},
{
"path": "llava/model/apply_delta.py",
"chars": 1956,
"preview": "\"\"\"\nUsage:\npython3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --de"
},
{
"path": "llava/model/builder.py",
"chars": 8076,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "llava/model/consolidate.py",
"chars": 914,
"preview": "\"\"\"\nUsage:\npython3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate\n"
},
{
"path": "llava/model/language_model/llava_llama.py",
"chars": 5406,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "llava/model/language_model/llava_mistral.py",
"chars": 5386,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "llava/model/language_model/llava_mpt.py",
"chars": 3487,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "llava/model/llava_arch.py",
"chars": 18110,
"preview": "# Copyright 2023 Haotian Liu\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not "
},
{
"path": "llava/model/make_delta.py",
"chars": 2257,
"preview": "\"\"\"\nUsage:\npython3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~"
},
{
"path": "llava/model/multimodal_encoder/builder.py",
"chars": 556,
"preview": "import os\nfrom .clip_encoder import CLIPVisionTower\n\n\ndef build_vision_tower(vision_tower_cfg, **kwargs):\n vision_tow"
},
{
"path": "llava/model/multimodal_encoder/clip_encoder.py",
"chars": 3101,
"preview": "import torch\nimport torch.nn as nn\n\nfrom transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig\nfrom "
},
{
"path": "llava/model/multimodal_projector/builder.py",
"chars": 1437,
"preview": "import torch\nimport torch.nn as nn\nimport re\n\n\nclass IdentityMap(nn.Module):\n def __init__(self):\n super().__i"
},
{
"path": "llava/model/utils.py",
"chars": 927,
"preview": "from transformers import AutoConfig\n\n\ndef auto_upgrade(config):\n cfg = AutoConfig.from_pretrained(config)\n if 'lla"
},
{
"path": "llava/serve/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "llava/serve/cli.py",
"chars": 4808,
"preview": "import argparse\nimport torch\n\nfrom llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN"
},
{
"path": "llava/serve/controller.py",
"chars": 9949,
"preview": "\"\"\"\nA controller manages distributed workers.\nIt sends worker addresses to clients.\n\"\"\"\nimport argparse\nimport asyncio\ni"
},
{
"path": "llava/serve/gradio_web_server.py",
"chars": 18841,
"preview": "import argparse\nimport datetime\nimport json\nimport os\nimport time\n\nimport gradio as gr\nimport requests\n\nfrom llava.conve"
},
{
"path": "llava/serve/model_worker.py",
"chars": 11176,
"preview": "\"\"\"\nA model worker executes the model.\n\"\"\"\nimport argparse\nimport asyncio\nimport json\nimport time\nimport threading\nimpor"
},
{
"path": "llava/serve/register_worker.py",
"chars": 734,
"preview": "\"\"\"\nManually register workers.\n\nUsage:\npython3 -m fastchat.serve.register_worker --controller http://localhost:21001 --w"
},
{
"path": "llava/serve/sglang_worker.py",
"chars": 8678,
"preview": "\"\"\"\nA model worker executes the model.\n\"\"\"\nimport argparse\nimport asyncio\nfrom concurrent.futures import ThreadPoolExecu"
},
{
"path": "llava/serve/test_message.py",
"chars": 2022,
"preview": "import argparse\nimport json\n\nimport requests\n\nfrom llava.conversation import default_conversation\n\n\ndef main():\n if a"
},
{
"path": "llava/train/llama_flash_attn_monkey_patch.py",
"chars": 4404,
"preview": "from typing import Optional, Tuple\nimport warnings\n\nimport torch\n\nimport transformers\nfrom transformers.models.llama.mod"
},
{
"path": "llava/train/llama_xformers_attn_monkey_patch.py",
"chars": 4916,
"preview": "\"\"\"\nDirectly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_a"
},
{
"path": "llava/train/llava_trainer.py",
"chars": 11076,
"preview": "import os\nimport torch\nimport torch.nn as nn\n\nfrom torch.utils.data import Sampler\n\nfrom transformers import Trainer\nfro"
},
{
"path": "llava/train/train.py",
"chars": 38414,
"preview": "# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:\n# Adopted from tatsu-lab@stanford_al"
},
{
"path": "llava/train/train_mem.py",
"chars": 115,
"preview": "from llava.train.train import train\n\nif __name__ == \"__main__\":\n train(attn_implementation=\"flash_attention_2\")\n"
},
{
"path": "llava/train/train_xformers.py",
"chars": 366,
"preview": "# Make it more memory efficient by monkey patching the LLaMA model with xformers attention.\n\n# Need to call this before "
},
{
"path": "llava/utils.py",
"chars": 4003,
"preview": "import datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nimport requests\n\nfrom llava.constants impor"
},
{
"path": "requirements.txt",
"chars": 51,
"preview": "torch==2.4.0\ndiffusers==0.28.0\ntransformers==4.46.1"
},
{
"path": "test.py",
"chars": 5397,
"preview": "import torch.cuda\nimport argparse\nfrom FaithDiff.create_FaithDiff_model import FaithDiff_pipeline\nfrom PIL import Image\n"
},
{
"path": "test_generate_caption.py",
"chars": 2257,
"preview": "import torch.cuda\nimport argparse\nfrom PIL import Image\nfrom llava.llm_agent import LLavaAgent\nfrom CKPT_PTH import LLAV"
},
{
"path": "test_metrics.py",
"chars": 8344,
"preview": "# Image Quality Assessment Script\n# Evaluates metrics like PSNR, SSIM, LPIPS, FID, DISTS, etc., for a set of images.\n\nim"
},
{
"path": "test_wo_llava.py",
"chars": 3993,
"preview": "import torch.cuda\nimport argparse\nfrom FaithDiff.create_FaithDiff_model import FaithDiff_pipeline\nfrom PIL import Image\n"
},
{
"path": "train_SDXL_stage_1.py",
"chars": 41133,
"preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserv"
},
{
"path": "train_SDXL_stage_2.py",
"chars": 42931,
"preview": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 Harutatsu Akiyama and The HuggingFace Inc. team. All rights reserv"
},
{
"path": "train_stage_1.sh",
"chars": 661,
"preview": "accelerate launch --config_file='./dataloader/accelerate_config.yaml' --mixed_precision=\"fp16\" ./train_SDXL_stage_1.py -"
},
{
"path": "train_stage_2.sh",
"chars": 684,
"preview": "accelerate launch --config_file='./dataloader/accelerate_config.yaml' --mixed_precision=\"fp16\" ./train_SDXL_stage_2.py -"
},
{
"path": "utils/color_fix.py",
"chars": 4481,
"preview": "'''\n# --------------------------------------------------------------------------------\n# Color fixed script from Li Yi"
},
{
"path": "utils/image_process.py",
"chars": 2817,
"preview": "from PIL import Image\nimport cv2\nimport numpy as np\nimport torch\n\nfrom utils.system import torch_gc\ndef check_image_size"
},
{
"path": "utils/system.py",
"chars": 1012,
"preview": "import gc\nimport torch\n\nfrom FaithDiff.models.unet_2d_condition_vae_extension import Encoder\n\ndef torch_gc():\n gc.col"
}
]
About this extraction
This page contains the full source code of the JyChen9811/FaithDiff GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 100 files (8.8 MB), approximately 2.3M tokens, and a symbol index with 466 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.