Repository: csslc/CCSR Branch: CCSR-v2.0 Commit: 878f5adf2ba8 Files: 80 Total size: 842.8 KB Directory structure: gitextract_wzfbvs4n/ ├── .idea/ │ ├── CCSR.iml │ ├── inspectionProfiles/ │ │ ├── Project_Default.xml │ │ └── profiles_settings.xml │ ├── modules.xml │ ├── vcs.xml │ └── workspace.xml ├── ADD/ │ ├── dnnlib/ │ │ ├── __init__.py │ │ └── util.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── block.py │ │ ├── dino_head.py │ │ ├── drop_path.py │ │ ├── layer_scale.py │ │ ├── mlp.py │ │ ├── patch_embed.py │ │ └── swiglu_ffn.py │ ├── models/ │ │ ├── discriminator.py │ │ └── vit.py │ ├── th_utils/ │ │ ├── __init__.py │ │ ├── custom_ops.py │ │ ├── misc.py │ │ └── ops/ │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── filtered_lrelu.cpp │ │ ├── filtered_lrelu.cu │ │ ├── filtered_lrelu.h │ │ ├── filtered_lrelu.py │ │ ├── filtered_lrelu_ns.cu │ │ ├── filtered_lrelu_rd.cu │ │ ├── filtered_lrelu_wr.cu │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ └── utils/ │ └── util_net.py ├── LICENSE ├── README.md ├── dataloaders/ │ ├── paired_dataset_txt.py │ ├── params_ccsr.yml │ └── realesrgan.py ├── models/ │ ├── DiffAugment.py │ ├── controlnet.py │ ├── losses/ │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── shared.py │ ├── unet_2d_blocks.py │ ├── unet_2d_condition.py │ └── vit_utils.py ├── myutils/ │ ├── devices.py │ ├── img_util.py │ ├── misc.py │ ├── vaehook.py │ └── wavelet_color_fix.py ├── pipelines/ │ └── pipeline_ccsr.py ├── requirements.txt ├── scripts/ │ ├── get_path.py │ ├── test/ │ │ ├── test_ccsr_multistep.sh │ │ ├── test_ccsr_onestep.sh │ │ └── test_ccsr_tile.sh │ └── train/ │ ├── train_ccsr_stage1.sh │ ├── train_ccsr_stage2.sh │ └── train_controlnet.sh ├── test_ccsr_tile.py ├── train_ccsr_stage1.py ├── train_ccsr_stage2.py ├── train_controlnet.py └── utils/ ├── devices.py ├── img_util.py ├── misc.py ├── vaehook.py └── wavelet_color_fix.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .idea/CCSR.iml ================================================ ================================================ FILE: .idea/inspectionProfiles/Project_Default.xml ================================================ ================================================ FILE: .idea/inspectionProfiles/profiles_settings.xml ================================================ ================================================ FILE: .idea/modules.xml ================================================ ================================================ FILE: .idea/vcs.xml ================================================ ================================================ FILE: .idea/workspace.xml ================================================ 1734539270044 ================================================ FILE: ADD/dnnlib/__init__.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from .util import EasyDict, make_cache_dir_path ================================================ FILE: ADD/dnnlib/util.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Miscellaneous utility classes and functions.""" import ctypes import fnmatch import importlib import inspect import os import sys import types import io import pickle import re import requests import html import hashlib import glob import tempfile import urllib import urllib.request import uuid from typing import Any, List, Tuple, Union, Optional from distutils.util import strtobool import shutil import numpy as np # Util classes # ------------------------------------------------------------------------------------------ class EasyDict(dict): """Convenience class that behaves like a dict but allows access with the attribute syntax.""" def __getattr__(self, name: str) -> Any: try: return self[name] except KeyError: raise AttributeError(name) def __setattr__(self, name: str, value: Any) -> None: self[name] = value def __delattr__(self, name: str) -> None: del self[name] class Logger(object): """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.""" def __init__(self, file_name: Optional[str] = None, file_mode: str = "w", should_flush: bool = True): self.file = None if file_name is not None: self.file = open(file_name, file_mode) self.should_flush = should_flush self.stdout = sys.stdout self.stderr = sys.stderr sys.stdout = self sys.stderr = self def __enter__(self) -> "Logger": return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.close() def write(self, text: Union[str, bytes]) -> None: """Write text to stdout (and a file) and optionally flush.""" if isinstance(text, bytes): text = text.decode() if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash return if self.file is not None: self.file.write(text) self.stdout.write(text) if self.should_flush: self.flush() def flush(self) -> None: """Flush written text to both stdout and a file, if open.""" if self.file is not None: self.file.flush() self.stdout.flush() def close(self) -> None: """Flush, close possible files, and remove stdout/stderr mirroring.""" self.flush() # if using multiple loggers, prevent closing in wrong order if sys.stdout is self: sys.stdout = self.stdout if sys.stderr is self: sys.stderr = self.stderr if self.file is not None: self.file.close() self.file = None # Cache directories # ------------------------------------------------------------------------------------------ _dnnlib_cache_dir = None def set_cache_dir(path: str) -> None: global _dnnlib_cache_dir _dnnlib_cache_dir = path def make_cache_dir_path(*paths: str) -> str: if _dnnlib_cache_dir is not None: return os.path.join(_dnnlib_cache_dir, *paths) if 'DNNLIB_CACHE_DIR' in os.environ: return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths) if 'HOME' in os.environ: return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths) if 'USERPROFILE' in os.environ: return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths) return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths) # Small util functions # ------------------------------------------------------------------------------------------ def format_time(seconds: Union[int, float]) -> str: """Convert the seconds to human readable string with days, hours, minutes and seconds.""" s = int(np.rint(seconds)) if s < 60: return "{0}s".format(s) elif s < 60 * 60: return "{0}m {1:02}s".format(s // 60, s % 60) elif s < 24 * 60 * 60: return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60) else: return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60) def format_time_brief(seconds: Union[int, float]) -> str: """Convert the seconds to human readable string with days, hours, minutes and seconds.""" s = int(np.rint(seconds)) if s < 60: return "{0}s".format(s) elif s < 60 * 60: return "{0}m {1:02}s".format(s // 60, s % 60) elif s < 24 * 60 * 60: return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60) else: return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24) def ask_yes_no(question: str) -> bool: """Ask the user the question until the user inputs a valid answer.""" while True: try: print("{0} [y/n]".format(question)) return strtobool(input().lower()) except ValueError: pass def tuple_product(t: Tuple) -> Any: """Calculate the product of the tuple elements.""" result = 1 for v in t: result *= v return result _str_to_ctype = { "uint8": ctypes.c_ubyte, "uint16": ctypes.c_uint16, "uint32": ctypes.c_uint32, "uint64": ctypes.c_uint64, "int8": ctypes.c_byte, "int16": ctypes.c_int16, "int32": ctypes.c_int32, "int64": ctypes.c_int64, "float32": ctypes.c_float, "float64": ctypes.c_double } def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]: """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.""" type_str = None if isinstance(type_obj, str): type_str = type_obj elif hasattr(type_obj, "__name__"): type_str = type_obj.__name__ elif hasattr(type_obj, "name"): type_str = type_obj.name else: raise RuntimeError("Cannot infer type name from input") assert type_str in _str_to_ctype.keys() my_dtype = np.dtype(type_str) my_ctype = _str_to_ctype[type_str] assert my_dtype.itemsize == ctypes.sizeof(my_ctype) return my_dtype, my_ctype def is_pickleable(obj: Any) -> bool: try: with io.BytesIO() as stream: pickle.dump(obj, stream) return True except: return False # Functionality to import modules/objects by name, and call functions by name # ------------------------------------------------------------------------------------------ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]: """Searches for the underlying module behind the name to some python object. Returns the module and the object name (original name with module part removed).""" # allow convenience shorthands, substitute them by full names obj_name = re.sub("^np.", "numpy.", obj_name) obj_name = re.sub("^tf.", "tensorflow.", obj_name) # list alternatives for (module_name, local_obj_name) parts = obj_name.split(".") name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)] # try each alternative in turn for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError return module, local_obj_name except: pass # maybe some of the modules themselves contain errors? for module_name, _local_obj_name in name_pairs: try: importlib.import_module(module_name) # may raise ImportError except ImportError: if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"): raise # maybe the requested attribute is missing? for module_name, local_obj_name in name_pairs: try: module = importlib.import_module(module_name) # may raise ImportError get_obj_from_module(module, local_obj_name) # may raise AttributeError except ImportError: pass # we are out of luck, but we have no idea why raise ImportError(obj_name) def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any: """Traverses the object name and returns the last (rightmost) python object.""" if obj_name == '': return module obj = module for part in obj_name.split("."): obj = getattr(obj, part) return obj def get_obj_by_name(name: str) -> Any: """Finds the python object with the given name.""" module, obj_name = get_module_from_obj_name(name) return get_obj_from_module(module, obj_name) def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any: """Finds the python object with the given name and calls it as a function.""" assert func_name is not None func_obj = get_obj_by_name(func_name) assert callable(func_obj) return func_obj(*args, **kwargs) def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any: """Finds the python class with the given name and constructs it with the given arguments.""" return call_func_by_name(*args, func_name=class_name, **kwargs) def get_module_dir_by_obj_name(obj_name: str) -> str: """Get the directory path of the module containing the given object name.""" module, _ = get_module_from_obj_name(obj_name) return os.path.dirname(inspect.getfile(module)) def is_top_level_function(obj: Any) -> bool: """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.""" return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__ def get_top_level_function_name(obj: Any) -> str: """Return the fully-qualified name of a top-level function.""" assert is_top_level_function(obj) module = obj.__module__ if module == '__main__': module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0] return module + "." + obj.__name__ # File system helpers # ------------------------------------------------------------------------------------------ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]: """List all files recursively in a given directory while ignoring given file and directory names. Returns list of tuples containing both absolute and relative paths.""" assert os.path.isdir(dir_path) base_name = os.path.basename(os.path.normpath(dir_path)) if ignores is None: ignores = [] result = [] for root, dirs, files in os.walk(dir_path, topdown=True): for ignore_ in ignores: dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)] # dirs need to be edited in-place for d in dirs_to_remove: dirs.remove(d) files = [f for f in files if not fnmatch.fnmatch(f, ignore_)] absolute_paths = [os.path.join(root, f) for f in files] relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths] if add_base_to_relative: relative_paths = [os.path.join(base_name, p) for p in relative_paths] assert len(absolute_paths) == len(relative_paths) result += zip(absolute_paths, relative_paths) return result def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None: """Takes in a list of tuples of (src, dst) paths and copies files. Will create all necessary directories.""" for file in files: target_dir_name = os.path.dirname(file[1]) # will create all intermediate-level directories if not os.path.exists(target_dir_name): os.makedirs(target_dir_name) shutil.copyfile(file[0], file[1]) # URL helpers # ------------------------------------------------------------------------------------------ def is_url(obj: Any, allow_file_urls: bool = False) -> bool: """Determine whether the given object is a valid URL string.""" if not isinstance(obj, str) or not "://" in obj: return False if allow_file_urls and obj.startswith('file://'): return True try: res = requests.compat.urlparse(obj) if not res.scheme or not res.netloc or not "." in res.netloc: return False res = requests.compat.urlparse(requests.compat.urljoin(obj, "/")) if not res.scheme or not res.netloc or not "." in res.netloc: return False except: return False return True def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any: """Download the given URL and return a binary-mode file object to access the data.""" assert num_attempts >= 1 assert not (return_filename and (not cache)) # Doesn't look like an URL scheme so interpret it as a local filename. if not re.match('^[a-z]+://', url): return url if return_filename else open(url, "rb") # Handle file URLs. This code handles unusual file:// patterns that # arise on Windows: # # file:///c:/foo.txt # # which would translate to a local '/c:/foo.txt' filename that's # invalid. Drop the forward slash for such pathnames. # # If you touch this code path, you should test it on both Linux and # Windows. # # Some internet resources suggest using urllib.request.url2pathname() but # but that converts forward slashes to backslashes and this causes # its own set of problems. if url.startswith('file://'): filename = urllib.parse.urlparse(url).path if re.match(r'^/[a-zA-Z]:', filename): filename = filename[1:] return filename if return_filename else open(filename, "rb") assert is_url(url) # Lookup from cache. if cache_dir is None: cache_dir = make_cache_dir_path('downloads') url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() if cache: cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*")) if len(cache_files) == 1: filename = cache_files[0] return filename if return_filename else open(filename, "rb") # Download. url_name = None url_data = None with requests.Session() as session: if verbose: print("Downloading %s ..." % url, end="", flush=True) for attempts_left in reversed(range(num_attempts)): try: with session.get(url) as res: res.raise_for_status() if len(res.content) == 0: raise IOError("No data received") if len(res.content) < 8192: content_str = res.content.decode("utf-8") if "download_warning" in res.headers.get("Set-Cookie", ""): links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] if len(links) == 1: url = requests.compat.urljoin(url, links[0]) raise IOError("Google Drive virus checker nag") if "Google Drive - Quota exceeded" in content_str: raise IOError("Google Drive download quota exceeded -- please try again later") match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) url_name = match[1] if match else url url_data = res.content if verbose: print(" done") break except KeyboardInterrupt: raise except: if not attempts_left: if verbose: print(" failed") raise if verbose: print(".", end="", flush=True) # Save to cache. if cache: safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name) safe_name = safe_name[:min(len(safe_name), 128)] cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name) temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name) os.makedirs(cache_dir, exist_ok=True) with open(temp_file, "wb") as f: f.write(url_data) os.replace(temp_file, cache_file) # atomic if return_filename: return cache_file # Return data as file object. assert not return_filename return io.BytesIO(url_data) ================================================ FILE: ADD/layers/__init__.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. from .dino_head import DINOHead from .mlp import Mlp from .patch_embed import PatchEmbed from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused from .block import NestedTensorBlock from .attention import MemEffAttention ================================================ FILE: ADD/layers/attention.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py import logging import os import warnings from torch import Tensor from torch import nn logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import memory_efficient_attention, unbind XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Attention)") else: warnings.warn("xFormers is disabled (Attention)") raise ImportError except ImportError: XFORMERS_AVAILABLE = False warnings.warn("xFormers is not available (Attention)") class Attention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, ) -> None: super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim, bias=proj_bias) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x class MemEffAttention(Attention): def forward(self, x: Tensor, attn_bias=None) -> Tensor: if not XFORMERS_AVAILABLE: if attn_bias is not None: raise AssertionError("xFormers is required for using nested tensors") return super().forward(x) B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) q, k, v = unbind(qkv, 2) x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x) x = self.proj_drop(x) return x ================================================ FILE: ADD/layers/block.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py import logging import os from typing import Callable, List, Any, Tuple, Dict import warnings import torch from torch import nn, Tensor from .attention import Attention, MemEffAttention from .drop_path import DropPath from .layer_scale import LayerScale from .mlp import Mlp logger = logging.getLogger("dinov2") XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import fmha, scaled_index_add, index_select_cat XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (Block)") else: warnings.warn("xFormers is disabled (Block)") raise ImportError except ImportError: XFORMERS_AVAILABLE = False warnings.warn("xFormers is not available (Block)") class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, proj_bias: bool = True, ffn_bias: bool = True, drop: float = 0.0, attn_drop: float = 0.0, init_values=None, drop_path: float = 0.0, act_layer: Callable[..., nn.Module] = nn.GELU, norm_layer: Callable[..., nn.Module] = nn.LayerNorm, attn_class: Callable[..., nn.Module] = Attention, ffn_layer: Callable[..., nn.Module] = Mlp, ) -> None: super().__init__() # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") self.norm1 = norm_layer(dim) self.attn = attn_class( dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, attn_drop=attn_drop, proj_drop=drop, ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = ffn_layer( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, bias=ffn_bias, ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.sample_drop_ratio = drop_path def forward(self, x: Tensor) -> Tensor: def attn_residual_func(x: Tensor) -> Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) if self.training and self.sample_drop_ratio > 0.1: # the overhead is compensated only for a drop path rate larger than 0.1 x = drop_add_residual_stochastic_depth( x, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) x = drop_add_residual_stochastic_depth( x, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, ) elif self.training and self.sample_drop_ratio > 0.0: x = x + self.drop_path1(attn_residual_func(x)) x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 else: x = x + attn_residual_func(x) x = x + ffn_residual_func(x) return x def drop_add_residual_stochastic_depth( x: Tensor, residual_func: Callable[[Tensor], Tensor], sample_drop_ratio: float = 0.0, ) -> Tensor: # 1) extract subset using permutation b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] x_subset = x[brange] # 2) apply residual_func to get residual residual = residual_func(x_subset) x_flat = x.flatten(1) residual = residual.flatten(1) residual_scale_factor = b / sample_subset_size # 3) add the residual x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) return x_plus_residual.view_as(x) def get_branges_scales(x, sample_drop_ratio=0.0): b, n, d = x.shape sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) brange = (torch.randperm(b, device=x.device))[:sample_subset_size] residual_scale_factor = b / sample_subset_size return brange, residual_scale_factor def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): if scaling_vector is None: x_flat = x.flatten(1) residual = residual.flatten(1) x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) else: x_plus_residual = scaled_index_add( x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor ) return x_plus_residual attn_bias_cache: Dict[Tuple, Any] = {} def get_attn_bias_and_cat(x_list, branges=None): """ this will perform the index select, cat the tensors, and provide the attn_bias from cache """ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) if all_shapes not in attn_bias_cache.keys(): seqlens = [] for b, x in zip(batch_sizes, x_list): for _ in range(b): seqlens.append(x.shape[1]) attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) attn_bias._batch_sizes = batch_sizes attn_bias_cache[all_shapes] = attn_bias if branges is not None: cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) else: tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) cat_tensors = torch.cat(tensors_bs1, dim=1) return attn_bias_cache[all_shapes], cat_tensors def drop_add_residual_stochastic_depth_list( x_list: List[Tensor], residual_func: Callable[[Tensor, Any], Tensor], sample_drop_ratio: float = 0.0, scaling_vector=None, ) -> Tensor: # 1) generate random set of indices for dropping samples in the batch branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] branges = [s[0] for s in branges_scales] residual_scale_factors = [s[1] for s in branges_scales] # 2) get attention bias and index+concat the tensors attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) # 3) apply residual_func to get residual, and split the result residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore outputs = [] for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) return outputs class NestedTensorBlock(Block): def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: """ x_list contains a list of tensors to nest together and run """ assert isinstance(self.attn, MemEffAttention) if self.training and self.sample_drop_ratio > 0.0: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.attn(self.norm1(x), attn_bias=attn_bias) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.mlp(self.norm2(x)) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=attn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, ) x_list = drop_add_residual_stochastic_depth_list( x_list, residual_func=ffn_residual_func, sample_drop_ratio=self.sample_drop_ratio, scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, ) return x_list else: def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) attn_bias, x = get_attn_bias_and_cat(x_list) x = x + attn_residual_func(x, attn_bias=attn_bias) x = x + ffn_residual_func(x) return attn_bias.split(x) def forward(self, x_or_x_list): if isinstance(x_or_x_list, Tensor): return super().forward(x_or_x_list) elif isinstance(x_or_x_list, list): if not XFORMERS_AVAILABLE: raise AssertionError("xFormers is required for using nested tensors") return self.forward_nested(x_or_x_list) else: raise AssertionError ================================================ FILE: ADD/layers/dino_head.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import torch import torch.nn as nn from torch.nn.init import trunc_normal_ from torch.nn.utils import weight_norm class DINOHead(nn.Module): def __init__( self, in_dim, out_dim, use_bn=False, nlayers=3, hidden_dim=2048, bottleneck_dim=256, mlp_bias=True, ): super().__init__() nlayers = max(nlayers, 1) self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) self.apply(self._init_weights) self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) self.last_layer.weight_g.data.fill_(1) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.mlp(x) eps = 1e-6 if x.dtype == torch.float16 else 1e-12 x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) x = self.last_layer(x) return x def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): if nlayers == 1: return nn.Linear(in_dim, bottleneck_dim, bias=bias) else: layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) for _ in range(nlayers - 2): layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) if use_bn: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) return nn.Sequential(*layers) ================================================ FILE: ADD/layers/drop_path.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py from torch import nn def drop_path(x, drop_prob: float = 0.0, training: bool = False): if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = x.new_empty(shape).bernoulli_(keep_prob) if keep_prob > 0.0: random_tensor.div_(keep_prob) output = x * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) ================================================ FILE: ADD/layers/layer_scale.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 from typing import Union import torch from torch import Tensor from torch import nn class LayerScale(nn.Module): def __init__( self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) def forward(self, x: Tensor) -> Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma ================================================ FILE: ADD/layers/mlp.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py from typing import Callable, Optional from torch import Tensor, nn class Mlp(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = nn.GELU, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop = nn.Dropout(drop) def forward(self, x: Tensor) -> Tensor: x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x ================================================ FILE: ADD/layers/patch_embed.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py from typing import Callable, Optional, Tuple, Union from torch import Tensor import torch.nn as nn def make_2tuple(x): if isinstance(x, tuple): assert len(x) == 2 return x assert isinstance(x, int) return (x, x) class PatchEmbed(nn.Module): """ 2D image to patch embedding: (B,C,H,W) -> (B,N,D) Args: img_size: Image size. patch_size: Patch token size. in_chans: Number of input image channels. embed_dim: Number of linear projection output channels. norm_layer: Normalization layer. """ def __init__( self, img_size: Union[int, Tuple[int, int]] = 224, patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, embed_dim: int = 768, norm_layer: Optional[Callable] = None, flatten_embedding: bool = True, ) -> None: super().__init__() image_HW = make_2tuple(img_size) patch_HW = make_2tuple(patch_size) patch_grid_size = ( image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1], ) self.img_size = image_HW self.patch_size = patch_HW self.patches_resolution = patch_grid_size self.num_patches = patch_grid_size[0] * patch_grid_size[1] #self.in_chans = in_chans self.embed_dim = embed_dim self.flatten_embedding = flatten_embedding self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x: Tensor) -> Tensor: _, _, H, W = x.shape patch_H, patch_W = self.patch_size assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" x = self.proj(x) # B C H W H, W = x.size(2), x.size(3) x = x.flatten(2).transpose(1, 2) # B HW C x = self.norm(x) if not self.flatten_embedding: x = x.reshape(-1, H, W, self.embed_dim) # B H W C return x #def flops(self) -> float: #Ho, Wo = self.patches_resolution #flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) #if self.norm is not None: # flops += Ho * Wo * self.embed_dim #return flops ================================================ FILE: ADD/layers/swiglu_ffn.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. import os from typing import Callable, Optional import warnings from torch import Tensor, nn import torch.nn.functional as F class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x: Tensor) -> Tensor: x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None try: if XFORMERS_ENABLED: from xformers.ops import SwiGLU XFORMERS_AVAILABLE = True warnings.warn("xFormers is available (SwiGLU)") else: warnings.warn("xFormers is disabled (SwiGLU)") raise ImportError except ImportError: SwiGLU = SwiGLUFFN XFORMERS_AVAILABLE = False warnings.warn("xFormers is not available (SwiGLU)") class SwiGLUFFNFused(SwiGLU): def __init__( self, in_features: int, hidden_features: Optional[int] = None, out_features: Optional[int] = None, act_layer: Callable[..., nn.Module] = None, drop: float = 0.0, bias: bool = True, ) -> None: out_features = out_features or in_features hidden_features = hidden_features or in_features hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 super().__init__( in_features=in_features, hidden_features=hidden_features, out_features=out_features, bias=bias, ) ================================================ FILE: ADD/models/discriminator.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """ Projected discriminator architecture from "StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis". """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.spectral_norm import SpectralNorm from torchvision.transforms import RandomCrop, Normalize import timm from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from ADD.th_utils import misc from models.shared import ResidualBlock, FullyConnectedLayer from models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone from models.DiffAugment import DiffAugment from ADD.utils.util_net import reload_model_ from functools import partial class SpectralConv1d(nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) class BatchNormLocal(nn.Module): def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5): super().__init__() self.virtual_bs = virtual_bs self.eps = eps self.affine = affine if self.affine: self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) def forward(self, x: torch.Tensor) -> torch.Tensor: shape = x.size() # Reshape batch into groups. G = np.ceil(x.size(0)/self.virtual_bs).astype(int) x = x.view(G, -1, x.size(-2), x.size(-1)) # Calculate stats. mean = x.mean([1, 3], keepdim=True) var = x.var([1, 3], keepdim=True, unbiased=False) x = (x - mean) / (torch.sqrt(var + self.eps)) if self.affine: x = x * self.weight[None, :, None] + self.bias[None, :, None] return x.view(shape) def make_block(channels: int, kernel_size: int) -> nn.Module: return nn.Sequential( SpectralConv1d( channels, channels, kernel_size = kernel_size, padding = kernel_size//2, padding_mode = 'circular', ), #BatchNormLocal(channels), nn.GroupNorm(4, channels), nn.LeakyReLU(0.2, True), ) class DiscHead(nn.Module): def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64): super().__init__() self.channels = channels self.c_dim = c_dim self.cmap_dim = cmap_dim self.main = nn.Sequential( make_block(channels, kernel_size=1), ResidualBlock(make_block(channels, kernel_size=9)) ) if self.c_dim > 0: self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim) self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0) else: self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: h = self.main(x) out = self.cls(h) if self.c_dim > 0: cmap = self.cmapper(c).unsqueeze(-1) out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim)) return out class DINO(torch.nn.Module): def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True): super().__init__() self.n_hooks = len(hooks) + int(hook_patch) self.model = make_vit_backbone( timm.create_model('vit_small_patch16_224.dino', pretrained=False), patch_size=[16,16], hooks=hooks, hook_patch=hook_patch, ) reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth')) self.model = self.model.eval().requires_grad_(False) self.img_resolution = self.model.model.patch_embed.img_size[0] self.embed_dim = self.model.model.embed_dim self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) def forward(self, x: torch.Tensor) -> torch.Tensor: ''' input: x in [0, 1]; output: dict of activations ''' x = F.interpolate(x, self.img_resolution, mode='area') x = self.norm(x) features = forward_vit(self.model, x) return features class ProjectedDiscriminator(nn.Module): def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5): super().__init__() self.c_dim = c_dim self.diffaug = diffaug self.p_crop = p_crop self.dino = DINO() heads = [] for i in range(self.dino.n_hooks): heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)], self.heads = nn.ModuleDict(heads) def train(self, mode: bool = True): self.dino = self.dino.train(False) self.heads = self.heads.train(mode) return self def eval(self): return self.train(False) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # Apply augmentation (x in [-1, 1]). if self.diffaug: x = DiffAugment(x, policy='translation,cutout') # Transform to [0, 1]. x = x.add(1).div(2) # Take crops with probablity p_crop if the image is larger. if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop: x = RandomCrop(self.dino.img_resolution)(x) # Forward pass through DINO ViT. features = self.dino(x) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): features[k].requires_grad_(True) logits.append(head(features[k], c).view(x.size(0), -1)) #logits = torch.cat(logits, dim=1) return logits, features ================================================ FILE: ADD/models/vit.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the Apache License, Version 2.0 # found in the LICENSE file in the root directory of this source tree. # References: # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py from functools import partial import math import logging from typing import Sequence, Tuple, Union, Callable import torch import torch.nn as nn import torch.utils.checkpoint from torch.nn.init import trunc_normal_ from ADD.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block logger = logging.getLogger("dinov2") def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: if not depth_first and include_root: fn(module=module, name=name) for child_name, child_module in module.named_children(): child_name = ".".join((name, child_name)) if name else child_name named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) if depth_first and include_root: fn(module=module, name=name) return module class BlockChunk(nn.ModuleList): def forward(self, x): for b in self: x = b(x) return x class DinoVisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True, ffn_bias=True, proj_bias=True, drop_path_rate=0.0, drop_path_uniform=False, init_values=None, # for layerscale: None or 0 => no layerscale embed_layer=PatchEmbed, act_layer=nn.GELU, block_fn=Block, ffn_layer="mlp", block_chunks=1, num_register_tokens=0, interpolate_antialias=False, interpolate_offset=0.1, ): """ Args: img_size (int, tuple): input image size patch_size (int, tuple): patch size in_chans (int): number of input channels embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True proj_bias (bool): enable bias for proj in attn if True ffn_bias (bool): enable bias for ffn if True drop_path_rate (float): stochastic depth rate drop_path_uniform (bool): apply uniform drop rate across blocks weight_init (str): weight init scheme init_values (float): layer-scale init values embed_layer (nn.Module): patch embedding layer act_layer (nn.Module): MLP activation layer block_fn (nn.Module): transformer block class ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" block_chunks: (int) split block sequence into block_chunks units for FSDP wrap num_register_tokens: (int) number of extra cls tokens (so-called "registers") interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings """ super().__init__() norm_layer = partial(nn.LayerNorm, eps=1e-6) self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_tokens = 1 self.n_blocks = depth self.num_heads = num_heads self.patch_size = patch_size self.num_register_tokens = num_register_tokens self.interpolate_antialias = interpolate_antialias self.interpolate_offset = interpolate_offset self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) assert num_register_tokens >= 0 self.register_tokens = ( nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None ) if drop_path_uniform is True: dpr = [drop_path_rate] * depth else: dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule if ffn_layer == "mlp": logger.info("using MLP layer as FFN") ffn_layer = Mlp elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": logger.info("using SwiGLU layer as FFN") ffn_layer = SwiGLUFFNFused elif ffn_layer == "identity": logger.info("using Identity layer as FFN") def f(*args, **kwargs): return nn.Identity() ffn_layer = f else: raise NotImplementedError blocks_list = [ block_fn( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, ffn_layer=ffn_layer, init_values=init_values, ) for i in range(depth) ] if block_chunks > 0: self.chunked_blocks = True chunked_blocks = [] chunksize = depth // block_chunks for i in range(0, depth, chunksize): # this is to keep the block index consistent if we chunk the block list chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) else: self.chunked_blocks = False self.blocks = nn.ModuleList(blocks_list) self.norm = norm_layer(embed_dim) self.head = nn.Identity() self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) self.init_weights() def init_weights(self): trunc_normal_(self.pos_embed, std=0.02) nn.init.normal_(self.cls_token, std=1e-6) if self.register_tokens is not None: nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 if npatch == N and w == h: return self.pos_embed pos_embed = self.pos_embed.float() class_pos_embed = pos_embed[:, 0] patch_pos_embed = pos_embed[:, 1:] dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset sqrt_N = math.sqrt(N) sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), scale_factor=(sx, sy), mode="bicubic", antialias=self.interpolate_antialias, ) assert int(w0) == patch_pos_embed.shape[-2] assert int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): B, nc, w, h = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = x + self.interpolate_pos_encoding(x, w, h) if self.register_tokens is not None: x = torch.cat( ( x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:], ), dim=1, ) return x def forward_features_list(self, x_list, masks_list): x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] for blk in self.blocks: x = blk(x) all_x = x output = [] for x, masks in zip(all_x, masks_list): x_norm = self.norm(x) output.append( { "x_norm_clstoken": x_norm[:, 0], "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], "x_prenorm": x, "masks": masks, } ) return output def forward_features(self, x, masks=None): fea_list = [] counter = 0 if isinstance(x, list): return self.forward_features_list(x, masks) x = self.prepare_tokens_with_masks(x, masks) fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1)) for blk in self.blocks: x = blk(x) counter += 1 if counter % 3 == 0: fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1)) x_norm = self.norm(x) return fea_list, x_norm[:, 0] def _get_intermediate_layers_not_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) # If n is an int, take the n last blocks. If it's a list, take them output, total_block_len = [], len(self.blocks) blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for i, blk in enumerate(self.blocks): x = blk(x) if i in blocks_to_take: output.append(x) assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def _get_intermediate_layers_chunked(self, x, n=1): x = self.prepare_tokens_with_masks(x) output, i, total_block_len = [], 0, len(self.blocks[-1]) # If n is an int, take the n last blocks. If it's a list, take them blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n for block_chunk in self.blocks: for blk in block_chunk[i:]: # Passing the nn.Identity() x = blk(x) if i in blocks_to_take: output.append(x) i += 1 assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" return output def get_intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, # Layers or n last layers to take reshape: bool = False, return_class_token: bool = False, norm=True, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: if self.chunked_blocks: outputs = self._get_intermediate_layers_chunked(x, n) else: outputs = self._get_intermediate_layers_not_chunked(x, n) if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() for out in outputs ] if return_class_token: return tuple(zip(outputs, class_tokens)) return tuple(outputs) def forward(self, *args, is_training=False, **kwargs): ret = self.forward_features(*args, **kwargs) if is_training: return ret else: return ret#self.head(ret["x_norm_clstoken"]) def init_weights_vit_timm(module: nn.Module, name: str = ""): """ViT weight initialization, original timm impl (for reproducibility)""" if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def vit_small(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model def vit_large(patch_size=16, num_register_tokens=0, **kwargs): model = DinoVisionTransformer( patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, block_fn=partial(Block, attn_class=MemEffAttention), num_register_tokens=num_register_tokens, **kwargs, ) return model # net = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0) # prefile = torch.load('../weights/dinov2_vits14_pretrain.pth') # net.load_state_dict(prefile, True) # out = net(torch.rand(1, 3, 518, 518)) # print(out.shape) # net = vit_large(patch_size=14, img_size=526, block_chunks=0, init_values=1.0, num_register_tokens=4) # prefile = torch.load('../weights/dinov2_vitl14_reg4_pretrain.pth') # net.load_state_dict(prefile, True) # out = net(torch.rand(1, 3, 70, 70)) # print(out.shape) ================================================ FILE: ADD/th_utils/__init__.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: ADD/th_utils/custom_ops.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import glob import hashlib import importlib import os import re import shutil import uuid import torch import torch.utils.cpp_extension from torch.utils.file_baton import FileBaton #---------------------------------------------------------------------------- # Global options. verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' #---------------------------------------------------------------------------- # Internal helper funcs. def _find_compiler_bindir(): patterns = [ 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 'C:/Program Files*/Microsoft Visual Studio */vc/bin', ] for pattern in patterns: matches = sorted(glob.glob(pattern)) if len(matches): return matches[-1] return None #---------------------------------------------------------------------------- def _get_mangled_gpu_name(): name = torch.cuda.get_device_name().lower() out = [] for c in name: if re.match('[a-z0-9_-]+', c): out.append(c) else: out.append('-') return ''.join(out) #---------------------------------------------------------------------------- # Main entry point for compiling and loading C++/CUDA plugins. _cached_plugins = dict() def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): assert verbosity in ['none', 'brief', 'full'] if headers is None: headers = [] if source_dir is not None: sources = [os.path.join(source_dir, fname) for fname in sources] headers = [os.path.join(source_dir, fname) for fname in headers] # Already cached? if module_name in _cached_plugins: return _cached_plugins[module_name] # Print status. if verbosity == 'full': print(f'Setting up PyTorch plugin "{module_name}"...') elif verbosity == 'brief': print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) verbose_build = (verbosity == 'full') # Compile and load. try: # pylint: disable=too-many-nested-blocks # Make sure we can find the necessary compiler binaries. if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: compiler_bindir = _find_compiler_bindir() if compiler_bindir is None: raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') os.environ['PATH'] += ';' + compiler_bindir # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either # break the build or unnecessarily restrict what's available to nvcc. # Unset it to let nvcc decide based on what's available on the # machine. os.environ['TORCH_CUDA_ARCH_LIST'] = '' # Incremental build md5sum trickery. Copies all the input source files # into a cached build directory under a combined md5 digest of the input # source files. Copying is done only if the combined digest has changed. # This keeps input file timestamps and filenames the same as in previous # extension builds, allowing for fast incremental rebuilds. # # This optimization is done only in case all the source files reside in # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR # environment variable is set (we take this as a signal that the user # actually cares about this.) # # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work # around the *.cu dependency bug in ninja config. # all_source_files = sorted(sources + headers) all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): # Compute combined hash digest for all source files. hash_md5 = hashlib.md5() for src in all_source_files: with open(src, 'rb') as f: hash_md5.update(f.read()) # Select cached build directory name. source_digest = hash_md5.hexdigest() build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') if not os.path.isdir(cached_build_dir): tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' os.makedirs(tmpdir) for src in all_source_files: shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) try: os.replace(tmpdir, cached_build_dir) # atomic except OSError: # source directory already exists, delete tmpdir and its contents. shutil.rmtree(tmpdir) if not os.path.isdir(cached_build_dir): raise # Compile. cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, verbose=verbose_build, sources=cached_sources, **build_kwargs) else: torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) # Load. module = importlib.import_module(module_name) except: if verbosity == 'brief': print('Failed!') raise # Print status and add to cache dict. if verbosity == 'full': print(f'Done setting up PyTorch plugin "{module_name}".') elif verbosity == 'brief': print('Done.') _cached_plugins[module_name] = module return module #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/misc.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import re import contextlib import numpy as np import torch import warnings import ADD.dnnlib as dnnlib #---------------------------------------------------------------------------- # Cached construction of constant tensors. Avoids CPU=>GPU copy when the # same constant is used multiple times. _constant_cache = dict() def constant(value, shape=None, dtype=None, device=None, memory_format=None): value = np.asarray(value) if shape is not None: shape = tuple(shape) if dtype is None: dtype = torch.get_default_dtype() if device is None: device = torch.device('cpu') if memory_format is None: memory_format = torch.contiguous_format key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) tensor = _constant_cache.get(key, None) if tensor is None: tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) if shape is not None: tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) tensor = tensor.contiguous(memory_format=memory_format) _constant_cache[key] = tensor return tensor #---------------------------------------------------------------------------- # Replace NaN/Inf with specified numerical values. try: nan_to_num = torch.nan_to_num # 1.8.0a0 except AttributeError: def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin assert isinstance(input, torch.Tensor) if posinf is None: posinf = torch.finfo(input.dtype).max if neginf is None: neginf = torch.finfo(input.dtype).min assert nan == 0 return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) #---------------------------------------------------------------------------- # Symbolic assert. try: symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access except AttributeError: symbolic_assert = torch.Assert # 1.7.0 #---------------------------------------------------------------------------- # Context manager to temporarily suppress known warnings in torch.jit.trace(). # Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672 @contextlib.contextmanager def suppress_tracer_warnings(): flt = ('ignore', None, torch.jit.TracerWarning, None, 0) warnings.filters.insert(0, flt) yield warnings.filters.remove(flt) #---------------------------------------------------------------------------- # Assert that the shape of a tensor matches the given list of integers. # None indicates that the size of a dimension is allowed to vary. # Performs symbolic assertion when used in torch.jit.trace(). def assert_shape(tensor, ref_shape): if tensor.ndim != len(ref_shape): raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): if ref_size is None: pass elif isinstance(ref_size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') elif isinstance(size, torch.Tensor): with suppress_tracer_warnings(): # as_tensor results are registered as constants symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') elif size != ref_size: raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') #---------------------------------------------------------------------------- # Function decorator that calls torch.autograd.profiler.record_function(). def profiled_function(fn): def decorator(*args, **kwargs): with torch.autograd.profiler.record_function(fn.__name__): return fn(*args, **kwargs) decorator.__name__ = fn.__name__ return decorator #---------------------------------------------------------------------------- # Sampler for torch.utils.data.DataLoader that loops over the dataset # indefinitely, shuffling items as it goes. class InfiniteSampler(torch.utils.data.Sampler): def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): assert len(dataset) > 0 assert num_replicas > 0 assert 0 <= rank < num_replicas assert 0 <= window_size <= 1 super().__init__(dataset) self.dataset = dataset self.rank = rank self.num_replicas = num_replicas self.shuffle = shuffle self.seed = seed self.window_size = window_size def __iter__(self): order = np.arange(len(self.dataset)) rnd = None window = 0 if self.shuffle: rnd = np.random.RandomState(self.seed) rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) idx = 0 while True: i = idx % order.size if idx % self.num_replicas == self.rank: yield order[i] if window >= 2: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1 #---------------------------------------------------------------------------- # Utilities for operating with torch.nn.Module parameters and buffers. def spectral_to_cpu(model: torch.nn.Module): def wrapped_in_spectral(m): return hasattr(m, 'weight_v') children = get_children(model) for child in children: if wrapped_in_spectral(child): child.weight = child.weight.cpu() return model def get_children(model: torch.nn.Module): children = list(model.children()) flatt_children = [] if children == []: return model else: for child in children: try: flatt_children.extend(get_children(child)) except TypeError: flatt_children.append(get_children(child)) return flatt_children def params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.parameters()) + list(module.buffers()) def named_params_and_buffers(module): assert isinstance(module, torch.nn.Module) return list(module.named_parameters()) + list(module.named_buffers()) def copy_params_and_buffers(src_module, dst_module, require_all=False): assert isinstance(src_module, torch.nn.Module) assert isinstance(dst_module, torch.nn.Module) src_tensors = dict(named_params_and_buffers(src_module)) for name, tensor in named_params_and_buffers(dst_module): assert (name in src_tensors) or (not require_all) if name in src_tensors: tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) #---------------------------------------------------------------------------- # Context manager for easily enabling/disabling DistributedDataParallel # synchronization. @contextlib.contextmanager def ddp_sync(module, sync): assert isinstance(module, torch.nn.Module) if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): yield else: with module.no_sync(): yield #---------------------------------------------------------------------------- # Check DistributedDataParallel consistency across processes. def check_ddp_consistency(module, ignore_regex=None): assert isinstance(module, torch.nn.Module) for name, tensor in named_params_and_buffers(module): fullname = type(module).__name__ + '.' + name if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): continue tensor = tensor.detach() if tensor.is_floating_point(): tensor = nan_to_num(tensor) other = tensor.clone() torch.distributed.broadcast(tensor=other, src=0) assert (tensor == other).all(), fullname #---------------------------------------------------------------------------- # Print summary table of module hierarchy. def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): assert isinstance(module, torch.nn.Module) assert not isinstance(module, torch.jit.ScriptModule) assert isinstance(inputs, (tuple, list)) # Register hooks. entries = [] nesting = [0] def pre_hook(_mod, _inputs): nesting[0] += 1 def post_hook(mod, _inputs, outputs): nesting[0] -= 1 if nesting[0] <= max_nesting: outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] outputs = [t for t in outputs if isinstance(t, torch.Tensor)] entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] # Run module. outputs = module(*inputs) for hook in hooks: hook.remove() # Identify unique outputs, parameters, and buffers. tensors_seen = set() for e in entries: e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} # Filter out redundant entries. if skip_redundant: entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] # Construct table. rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] rows += [['---'] * len(rows[0])] param_total = 0 buffer_total = 0 submodule_names = {mod: name for name, mod in module.named_modules()} for e in entries: name = '' if e.mod is module else submodule_names[e.mod] param_size = sum(t.numel() for t in e.unique_params) buffer_size = sum(t.numel() for t in e.unique_buffers) output_shapes = [str(list(t.shape)) for t in e.outputs] output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] rows += [[ name + (':0' if len(e.outputs) >= 2 else ''), str(param_size) if param_size else '-', str(buffer_size) if buffer_size else '-', (output_shapes + ['-'])[0], (output_dtypes + ['-'])[0], ]] for idx in range(1, len(e.outputs)): rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] param_total += param_size buffer_total += buffer_size rows += [['---'] * len(rows[0])] rows += [['Total', str(param_total), str(buffer_total), '-', '-']] # Print table. widths = [max(len(cell) for cell in column) for column in zip(*rows)] print() for row in rows: print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) print() return outputs ================================================ FILE: ADD/th_utils/ops/__init__.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. # empty ================================================ FILE: ADD/th_utils/ops/bias_act.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "bias_act.h" //------------------------------------------------------------------------ static bool has_same_layout(torch::Tensor x, torch::Tensor y) { if (x.dim() != y.dim()) return false; for (int64_t i = 0; i < x.dim(); i++) { if (x.size(i) != y.size(i)) return false; if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) return false; } return true; } //------------------------------------------------------------------------ static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(b.dim() == 1, "b must have rank 1"); TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); TORCH_CHECK(grad >= 0, "grad must be non-negative"); // Validate layout. TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); torch::Tensor y = torch::empty_like(x); TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); // Initialize CUDA kernel parameters. bias_act_kernel_params p; p.x = x.data_ptr(); p.b = (b.numel()) ? b.data_ptr() : NULL; p.xref = (xref.numel()) ? xref.data_ptr() : NULL; p.yref = (yref.numel()) ? yref.data_ptr() : NULL; p.dy = (dy.numel()) ? dy.data_ptr() : NULL; p.y = y.data_ptr(); p.grad = grad; p.act = act; p.alpha = alpha; p.gain = gain; p.clamp = clamp; p.sizeX = (int)x.numel(); p.sizeB = (int)b.numel(); p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; // Choose CUDA kernel. void* kernel; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { kernel = choose_bias_act_kernel(p); }); TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); // Launch CUDA kernel. p.loopX = 4; int blockSize = 4 * 32; int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("bias_act", &bias_act); } //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/bias_act.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "bias_act.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; //------------------------------------------------------------------------ // CUDA kernel. template __global__ void bias_act_kernel(bias_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; int G = p.grad; scalar_t alpha = (scalar_t)p.alpha; scalar_t gain = (scalar_t)p.gain; scalar_t clamp = (scalar_t)p.clamp; scalar_t one = (scalar_t)1; scalar_t two = (scalar_t)2; scalar_t expRange = (scalar_t)80; scalar_t halfExpRange = (scalar_t)40; scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; // Loop over elements. int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) { // Load. scalar_t x = (scalar_t)((const T*)p.x)[xi]; scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; scalar_t yy = (gain != 0) ? yref / gain : 0; scalar_t y = 0; // Apply bias. ((G == 0) ? x : xref) += b; // linear if (A == 1) { if (G == 0) y = x; if (G == 1) y = x; } // relu if (A == 2) { if (G == 0) y = (x > 0) ? x : 0; if (G == 1) y = (yy > 0) ? x : 0; } // lrelu if (A == 3) { if (G == 0) y = (x > 0) ? x : x * alpha; if (G == 1) y = (yy > 0) ? x : x * alpha; } // tanh if (A == 4) { if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } if (G == 1) y = x * (one - yy * yy); if (G == 2) y = x * (one - yy * yy) * (-two * yy); } // sigmoid if (A == 5) { if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); if (G == 1) y = x * yy * (one - yy); if (G == 2) y = x * yy * (one - yy) * (one - two * yy); } // elu if (A == 6) { if (G == 0) y = (x >= 0) ? x : exp(x) - one; if (G == 1) y = (yy >= 0) ? x : x * (yy + one); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); } // selu if (A == 7) { if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); } // softplus if (A == 8) { if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); if (G == 1) y = x * (one - exp(-yy)); if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } } // swish if (A == 9) { if (G == 0) y = (x < -expRange) ? 0 : x / (exp(-x) + one); else { scalar_t c = exp(xref); scalar_t d = c + one; if (G == 1) y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); else y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; } } // Apply gain. y *= gain * dy; // Clamp. if (clamp >= 0) { if (G == 0) y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; else y = (yref > -clamp & yref < clamp) ? y : 0; } // Store. ((T*)p.y)[xi] = (T)y; } } //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p) { if (p.act == 1) return (void*)bias_act_kernel; if (p.act == 2) return (void*)bias_act_kernel; if (p.act == 3) return (void*)bias_act_kernel; if (p.act == 4) return (void*)bias_act_kernel; if (p.act == 5) return (void*)bias_act_kernel; if (p.act == 6) return (void*)bias_act_kernel; if (p.act == 7) return (void*)bias_act_kernel; if (p.act == 8) return (void*)bias_act_kernel; if (p.act == 9) return (void*)bias_act_kernel; return NULL; } //------------------------------------------------------------------------ // Template specializations. template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); template void* choose_bias_act_kernel (const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/bias_act.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. //------------------------------------------------------------------------ // CUDA kernel parameters. struct bias_act_kernel_params { const void* x; // [sizeX] const void* b; // [sizeB] or NULL const void* xref; // [sizeX] or NULL const void* yref; // [sizeX] or NULL const void* dy; // [sizeX] or NULL void* y; // [sizeX] int grad; int act; float alpha; float gain; float clamp; int sizeX; int sizeB; int stepB; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template void* choose_bias_act_kernel(const bias_act_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/bias_act.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient bias and activation.""" import os import numpy as np import torch import ADD.dnnlib as dnnlib from .. import custom_ops from .. import misc #---------------------------------------------------------------------------- activation_funcs = { 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), } #---------------------------------------------------------------------------- _plugin = None _null_tensor = torch.empty([0]) def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='bias_act_plugin', sources=['bias_act.cpp', 'bias_act.cu'], headers=['bias_act.h'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True #---------------------------------------------------------------------------- def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, and scales the result by `gain`. Each of the steps is optional. In most cases, the fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports first and second order gradients, but not third order gradients. Args: x: Input activation tensor. Can be of any shape. b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The shape must be known, and it must match the dimension of `x` corresponding to `dim`. dim: The dimension in `x` corresponding to the elements of `b`. The value of `dim` is ignored if `b` is not specified. act: Name of the activation function to evaluate, or `"linear"` to disable. Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. See `activation_funcs` for a full list. `None` is not allowed. alpha: Shape parameter for the activation function, or `None` to use the default. gain: Scaling factor for the output tensor, or `None` to use default. See `activation_funcs` for the default scaling of each activation function. If unsure, consider specifying 1. clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable the clamping (default). impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). Returns: Tensor of the same shape and datatype as `x`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) #---------------------------------------------------------------------------- @misc.profiled_function def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): """Slow reference implementation of `bias_act()` using standard TensorFlow ops. """ assert isinstance(x, torch.Tensor) assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Add bias. if b is not None: assert isinstance(b, torch.Tensor) and b.ndim == 1 assert 0 <= dim < x.ndim assert b.shape[0] == x.shape[dim] x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) # Evaluate activation function. alpha = float(alpha) x = spec.func(x, alpha=alpha) # Scale by gain. gain = float(gain) if gain != 1: x = x * gain # Clamp. if clamp >= 0: x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type return x #---------------------------------------------------------------------------- _bias_act_cuda_cache = dict() def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): """Fast CUDA implementation of `bias_act()` using custom ops. """ # Parse arguments. assert clamp is None or clamp >= 0 spec = activation_funcs[act] alpha = float(alpha if alpha is not None else spec.def_alpha) gain = float(gain if gain is not None else spec.def_gain) clamp = float(clamp if clamp is not None else -1) # Lookup from cache. key = (dim, act, alpha, gain, clamp) if key in _bias_act_cuda_cache: return _bias_act_cuda_cache[key] # Forward op. class BiasActCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, b): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format x = x.contiguous(memory_format=ctx.memory_format) b = b.contiguous() if b is not None else _null_tensor y = x if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, y if 'y' in spec.ref else _null_tensor) return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ dy = dy.contiguous(memory_format=ctx.memory_format) x, b, y = ctx.saved_tensors dx = None db = None if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: dx = dy if act != 'linear' or gain != 1 or clamp >= 0: dx = BiasActCudaGrad.apply(dy, x, b, y) if ctx.needs_input_grad[1]: db = dx.sum([i for i in range(dx.ndim) if i != dim]) return dx, db # Backward op. class BiasActCudaGrad(torch.autograd.Function): @staticmethod def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) ctx.save_for_backward( dy if spec.has_2nd_grad else _null_tensor, x, b, y) return dx @staticmethod def backward(ctx, d_dx): # pylint: disable=arguments-differ d_dx = d_dx.contiguous(memory_format=ctx.memory_format) dy, x, b, y = ctx.saved_tensors d_dy = None d_x = None d_b = None d_y = None if ctx.needs_input_grad[0]: d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) if spec.has_2nd_grad and ctx.needs_input_grad[2]: d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) return d_dy, d_x, d_b, d_y # Add to cache. _bias_act_cuda_cache[key] = BiasActCuda return BiasActCuda #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/conv2d_gradfix.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.conv2d` that supports arbitrarily high order gradients with zero performance penalty.""" import contextlib import torch from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 @contextlib.contextmanager def no_weight_gradients(disable=True): global weight_gradients_disabled old = weight_gradients_disabled if disable: weight_gradients_disabled = True yield weight_gradients_disabled = old #---------------------------------------------------------------------------- def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): if _should_use_custom_op(input): return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) #---------------------------------------------------------------------------- def _should_use_custom_op(input): assert isinstance(input, torch.Tensor) if (not enabled) or (not torch.backends.cudnn.enabled): return False if _use_pytorch_1_11_api: # The work-around code doesn't work on PyTorch 1.11.0 onwards return False if input.device.type != 'cuda': return False return True def _tuple_of_ints(xs, ndim): xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim assert len(xs) == ndim assert all(isinstance(x, int) for x in xs) return xs #---------------------------------------------------------------------------- _conv2d_gradfix_cache = dict() _null_tensor = torch.empty([0]) def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): # Parse arguments. ndim = 2 weight_shape = tuple(weight_shape) stride = _tuple_of_ints(stride, ndim) padding = _tuple_of_ints(padding, ndim) output_padding = _tuple_of_ints(output_padding, ndim) dilation = _tuple_of_ints(dilation, ndim) # Lookup from cache. key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) if key in _conv2d_gradfix_cache: return _conv2d_gradfix_cache[key] # Validate arguments. assert groups >= 1 assert len(weight_shape) == ndim + 2 assert all(stride[i] >= 1 for i in range(ndim)) assert all(padding[i] >= 0 for i in range(ndim)) assert all(dilation[i] >= 0 for i in range(ndim)) if not transpose: assert all(output_padding[i] == 0 for i in range(ndim)) else: # transpose assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) # Helpers. common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) def calc_output_padding(input_shape, output_shape): if transpose: return [0, 0] return [ input_shape[i + 2] - (output_shape[i + 2] - 1) * stride[i] - (1 - 2 * padding[i]) - dilation[i] * (weight_shape[i + 2] - 1) for i in range(ndim) ] # Forward & backward. class Conv2d(torch.autograd.Function): @staticmethod def forward(ctx, input, weight, bias): assert weight.shape == weight_shape ctx.save_for_backward( input if weight.requires_grad else _null_tensor, weight if input.requires_grad else _null_tensor, ) ctx.input_shape = input.shape # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) # General case => cuDNN. if transpose: return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) @staticmethod def backward(ctx, grad_output): input, weight = ctx.saved_tensors input_shape = ctx.input_shape grad_input = None grad_weight = None grad_bias = None if ctx.needs_input_grad[0]: p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) grad_input = op.apply(grad_output, weight, None) assert grad_input.shape == input_shape if ctx.needs_input_grad[1] and not weight_gradients_disabled: grad_weight = Conv2dGradWeight.apply(grad_output, input) assert grad_weight.shape == weight_shape if ctx.needs_input_grad[2]: grad_bias = grad_output.sum([0, 2, 3]) return grad_input, grad_weight, grad_bias # Gradient with respect to the weights. class Conv2dGradWeight(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input): ctx.save_for_backward( grad_output if input.requires_grad else _null_tensor, input if grad_output.requires_grad else _null_tensor, ) ctx.grad_output_shape = grad_output.shape ctx.input_shape = input.shape # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) # General case => cuDNN. name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) @staticmethod def backward(ctx, grad2_grad_weight): grad_output, input = ctx.saved_tensors grad_output_shape = ctx.grad_output_shape input_shape = ctx.input_shape grad2_grad_output = None grad2_input = None if ctx.needs_input_grad[0]: grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) assert grad2_grad_output.shape == grad_output_shape if ctx.needs_input_grad[1]: p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) grad2_input = op.apply(grad_output, grad2_grad_weight, None) assert grad2_input.shape == input_shape return grad2_grad_output, grad2_input _conv2d_gradfix_cache[key] = Conv2d return Conv2d #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/conv2d_resample.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """2D convolution with optional up/downsampling.""" import torch from .. import misc from . import conv2d_gradfix from . import upfirdn2d from .upfirdn2d import _parse_padding from .upfirdn2d import _get_filter_size #---------------------------------------------------------------------------- def _get_weight_shape(w): with misc.suppress_tracer_warnings(): # this value will be treated as a constant shape = [int(sz) for sz in w.shape] misc.assert_shape(w, shape) return shape #---------------------------------------------------------------------------- def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. """ _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). if not flip_weight and (kw > 1 or kh > 1): w = w.flip([2, 3]) # Execute using conv2d_gradfix. op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d return op(x, w, stride=stride, padding=padding, groups=groups) #---------------------------------------------------------------------------- @misc.profiled_function def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. Args: x: Input tensor of shape `[batch_size, in_channels, in_height, in_width]`. w: Weight tensor of shape `[out_channels, in_channels//groups, kernel_height, kernel_width]`. f: Low-pass filter for up/downsampling. Must be prepared beforehand by calling upfirdn2d.setup_filter(). None = identity (default). up: Integer upsampling factor (default: 1). down: Integer downsampling factor (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). groups: Split input channels into N groups (default: 1). flip_weight: False = convolution, True = correlation (default: True). flip_filter: False = convolution, True = correlation (default: False). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) assert isinstance(groups, int) and (groups >= 1) out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) fw, fh = _get_filter_size(f) px0, px1, py0, py1 = _parse_padding(padding) # Adjust padding to account for up/downsampling. if up > 1: px0 += (fw + up - 1) // 2 px1 += (fw - up) // 2 py0 += (fh + up - 1) // 2 py1 += (fh - up) // 2 if down > 1: px0 += (fw - down + 1) // 2 px1 += (fw - down) // 2 py0 += (fh - down + 1) // 2 py1 += (fh - down) // 2 # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) return x # Fast path: upsampling with optional downsampling => use transpose strided convolution. if up > 1: if groups == 1: w = w.transpose(0, 1) else: w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.transpose(1, 2) w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) px0 -= kw - 1 px1 -= kw - up py0 -= kh - 1 py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) # Fallback: Generic reference implementation. x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/filtered_lrelu.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "filtered_lrelu.h" //------------------------------------------------------------------------ static std::tuple filtered_lrelu( torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si, int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), "all input tensors must reside on the same device"); TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, "fu and fd must be float32"); TORCH_CHECK(b.dtype() == x.dtype(), "x and b must have the same dtype"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, "x and b must be float16 or float32"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), "fu and fd must be rank 1 or 2"); TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, "fu is too large"); TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, "fd is too large"); TORCH_CHECK(fu.numel() > 0, "fu is empty"); TORCH_CHECK(fd.numel() > 0, "fd is empty"); TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), "b must be a vector with the same number of channels as x"); TORCH_CHECK(up >= 1 && down >= 1, "up and down must be at least 1"); // Figure out how much shared memory is available on the device. int maxSharedBytes = 0; AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index())); int sharedKB = maxSharedBytes >> 10; // Populate enough launch parameters to check if a CUDA kernel exists. filtered_lrelu_kernel_params p; p.up = up; p.down = down; p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter. p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0); filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel(p, sharedKB); if (!test_spec.exec) { // No kernel found - return empty tensors and indicate missing kernel with return code of -1. return std::make_tuple(torch::Tensor(), torch::Tensor(), -1); } // Input/output element size. int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4; // Input sizes. int64_t xw = (int)x.size(3); int64_t xh = (int)x.size(2); int64_t fut_w = (int)fu.size(-1) - 1; int64_t fut_h = (int)fu.size(0) - 1; int64_t fdt_w = (int)fd.size(-1) - 1; int64_t fdt_h = (int)fd.size(0) - 1; // Logical size of upsampled buffer. int64_t cw = xw * up + (px0 + px1) - fut_w; int64_t ch = xh * up + (py0 + py1) - fut_h; TORCH_CHECK(cw > fdt_w && ch > fdt_h, "upsampled buffer must be at least the size of downsampling filter"); TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, "upsampled buffer is too large"); // Compute output size and allocate. int64_t yw = (cw - fdt_w + (down - 1)) / down; int64_t yh = (ch - fdt_h + (down - 1)) / down; TORCH_CHECK(yw > 0 && yh > 0, "output must be at least 1x1"); TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, "output is too large"); torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format()); // Allocate sign tensor. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); int64_t sw_active = 0; // Active width of sign tensor. if (writeSigns) { sw_active = yw * down - (down - 1) + fdt_w; // Active width in elements. int64_t sh = yh * down - (down - 1) + fdt_h; // Height = active height. int64_t sw = (sw_active + 15) & ~15; // Width = active width in elements, rounded up to multiple of 16. TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, "signs is too large"); s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } else if (readSigns) sw_active = s.size(3) << 2; // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, "signs is too large"); } // Populate rest of CUDA kernel parameters. p.x = x.data_ptr(); p.y = y.data_ptr(); p.b = b.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.fu = fu.data_ptr(); p.fd = fd.data_ptr(); p.pad0 = make_int2(px0, py0); p.gain = gain; p.slope = slope; p.clamp = clamp; p.flip = (flip_filters) ? 1 : 0; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.yShape = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous. p.sOfs = make_int2(sx, sy); p.swLimit = (sw_active + 3) >> 2; // Rounded up to bytes. // x, y, b strides are in bytes. p.xStride = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0)); p.yStride = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0)); p.bStride = sz * b.stride(0); // fu, fd strides are in elements. p.fuStride = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0); p.fdStride = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0); // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those. bool index64b = false; if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true; if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) > INT_MAX) index64b = true; if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true; if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) > INT_MAX) index64b = true; if (s.numel() > INT_MAX) index64b = true; // Choose CUDA kernel. filtered_lrelu_kernel_spec spec = { 0 }; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_cuda", [&] { if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation. { // Choose kernel based on index type, datatype and sign read/write modes. if (!index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel(p, sharedKB); } }); TORCH_CHECK(spec.exec, "internal error - CUDA kernel not found") // This should not happen because we tested earlier that kernel exists. // Launch CUDA kernel. void* args[] = {&p}; int bx = spec.numWarps * 32; int gx = (p.yShape.x - 1) / spec.tileOut.x + 1; int gy = (p.yShape.y - 1) / spec.tileOut.y + 1; int gz = p.yShape.z * p.yShape.w; // Repeat multiple horizontal tiles in a CTA? if (spec.xrep) { p.tilesXrep = spec.xrep; p.tilesXdim = gx; gx = (gx + p.tilesXrep - 1) / p.tilesXrep; std::swap(gx, gy); } else { p.tilesXrep = 0; p.tilesXdim = 0; } // Launch filter setup kernel. AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream())); // Copy kernels to constant memory. if ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters(at::cuda::getCurrentCUDAStream()))); // Set cache and shared memory configurations for main kernel. AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared)); if (spec.dynamicSharedKB) // Need dynamically allocated shared memory? AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10)); AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte)); // Launch main kernel. const int maxSubGz = 65535; // CUDA maximum for block z dimension. for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big. { p.blockZofs = zofs; int subGz = std::min(maxSubGz, gz - zofs); AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream())); } // Done. return std::make_tuple(y, so, 0); } //------------------------------------------------------------------------ static torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns) { // Set CUDA device. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); // Validate arguments. TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, "x is too large"); TORCH_CHECK(x.numel() > 0, "x is empty"); TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, "x must be float16, float32 or float64"); // Output signs if we don't have sign input. torch::Tensor so; torch::Tensor s = si; bool readSigns = !!s.numel(); if (writeSigns) { int64_t sw = x.size(3); sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing. s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous); } // Validate sign tensor if in use. if (readSigns || writeSigns) { TORCH_CHECK(s.is_contiguous(), "signs must be contiguous"); TORCH_CHECK(s.dtype() == torch::kUInt8, "signs must be uint8"); TORCH_CHECK(s.device() == x.device(), "signs must reside on the same device as x"); TORCH_CHECK(s.dim() == 4, "signs must be rank 4"); TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), "signs must have same batch & channels as x"); TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, "signs tensor is too large"); } // Initialize CUDA kernel parameters. filtered_lrelu_act_kernel_params p; p.x = x.data_ptr(); p.s = (readSigns || writeSigns) ? s.data_ptr() : 0; p.gain = gain; p.slope = slope; p.clamp = clamp; p.xShape = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.xStride = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0)); p.sShape = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous. p.sOfs = make_int2(sx, sy); // Choose CUDA kernel. void* func = 0; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "filtered_lrelu_act_cuda", [&] { if (writeSigns) func = choose_filtered_lrelu_act_kernel(); else if (readSigns) func = choose_filtered_lrelu_act_kernel(); else func = choose_filtered_lrelu_act_kernel(); }); TORCH_CHECK(func, "internal error - CUDA kernel not found"); // Launch CUDA kernel. void* args[] = {&p}; int bx = 128; // 4 warps per block. // Logical size of launch = writeSigns ? p.s : p.x uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x; uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y; uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use. gx = (gx - 1) / bx + 1; // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest. const uint32_t gmax = 65535; gy = std::min(gy, gmax); gz = std::min(gz, gmax); // Launch. AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream())); return so; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("filtered_lrelu", &filtered_lrelu); // The whole thing. m.def("filtered_lrelu_act_", &filtered_lrelu_act); // Activation and sign tensor handling only. Modifies data tensor in-place. } //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/filtered_lrelu.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "filtered_lrelu.h" #include //------------------------------------------------------------------------ // Helpers. enum // Filter modes. { MODE_SUSD = 0, // Separable upsampling, separable downsampling. MODE_FUSD = 1, // Full upsampling, separable downsampling. MODE_SUFD = 2, // Separable upsampling, full downsampling. MODE_FUFD = 3, // Full upsampling, full downsampling. }; template struct InternalType; template <> struct InternalType { typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); } __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; template <> struct InternalType { typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t; __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); } __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); } __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); } }; #define MIN(A, B) ((A) < (B) ? (A) : (B)) #define MAX(A, B) ((A) > (B) ? (A) : (B)) #define CEIL_DIV(A, B) (((B)==1) ? (A) : \ ((B)==2) ? ((int)((A)+1) >> 1) : \ ((B)==4) ? ((int)((A)+3) >> 2) : \ (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B))) // This works only up to blocks of size 256 x 256 and for all N that are powers of two. template __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i) { if ((N & (N-1)) && N <= 256) y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256. else y = i/N; x = i - y*N; } // Type cast stride before reading it. template __device__ __forceinline__ T get_stride(const int64_t& x) { return *reinterpret_cast(&x); } //------------------------------------------------------------------------ // Filters, setup kernel, copying function. #define MAX_FILTER_SIZE 32 // Combined up/down filter buffers so that transfer can be done with one copy. __device__ float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel. __device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel. // Accessors to combined buffers to index up/down filters individually. #define c_fu (c_fbuf) #define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) #define g_fu (g_fbuf) #define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE) // Set up filters into global memory buffer. static __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p) { for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x) { int x, y; fast_div_mod(x, y, idx); int fu_x = p.flip ? x : (p.fuShape.x - 1 - x); int fu_y = p.flip ? y : (p.fuShape.y - 1 - y); if (p.fuShape.y > 0) g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y]; else g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x]; int fd_x = p.flip ? x : (p.fdShape.x - 1 - x); int fd_y = p.flip ? y : (p.fdShape.y - 1 - y); if (p.fdShape.y > 0) g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y]; else g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x]; } } // Host function to copy filters written by setup kernel into constant buffer for main kernel. template static cudaError_t copy_filters(cudaStream_t stream) { void* src = 0; cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf); if (err) return err; return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream); } //------------------------------------------------------------------------ // Coordinate spaces: // - Relative to input tensor: inX, inY, tileInX, tileInY // - Relative to input tile: relInX, relInY, tileInW, tileInH // - Relative to upsampled tile: relUpX, relUpY, tileUpW, tileUpH // - Relative to output tile: relOutX, relOutY, tileOutW, tileOutH // - Relative to output tensor: outX, outY, tileOutX, tileOutY // // Relationships between coordinate spaces: // - inX = tileInX + relInX // - inY = tileInY + relInY // - relUpX = relInX * up + phaseInX // - relUpY = relInY * up + phaseInY // - relUpX = relOutX * down // - relUpY = relOutY * down // - outX = tileOutX + relOutX // - outY = tileOutY + relOutY extern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer. template static __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p) { // Check that we don't try to support non-existing filter modes. static_assert(up == 1 || up == 2 || up == 4, "only up=1, up=2, up=4 scales supported"); static_assert(down == 1 || down == 2 || down == 4, "only down=1, down=2, down=4 scales supported"); static_assert(fuSize >= up, "upsampling filter size must be at least upsampling factor"); static_assert(fdSize >= down, "downsampling filter size must be at least downsampling factor"); static_assert(fuSize % up == 0, "upsampling filter size must be divisible with upsampling factor"); static_assert(fdSize % down == 0, "downsampling filter size must be divisible with downsampling factor"); static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, "filter size greater than MAX_FILTER_SIZE"); static_assert(up != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "up=1 supported only for 1x1 full filters"); static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "down=1 supported only for 1x1 full filters"); static_assert(!(up == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), "full filters not supported for up=4"); static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), "full filters not supported for down=4"); // Static definitions. typedef typename InternalType::scalar_t scalar_t; typedef typename InternalType::vec2_t vec2_t; typedef typename InternalType::vec4_t vec4_t; const int tileUpW = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3; // Upsampled tile width, rounded up to multiple of 4. const int tileUpH = tileOutH * down + (fdSize - 1) - (down - 1); // Upsampled tile height. const int tileInW = CEIL_DIV(tileUpW + (fuSize - 1), up); // Input tile width. const int tileInH = CEIL_DIV(tileUpH + (fuSize - 1), up); // Input tile height. const int tileUpH_up = CEIL_DIV(tileUpH, up) * up; // Upsampled tile height rounded up to a multiple of up. const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up); // For allocations only, to avoid shared memory read overruns with up=2 and up=4. // Merge 1x1 downsampling into last upsampling step for upf1 and ups2. const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD)); // Sizes of logical buffers. const int szIn = tileInH_up * tileInW; const int szUpX = tileInH_up * tileUpW; const int szUpXY = downInline ? 0 : (tileUpH * tileUpW); const int szDownX = tileUpH * tileOutW; // Sizes for shared memory arrays. const int s_buf0_size_base = (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) : (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) : (filterMode == MODE_FUFD) ? szIn : -1; const int s_buf1_size_base = (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) : (filterMode == MODE_FUSD) ? szUpXY : (filterMode == MODE_SUFD) ? szUpX : (filterMode == MODE_FUFD) ? szUpXY : -1; // Ensure U128 alignment. const int s_buf0_size = (s_buf0_size_base + 3) & ~3; const int s_buf1_size = (s_buf1_size_base + 3) & ~3; // Check at compile time that we don't use too much shared memory. static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), "shared memory overflow"); // Declare shared memory arrays. scalar_t* s_buf0; scalar_t* s_buf1; if (sharedKB <= 48) { // Allocate shared memory arrays here. __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused. s_buf0 = s_buf0_st; s_buf1 = s_buf0 + s_buf0_size; } else { // Use the dynamically allocated shared memory array. s_buf0 = (scalar_t*)s_buf_raw; s_buf1 = s_buf0 + s_buf0_size; } // Pointers to the buffers. scalar_t* s_tileIn; // Input tile: [relInX * tileInH + relInY] scalar_t* s_tileUpX; // After horizontal upsampling: [relInY * tileUpW + relUpX] scalar_t* s_tileUpXY; // After upsampling: [relUpY * tileUpW + relUpX] scalar_t* s_tileDownX; // After horizontal downsampling: [relUpY * tileOutW + relOutX] if (filterMode == MODE_SUSD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; s_tileDownX = s_buf1; } else if (filterMode == MODE_FUSD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; s_tileDownX = s_buf0; } else if (filterMode == MODE_SUFD) { s_tileIn = s_buf0; s_tileUpX = s_buf1; s_tileUpXY = s_buf0; } else if (filterMode == MODE_FUFD) { s_tileIn = s_buf0; s_tileUpXY = s_buf1; } // Allow large grids in z direction via per-launch offset. int channelIdx = blockIdx.z + p.blockZofs; int batchIdx = channelIdx / p.yShape.z; channelIdx -= batchIdx * p.yShape.z; // Offset to output feature map. In bytes. index_t mapOfsOut = channelIdx * get_stride(p.yStride.z) + batchIdx * get_stride(p.yStride.w); // Sign shift amount. uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6; // Inner tile loop. #pragma unroll 1 for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++) { // Locate output tile. int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x; int tileOutX = tileX * tileOutW; int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH; // Locate input tile. int tmpX = tileOutX * down - p.pad0.x; int tmpY = tileOutY * down - p.pad0.y; int tileInX = CEIL_DIV(tmpX, up); int tileInY = CEIL_DIV(tmpY, up); const int phaseInX = tileInX * up - tmpX; const int phaseInY = tileInY * up - tmpY; // Extra sync if input and output buffers are the same and we are not on first tile. if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline))) __syncthreads(); // Load input tile & apply bias. Unrolled. scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride(p.bStride))); index_t mapOfsIn = channelIdx * get_stride(p.xStride.z) + batchIdx * get_stride(p.xStride.w); int idx = threadIdx.x; const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock); #pragma unroll for (int loop = 0; loop < loopCountIN; loop++) { int relInX, relInY; fast_div_mod(relInX, relInY, idx); int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y) v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride(p.xStride.x) + inY * get_stride(p.xStride.y) + mapOfsIn))) + b; bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH); if (!skip) s_tileIn[idx] = v; idx += threadsPerBlock; } if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter. { // Horizontal upsampling. __syncthreads(); if (up == 4) { for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileIn[src0]; if (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInX == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInX == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; s_tileUpX[dst+2] = v.z; s_tileUpX[dst+3] = v.w; } } else if (up == 2) { bool p0 = (phaseInX == 0); for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up) { int relUpX0, relInY; fast_div_mod(relUpX0, relInY, idx); int relInX0 = relUpX0 / up; int src0 = relInX0 + tileInW * relInY; int dst = relInY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileIn[src0]; if (p0) // (phaseInX == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInX == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileIn[src0 + step + 1]; } } s_tileUpX[dst+0] = v.x; s_tileUpX[dst+1] = v.y; } } // Vertical upsampling & nonlinearity. __syncthreads(); int groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes. if (up == 4) { minY -= 3; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec4_t v = InternalType::zero_vec4(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 3]; v.z += a * (scalar_t)c_fu[step * up + 2]; v.w += a * (scalar_t)c_fu[step * up + 1]; } } else if (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.z += a * (scalar_t)c_fu[step * up + 3]; v.w += a * (scalar_t)c_fu[step * up + 2]; } } else if (phaseInY == 2) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 2]; v.y += a * (scalar_t)c_fu[step * up + 1]; v.z += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.w += a * (scalar_t)c_fu[step * up + 3]; } } else // (phaseInY == 3) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 3]; v.y += a * (scalar_t)c_fu[step * up + 2]; v.z += a * (scalar_t)c_fu[step * up + 1]; v.w += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; index_t si2 = si0 + p.sShape.x * 2; index_t si3 = si0 + p.sShape.x * 3; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; int sz = __float_as_uint(v.z) >> 31 << 16; int sw = __float_as_uint(v.w) >> 31 << 24; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (sz) v.z *= p.slope; if (sw) v.w *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType::clamp(v.z, p.clamp); } if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType::clamp(v.w, p.clamp); } // Combine signs. uint32_t s = sx + sy + sw + sz; s <<= (signX & 3) << 1; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); } if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { int ss = (signX & 3) << 1; if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; } if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[dst + 0 * tileUpW] = v.x; if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y; if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z; if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w; } } else if (up == 2) { minY -= 1; // Adjust according to block height. for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x) { int relUpX, relInY0; fast_div_mod(relUpX, relInY0, idx); int relUpY0 = relInY0 * up; int src0 = relInY0 * tileUpW + relUpX; int dst = relUpY0 * tileUpW + relUpX; vec2_t v = InternalType::zero_vec2(); scalar_t a = s_tileUpX[src0]; if (phaseInY == 0) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; v.y += a * (scalar_t)c_fu[step * up + 1]; } } else // (phaseInY == 1) { #pragma unroll for (int step = 0; step < fuSize / up; step++) { v.x += a * (scalar_t)c_fu[step * up + 1]; v.y += a * (scalar_t)c_fu[step * up + 0]; a = s_tileUpX[src0 + (step + 1) * tileUpW]; } } int x = tileOutX * down + relUpX; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); index_t si1 = si0 + p.sShape.x; v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } if ((uint32_t)signXb < p.swLimit && signY >= minY) { // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && signY >= minY) { int sx = __float_as_uint(v.x) >> 31 << 0; int sy = __float_as_uint(v.y) >> 31 << 8; if (sx) v.x *= p.slope; if (sy) v.y *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType::clamp(v.x, p.clamp); } if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType::clamp(v.y, p.clamp); } // Combine signs. int s = sx + sy; s <<= signXo; s |= __shfl_xor_sync(groupMask, s, 1); s |= __shfl_xor_sync(groupMask, s, 2); // Write signs. if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >> 0); } if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >> 8); } } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } } } else if (signRead) // Read signs and apply. { if ((uint32_t)signXb < p.swLimit) { if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; } if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; } } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); } if (!downInline) { // Write into temporary buffer. s_tileUpXY[dst] = v.x; if (relUpY0 < tileUpH - 1) s_tileUpXY[dst + tileUpW] = v.y; } else { // Write directly into output buffer. if ((uint32_t)x < p.yShape.x) { int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down); index_t ofs = x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut; if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]); if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]); } } } } } else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD) { // Full upsampling filter. if (up == 2) { // 2 x 2-wide. __syncthreads(); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs. for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up); int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up); int src0 = relInX0 + tileInW * relInY0; int tap0y = (relInY0 * up + phaseInY - relUpY0); #define X_LOOP(TAPY, PX) \ for (int sx = 0; sx < fuSize / up; sx++) \ { \ v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \ v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \ } vec4_t v = InternalType::zero_vec4(); if (tap0y == 0 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 0) } if (tap0y == 0 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(0, 1) } if (tap0y == 1 && phaseInX == 0) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 0) } if (tap0y == 1 && phaseInX == 1) #pragma unroll for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1]; #pragma unroll X_LOOP(1, 1) } #undef X_LOOP int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v.x *= (scalar_t)((float)up * (float)up * p.gain); v.y *= (scalar_t)((float)up * (float)up * p.gain); v.z *= (scalar_t)((float)up * (float)up * p.gain); v.w *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write signs. int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } } else { // Determine and write signs. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { int sx = __float_as_uint(v.x) >> 31; int sy = __float_as_uint(v.y) >> 31; int sz = __float_as_uint(v.z) >> 31; int sw = __float_as_uint(v.w) >> 31; if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType::clamp(v.x, p.clamp); } if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType::clamp(v.y, p.clamp); } if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType::clamp(v.z, p.clamp); } if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType::clamp(v.w, p.clamp); } p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6); } else { // Just compute the values. if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } } } else if (signRead) // Read sign and apply. { if ((uint32_t)signY < p.sShape.y) { int s = 0; if ((uint32_t)signXb < p.swLimit) s = p.s[si]; if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8; s >>= (signX & 3) << 1; if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f; if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f; if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f; if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f; } } else // Forward pass with no sign write. { if (v.x < 0.f) v.x *= p.slope; v.x = InternalType::clamp(v.x, p.clamp); if (v.y < 0.f) v.y *= p.slope; v.y = InternalType::clamp(v.y, p.clamp); if (v.z < 0.f) v.z *= p.slope; v.z = InternalType::clamp(v.z, p.clamp); if (v.w < 0.f) v.w *= p.slope; v.w = InternalType::clamp(v.w, p.clamp); } s_tileUpXY[idx + 0] = v.x; s_tileUpXY[idx + 1] = v.y; s_tileUpXY[idx + 2] = v.z; s_tileUpXY[idx + 3] = v.w; } } else if (up == 1) { __syncthreads(); uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3); int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs. for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x) { int relUpX0, relUpY0; fast_div_mod(relUpX0, relUpY0, idx); scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter. int x = tileOutX * down + relUpX0; int y = tileOutY * down + relUpY0; int signX = x + p.sOfs.x; int signY = y + p.sOfs.y; int signZ = blockIdx.z + p.blockZofs; int signXb = signX >> 2; index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ); v *= (scalar_t)((float)up * (float)up * p.gain); if (signWrite) { if (!enableWriteSkip) { // Determine and write sign. uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } } else { // Determine and write sign. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY) { uint32_t s = 0; uint32_t signXbit = (1u << signXo); if (v < 0.f) { s = signXbit; v *= p.slope; } if (fabsf(v) > p.clamp) { s = signXbit * 2; v = InternalType::clamp(v, p.clamp); } s += __shfl_xor_sync(groupMask, s, 1); // Coalesce. s += __shfl_xor_sync(groupMask, s, 2); // Coalesce. p.s[si] = s; // Write. } else { // Just compute the value. if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } } } else if (signRead) { // Read sign and apply if within sign tensor bounds. if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y) { int s = p.s[si]; s >>= signXo; if (s & 1) v *= p.slope; if (s & 2) v = 0.f; } } else // Forward pass with no sign write. { if (v < 0.f) v *= p.slope; v = InternalType::clamp(v, p.clamp); } if (!downInline) // Write into temporary buffer. s_tileUpXY[idx] = v; else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer *((T*)((char*)p.y + (x * get_stride(p.yStride.x) + y * get_stride(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]); } } } // Downsampling. if (filterMode == MODE_SUSD || filterMode == MODE_FUSD) { // Horizontal downsampling. __syncthreads(); if (down == 4 && tileOutW % 4 == 0) { // Calculate 4 pixels at a time. for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec4_t v = InternalType::zero_vec4(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + 4 + step] * (scalar_t)c_fd[step]; v.z += s_tileUpXY[src0 + 8 + step] * (scalar_t)c_fd[step]; v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; s_tileDownX[idx+2] = v.z; s_tileDownX[idx+3] = v.w; } } else if ((down == 2 || down == 4) && (tileOutW % 2 == 0)) { // Calculate 2 pixels at a time. for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src0 = relUpY * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int step = 0; step < fdSize; step++) { v.x += s_tileUpXY[src0 + 0 + step] * (scalar_t)c_fd[step]; v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step]; } s_tileDownX[idx+0] = v.x; s_tileDownX[idx+1] = v.y; } } else { // Calculate 1 pixel at a time. for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x) { int relOutX0, relUpY; fast_div_mod(relOutX0, relUpY, idx); int relUpX0 = relOutX0 * down; int src = relUpY * tileUpW + relUpX0; scalar_t v = 0.f; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileUpXY[src + step] * (scalar_t)c_fd[step]; s_tileDownX[idx] = v; } } // Vertical downsampling & store output tile. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX, relOutY0; fast_div_mod(relOutX, relOutY0, idx); int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileOutW + relOutX; scalar_t v = 0; #pragma unroll for (int step = 0; step < fdSize; step++) v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step]; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY0; if (outX < p.yShape.x & outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD) { // Full downsampling filter. if (down == 2) { // 2-wide. __syncthreads(); for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); int relUpX0 = relOutX0 * down; int relUpY0 = relOutY0 * down; int src0 = relUpY0 * tileUpW + relUpX0; vec2_t v = InternalType::zero_vec2(); #pragma unroll for (int sy = 0; sy < fdSize; sy++) #pragma unroll for (int sx = 0; sx < fdSize; sx++) { v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE]; } int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outY < p.yShape.y) { index_t ofs = outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut; if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x; if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride(p.yStride.x))) = (T)v.y; } } } else if (down == 1 && !downInline) { // Thread per pixel. __syncthreads(); for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x) { int relOutX0, relOutY0; fast_div_mod(relOutX0, relOutY0, idx); scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter. int outX = tileOutX + relOutX0; int outY = tileOutY + relOutY0; if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y) *((T*)((char*)p.y + (outX * get_stride(p.yStride.x) + outY * get_stride(p.yStride.y) + mapOfsOut))) = (T)v; } } } if (!enableXrep) break; } } //------------------------------------------------------------------------ // Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant. // Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used. template static __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Indexing. int32_t x = threadIdx.x + blockIdx.x * blockDim.x; int32_t ymax = signWrite ? p.sShape.y : p.xShape.y; int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index. // Loop to accommodate oversized tensors. for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z) for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y) { // Extract z and w (channel, minibatch index). int32_t w = q / p.xShape.z; int32_t z = q - w * p.xShape.z; // Choose behavior based on sign read/write mode. if (signWrite) { // Process value if in p.x. uint32_t s = 0; if (x < p.xShape.x && y < p.xShape.y) { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); // Gain, LReLU, clamp. v *= p.gain; if (v < 0.f) { v *= p.slope; s = 1; // Sign. } if (fabsf(v) > p.clamp) { v = InternalType::clamp(v, p.clamp); s = 2; // Clamp. } *pv = (T)v; // Write value. } // Coalesce into threads 0 and 16 of warp. uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu; s <<= ((threadIdx.x & 15) << 1); // Shift into place. s |= __shfl_xor_sync(m, s, 1); // Distribute. s |= __shfl_xor_sync(m, s, 2); s |= __shfl_xor_sync(m, s, 4); s |= __shfl_xor_sync(m, s, 8); // Write signs if leader and in p.s. if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in. { uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous. ((uint32_t*)p.s)[is >> 4] = s; } } else if (signRead) { // Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; // Apply sign buffer offset. uint32_t sx = x + p.sOfs.x; uint32_t sy = y + p.sOfs.y; // Read and apply signs if we land inside valid region of sign buffer. if (sx < p.sShape.x && sy < p.sShape.y) { uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous. unsigned char s = p.s[is]; s >>= (sx & 3) << 1; // Shift into place. if (s & 1) // Sign? v *= p.slope; if (s & 2) // Clamp? v = 0.f; } *pv = (T)v; // Write value. } } else { // Forward pass with no sign write. Process value if in p.x. if (x < p.xShape.x) // y is always in. { int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w; T* pv = ((T*)p.x) + ix; scalar_t v = (scalar_t)(*pv); v *= p.gain; if (v < 0.f) v *= p.slope; if (fabsf(v) > p.clamp) v = InternalType::clamp(v, p.clamp); *pv = (T)v; // Write value. } } } } template void* choose_filtered_lrelu_act_kernel(void) { return (void*)filtered_lrelu_act_kernel; } //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB) { filtered_lrelu_kernel_spec s = { 0 }; // Return the first matching kernel. #define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \ if (sharedKB >= SH) \ if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \ if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \ if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \ { \ static_assert((D*TW % 4) == 0, "down * tileWidth must be divisible by 4"); \ static_assert(FU % U == 0, "upscaling filter size must be multiple of upscaling factor"); \ static_assert(FD % D == 0, "downscaling filter size must be multiple of downscaling factor"); \ s.setup = (void*)setup_filters_kernel; \ s.exec = (void*)filtered_lrelu_kernel; \ s.tileOut = make_int2(TW, TH); \ s.numWarps = W; \ s.xrep = XR; \ s.dynamicSharedKB = (SH == 48) ? 0 : SH; \ return s; \ } // Launch parameters for various kernel specializations. // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first. // Kernels that use more shared memory must be listed before those that use less, for the same reason. CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/1,1, /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64, 178, 32, 0, 0) // 1t-upf1-downf1 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95, 16, 0, 0) // 4t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 22, 16, 0, 0) // 4t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 29, 16, 11, 0) // 4t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60, 28, 16, 0, 0) // 4t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 28, 16, 0, 0) // 4t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56, 31, 16, 11, 0) // 4t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56, 36, 16, 0, 0) // 4t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 22, 16, 12, 0) // 4t-ups2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,8, /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29, 15, 16, 0, 0) // 4t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96, 150, 28, 0, 0) // 6t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 35, 24, 0, 0) // 6t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 16, 10, 0) // 6t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58, 28, 24, 8, 0) // 6t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52, 28, 16, 0, 0) // 6t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 51, 16, 5, 0) // 6t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 56, 16, 6, 0) // 6t-ups4-downf2 CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 18, 16, 12, 0) // 6t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 31, 32, 6, 0) // 6t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27, 13, 24, 0, 0) // 6t-upf2-downs4 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89, 24, 0, 0) // 8t-ups2-downf1 CASE(/*sharedKB*/48, /*up,fu*/1,1, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32, 31, 16, 5, 0) // 8t-upf1-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 41, 16, 9, 0) // 8t-ups2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56, 26, 24, 0, 0) // 8t-upf2-downs2 CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 40, 16, 0, 0) // 8t-ups2-downf2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32, 46, 24, 5, 0) // 8t-ups4-downs2 CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32, 50, 16, 0, 0) // 8t-ups4-downf2 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24, 24, 32, 12, 1) // 8t-ups2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16, 13, 16, 10, 1) // 8t-ups2-downs4 CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 28, 28, 4, 0) // 8t-upf2-downs4 96kB CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25, 10, 24, 0, 0) // 8t-upf2-downs4 #undef CASE return s; // No kernel found. } //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/filtered_lrelu.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct filtered_lrelu_kernel_params { // These parameters decide which kernel to use. int up; // upsampling ratio (1, 2, 4) int down; // downsampling ratio (1, 2, 4) int2 fuShape; // [size, 1] | [size, size] int2 fdShape; // [size, 1] | [size, size] int _dummy; // Alignment. // Rest of the parameters. const void* x; // Input tensor. void* y; // Output tensor. const void* b; // Bias tensor. unsigned char* s; // Sign tensor in/out. NULL if unused. const float* fu; // Upsampling filter. const float* fd; // Downsampling filter. int2 pad0; // Left/top padding. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int flip; // Filter kernel flip for gradient computation. int tilesXdim; // Original number of horizontal output tiles. int tilesXrep; // Number of horizontal tiles per CTA. int blockZofs; // Block z offset to support large minibatch, channel dimensions. int4 xShape; // [width, height, channel, batch] int4 yShape; // [width, height, channel, batch] int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. int swLimit; // Active width of sign tensor in bytes. longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. longlong4 yStride; // int64_t bStride; // longlong3 fuStride; // longlong3 fdStride; // }; struct filtered_lrelu_act_kernel_params { void* x; // Input/output, modified in-place. unsigned char* s; // Sign tensor in/out. NULL if unused. float gain; // Additional gain factor. float slope; // Leaky ReLU slope on negative side. float clamp; // Clamp after nonlinearity. int4 xShape; // [width, height, channel, batch] longlong4 xStride; // Input/output tensor strides, same order as in shape. int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct filtered_lrelu_kernel_spec { void* setup; // Function for filter kernel setup. void* exec; // Function for main operation. int2 tileOut; // Width/height of launch tile. int numWarps; // Number of warps per thread block, determines launch block size. int xrep; // For processing multiple horizontal tiles per thread block. int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. }; //------------------------------------------------------------------------ // CUDA kernel selection. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template void* choose_filtered_lrelu_act_kernel(void); template cudaError_t copy_filters(cudaStream_t stream); //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/filtered_lrelu.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import os import numpy as np import torch import warnings from .. import custom_ops from .. import misc from . import upfirdn2d from . import bias_act #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='filtered_lrelu_plugin', sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'], headers=['filtered_lrelu.h', 'filtered_lrelu.cu'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) assert 1 <= f.ndim <= 2 return f.shape[-1], f.shape[0] # width, height def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, (int, np.integer)) for x in padding) padding = [int(x) for x in padding] if len(padding) == 2: px, py = padding padding = [px, px, py, py] px0, px1, py0, py1 = padding return px0, px1, py0, py1 #---------------------------------------------------------------------------- def filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'): r"""Filtered leaky ReLU for a batch of 2D images. Performs the following sequence of operations for each channel: 1. Add channel-specific bias if provided (`b`). 2. Upsample the image by inserting N-1 zeros after each pixel (`up`). 3. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it so that the footprint of all output pixels lies within the input image. 5. Multiply each value by the provided gain factor (`gain`). 6. Apply leaky ReLU activation function to each value. 7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided. 8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking it so that the footprint of all output pixels lies within the input image. 9. Downsample the image by keeping every Nth pixel (`down`). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float16/float64 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. fu: Float32 upsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). fd: Float32 downsampling FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type as `x`. The length of vector must must match the channel dimension of `x`. up: Integer upsampling factor (default: 1). down: Integer downsampling factor. (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). gain: Overall scaling factor for signal magnitude (default: sqrt(2)). slope: Slope on the negative side of leaky ReLU (default: 0.2). clamp: Maximum magnitude for leaky ReLU output (default: None). flip_filter: False = convolution, True = correlation (default: False). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0) return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter) #---------------------------------------------------------------------------- @misc.profiled_function def _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Slow and memory-inefficient reference implementation of `filtered_lrelu()` using existing `upfirdn2n()` and `bias_act()` ops. """ assert isinstance(x, torch.Tensor) and x.ndim == 4 fu_w, fu_h = _get_filter_size(fu) fd_w, fd_h = _get_filter_size(fd) if b is not None: assert isinstance(b, torch.Tensor) and b.dtype == x.dtype misc.assert_shape(b, [x.shape[1]]) assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 assert slope == float(slope) and slope >= 0 assert clamp is None or (clamp == float(clamp) and clamp >= 0) # Calculate output size. batch_size, channels, in_h, in_w = x.shape in_dtype = x.dtype out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down # Compute using existing ops. x = bias_act.bias_act(x=x, b=b) # Apply bias. x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp. x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Check output shape & dtype. misc.assert_shape(x, [batch_size, channels, out_h, out_w]) assert x.dtype == in_dtype return x #---------------------------------------------------------------------------- _filtered_lrelu_cuda_cache = dict() def _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False): """Fast CUDA implementation of `filtered_lrelu()` using custom ops. """ assert isinstance(up, int) and up >= 1 assert isinstance(down, int) and down >= 1 px0, px1, py0, py1 = _parse_padding(padding) assert gain == float(gain) and gain > 0 gain = float(gain) assert slope == float(slope) and slope >= 0 slope = float(slope) assert clamp is None or (clamp == float(clamp) and clamp >= 0) clamp = float(clamp if clamp is not None else 'inf') # Lookup from cache. key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter) if key in _filtered_lrelu_cuda_cache: return _filtered_lrelu_cuda_cache[key] # Forward op. class FilteredLReluCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable). if fu is None: fu = torch.ones([1, 1], dtype=torch.float32, device=x.device) if fd is None: fd = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert 1 <= fu.ndim <= 2 assert 1 <= fd.ndim <= 2 # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1. if up == 1 and fu.ndim == 1 and fu.shape[0] == 1: fu = fu.square()[None] if down == 1 and fd.ndim == 1 and fd.shape[0] == 1: fd = fd.square()[None] # Missing sign input tensor. if si is None: si = torch.empty([0]) # Missing bias tensor. if b is None: b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) # Construct internal sign tensor only if gradients are needed. write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad) # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout. strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1] if any(a < b for a, b in zip(strides[:-1], strides[1:])): warnings.warn("low-performance memory layout detected in filtered_lrelu input", RuntimeWarning) # Call C++/Cuda plugin if datatype is supported. if x.dtype in [torch.float16, torch.float32]: if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): warnings.warn("filtered_lrelu called with non-default cuda stream but concurrent execution is not supported", RuntimeWarning) y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs) else: return_code = -1 # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because # only the bit-packed sign tensor is retained for gradient computation. if return_code < 0: warnings.warn("filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback", RuntimeWarning) y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias. y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample. so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place. y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample. # Prepare for gradient computation. ctx.save_for_backward(fu, fd, (si if si.numel() else so)) ctx.x_shape = x.shape ctx.y_shape = y.shape ctx.s_ofs = sx, sy return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ fu, fd, si = ctx.saved_tensors _, _, xh, xw = ctx.x_shape _, _, yh, yw = ctx.y_shape sx, sy = ctx.s_ofs dx = None # 0 dfu = None; assert not ctx.needs_input_grad[1] dfd = None; assert not ctx.needs_input_grad[2] db = None # 3 dsi = None; assert not ctx.needs_input_grad[4] dsx = None; assert not ctx.needs_input_grad[5] dsy = None; assert not ctx.needs_input_grad[6] if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]: pp = [ (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0, xw * up - yw * down + px0 - (up - 1), (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0, xh * up - yh * down + py0 - (up - 1), ] gg = gain * (up ** 2) / (down ** 2) ff = (not flip_filter) sx = sx - (fu.shape[-1] - 1) + px0 sy = sy - (fu.shape[0] - 1) + py0 dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy) if ctx.needs_input_grad[3]: db = dx.sum([0, 2, 3]) return dx, dfu, dfd, db, dsi, dsx, dsy # Add to cache. _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda return FilteredLReluCuda #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/filtered_lrelu_ns.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for no signs mode (no gradients required). // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: ADD/th_utils/ops/filtered_lrelu_rd.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign read mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: ADD/th_utils/ops/filtered_lrelu_wr.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include "filtered_lrelu.cu" // Template/kernel specializations for sign write mode. // Full op, 32-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Full op, 64-bit indexing. template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); // Activation/signs only for generic variant. 64-bit indexing. template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); template void* choose_filtered_lrelu_act_kernel(void); // Copy filters to constant memory. template cudaError_t copy_filters(cudaStream_t stream); ================================================ FILE: ADD/th_utils/ops/fma.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" import torch #---------------------------------------------------------------------------- def fma(a, b, c): # => a * b + c return _FusedMultiplyAdd.apply(a, b, c) #---------------------------------------------------------------------------- class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c @staticmethod def forward(ctx, a, b, c): # pylint: disable=arguments-differ out = torch.addcmul(c, a, b) ctx.save_for_backward(a, b) ctx.c_shape = c.shape return out @staticmethod def backward(ctx, dout): # pylint: disable=arguments-differ a, b = ctx.saved_tensors c_shape = ctx.c_shape da = None db = None dc = None if ctx.needs_input_grad[0]: da = _unbroadcast(dout * b, a.shape) if ctx.needs_input_grad[1]: db = _unbroadcast(dout * a, b.shape) if ctx.needs_input_grad[2]: dc = _unbroadcast(dout, c_shape) return da, db, dc #---------------------------------------------------------------------------- def _unbroadcast(x, shape): extra_dims = x.ndim - len(shape) assert extra_dims >= 0 dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] if len(dim): x = x.sum(dim=dim, keepdim=True) if extra_dims: x = x.reshape(-1, *x.shape[extra_dims+1:]) assert x.shape == shape return x #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/grid_sample_gradfix.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom replacement for `torch.nn.functional.grid_sample` that supports arbitrarily high order gradients between the input and output. Only works on 2D images and assumes `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" import torch from pkg_resources import parse_version # pylint: disable=redefined-builtin # pylint: disable=arguments-differ # pylint: disable=protected-access #---------------------------------------------------------------------------- enabled = False # Enable the custom op by setting this to true. _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 #---------------------------------------------------------------------------- def grid_sample(input, grid): if _should_use_custom_op(): return _GridSample2dForward.apply(input, grid) return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) #---------------------------------------------------------------------------- def _should_use_custom_op(): return enabled #---------------------------------------------------------------------------- class _GridSample2dForward(torch.autograd.Function): @staticmethod def forward(ctx, input, grid): assert input.ndim == 4 assert grid.ndim == 4 output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) ctx.save_for_backward(input, grid) return output @staticmethod def backward(ctx, grad_output): input, grid = ctx.saved_tensors grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) return grad_input, grad_grid #---------------------------------------------------------------------------- class _GridSample2dBackward(torch.autograd.Function): @staticmethod def forward(ctx, grad_output, input, grid): op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') if _use_pytorch_1_11_api: output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) else: grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) ctx.save_for_backward(grid) return grad_input, grad_grid @staticmethod def backward(ctx, grad2_grad_input, grad2_grad_grid): _ = grad2_grad_grid # unused grid, = ctx.saved_tensors grad2_grad_output = None grad2_input = None grad2_grid = None if ctx.needs_input_grad[0]: grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) assert not ctx.needs_input_grad[2] return grad2_grad_output, grad2_input, grad2_grid #---------------------------------------------------------------------------- ================================================ FILE: ADD/th_utils/ops/upfirdn2d.cpp ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include #include #include "upfirdn2d.h" //------------------------------------------------------------------------ static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) { // Validate arguments. TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); TORCH_CHECK(x.numel() > 0, "x has zero size"); TORCH_CHECK(f.numel() > 0, "f has zero size"); TORCH_CHECK(x.dim() == 4, "x must be rank 4"); TORCH_CHECK(f.dim() == 2, "f must be rank 2"); TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); // Create output tensor. const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); // Initialize CUDA kernel parameters. upfirdn2d_kernel_params p; p.x = x.data_ptr(); p.f = f.data_ptr(); p.y = y.data_ptr(); p.up = make_int2(upx, upy); p.down = make_int2(downx, downy); p.pad0 = make_int2(padx0, pady0); p.flip = (flip) ? 1 : 0; p.gain = gain; p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; // Choose CUDA kernel. upfirdn2d_kernel_spec spec; AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { spec = choose_upfirdn2d_kernel(p); }); // Set looping options. p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; p.loopMinor = spec.loopMinor; p.loopX = spec.loopX; p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; // Compute grid size. dim3 blockSize, gridSize; if (spec.tileOutW < 0) // large { blockSize = dim3(4, 32, 1); gridSize = dim3( ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, p.launchMajor); } else // small { blockSize = dim3(256, 1, 1); gridSize = dim3( ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, p.launchMajor); } // Launch CUDA kernel. void* args[] = {&p}; AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); return y; } //------------------------------------------------------------------------ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d); } //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/upfirdn2d.cu ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include #include "upfirdn2d.h" //------------------------------------------------------------------------ // Helpers. template struct InternalType; template <> struct InternalType { typedef double scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; template <> struct InternalType { typedef float scalar_t; }; static __device__ __forceinline__ int floor_div(int a, int b) { int t = 1 - a / b; return (a + t * b) / b - t; } //------------------------------------------------------------------------ // Generic CUDA implementation for large filters. template static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; // Calculate thread index. int minorBase = blockIdx.x * blockDim.x + threadIdx.x; int outY = minorBase / p.launchMinor; minorBase -= outY * p.launchMinor; int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y; int majorBase = blockIdx.z * p.loopMajor; if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Setup Y receptive field. int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y; int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y); int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY; int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y; if (p.flip) filterY = p.filterSize.y - 1 - filterY; // Loop over major, minor, and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor) { int nc = major * p.sizeMinor + minor; int n = nc / p.inSize.z; int c = nc - n * p.inSize.z; for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y) { // Setup X receptive field. int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x; int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x); int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX; int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x; if (p.flip) filterX = p.filterSize.x - 1 - filterX; // Initialize pointers. const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y]; int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x; int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y; // Inner loop. scalar_t v = 0; for (int y = 0; y < h; y++) { for (int x = 0; x < w; x++) { v += (scalar_t)(*xp) * (scalar_t)(*fp); xp += p.inStride.x; fp += filterStepX; } xp += p.inStride.y - w * p.inStride.x; fp += filterStepY - w * filterStepX; } // Store result. v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } //------------------------------------------------------------------------ // Specialized CUDA implementation for small filters. template static __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p) { typedef typename InternalType::scalar_t scalar_t; const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1; const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1; __shared__ volatile scalar_t sf[filterH][filterW]; __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor]; // Calculate tile index. int minorBase = blockIdx.x; int tileOutY = minorBase / p.launchMinor; minorBase -= tileOutY * p.launchMinor; minorBase *= loopMinor; tileOutY *= tileOutH; int tileOutXBase = blockIdx.y * p.loopX * tileOutW; int majorBase = blockIdx.z * p.loopMajor; if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor) return; // Load filter (flipped). for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x) { int fy = tapIdx / filterW; int fx = tapIdx - fy * filterW; scalar_t v = 0; if (fx < p.filterSize.x & fy < p.filterSize.y) { int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx; int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy; v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y]; } sf[fy][fx] = v; } // Loop over major and X. for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++) { int baseNC = major * p.sizeMinor + minorBase; int n = baseNC / p.inSize.z; int baseC = baseNC - n * p.inSize.z; for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW) { // Load input pixels. int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x; int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y; int tileInX = floor_div(tileMidX, upx); int tileInY = floor_div(tileMidY, upy); __syncthreads(); for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x) { int relC = inIdx; int relInX = relC / loopMinor; int relInY = relInX / tileInW; relC -= relInX * loopMinor; relInX -= relInY * tileInW; int c = baseC + relC; int inX = tileInX + relInX; int inY = tileInY + relInY; scalar_t v = 0; if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z) v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w]; sx[relInY][relInX][relC] = v; } // Loop over output pixels. __syncthreads(); for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x) { int relC = outIdx; int relOutX = relC / loopMinor; int relOutY = relOutX / tileOutW; relC -= relOutX * loopMinor; relOutX -= relOutY * tileOutW; int c = baseC + relC; int outX = tileOutX + relOutX; int outY = tileOutY + relOutY; // Setup receptive field. int midX = tileMidX + relOutX * downx; int midY = tileMidY + relOutY * downy; int inX = floor_div(midX, upx); int inY = floor_div(midY, upy); int relInX = inX - tileInX; int relInY = inY - tileInY; int filterX = (inX + 1) * upx - midX - 1; // flipped int filterY = (inY + 1) * upy - midY - 1; // flipped // Inner loop. if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z) { scalar_t v = 0; #pragma unroll for (int y = 0; y < filterH / upy; y++) #pragma unroll for (int x = 0; x < filterW / upx; x++) v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx]; v *= p.gain; ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v; } } } } } //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p) { int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y; upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large, -1,-1,1, 4}; // contiguous if (s == 1) spec = {(void*)upfirdn2d_kernel_large, -1,-1,4, 1}; // channels_last // No up/downsampling. if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 7 && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 5 && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 3 && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 8 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 2x upsampling. if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 64,16,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 16,16,8, 1}; } if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; } if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 2x downsampling. if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2) { // contiguous if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 16,16,1, 1}; if (s == 1 && fx <= 8 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 6 && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 4 && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; if (s == 1 && fx <= 2 && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small, 8,8,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; if (s != 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,8,1, 1}; // channels_last if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; if (s == 1 && fx <= 8 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 64,1,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2) { // contiguous if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 32,16,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small, 1,64,8, 1}; } // 4x upsampling. if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 64,32,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; } if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,8,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 128,1,16, 1}; } if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,32,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,128,16, 1}; } // 4x downsampling (inefficient). if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1) { // contiguous if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small, 32,1,8, 1}; } if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4) { // contiguous if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 32,8,1, 1}; // channels_last if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small, 1,32,8, 1}; } return spec; } //------------------------------------------------------------------------ // Template specializations. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel (const upfirdn2d_kernel_params& p); template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/upfirdn2d.h ================================================ // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // NVIDIA CORPORATION and its licensors retain all intellectual property // and proprietary rights in and to this software, related documentation // and any modifications thereto. Any use, reproduction, disclosure or // distribution of this software and related documentation without an express // license agreement from NVIDIA CORPORATION is strictly prohibited. #include //------------------------------------------------------------------------ // CUDA kernel parameters. struct upfirdn2d_kernel_params { const void* x; const float* f; void* y; int2 up; int2 down; int2 pad0; int flip; float gain; int4 inSize; // [width, height, channel, batch] int4 inStride; int2 filterSize; // [width, height] int2 filterStride; int4 outSize; // [width, height, channel, batch] int4 outStride; int sizeMinor; int sizeMajor; int loopMinor; int loopMajor; int loopX; int launchMinor; int launchMajor; }; //------------------------------------------------------------------------ // CUDA kernel specialization. struct upfirdn2d_kernel_spec { void* kernel; int tileOutW; int tileOutH; int loopMinor; int loopX; }; //------------------------------------------------------------------------ // CUDA kernel selection. template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); //------------------------------------------------------------------------ ================================================ FILE: ADD/th_utils/ops/upfirdn2d.py ================================================ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Custom PyTorch ops for efficient resampling of 2D images.""" import os import numpy as np import torch from .. import custom_ops from .. import misc from . import conv2d_gradfix #---------------------------------------------------------------------------- _plugin = None def _init(): global _plugin if _plugin is None: _plugin = custom_ops.get_plugin( module_name='upfirdn2d_plugin', sources=['upfirdn2d.cpp', 'upfirdn2d.cu'], headers=['upfirdn2d.h'], source_dir=os.path.dirname(__file__), extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'], ) return True def _parse_scaling(scaling): if isinstance(scaling, int): scaling = [scaling, scaling] assert isinstance(scaling, (list, tuple)) assert all(isinstance(x, int) for x in scaling) sx, sy = scaling assert sx >= 1 and sy >= 1 return sx, sy def _parse_padding(padding): if isinstance(padding, int): padding = [padding, padding] assert isinstance(padding, (list, tuple)) assert all(isinstance(x, int) for x in padding) if len(padding) == 2: padx, pady = padding padding = [padx, padx, pady, pady] padx0, padx1, pady0, pady1 = padding return padx0, padx1, pady0, pady1 def _get_filter_size(f): if f is None: return 1, 1 assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] fw = f.shape[-1] fh = f.shape[0] with misc.suppress_tracer_warnings(): fw = int(fw) fh = int(fh) misc.assert_shape(f, [fh, fw][:f.ndim]) assert fw >= 1 and fh >= 1 return fw, fh #---------------------------------------------------------------------------- def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: f: Torch tensor, numpy array, or python list of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), `[]` (impulse), or `None` (identity). device: Result device (default: cpu). normalize: Normalize the filter so that it retains the magnitude for constant input signal (DC)? (default: True). flip_filter: Flip the filter? (default: False). gain: Overall scaling factor for signal magnitude (default: 1). separable: Return a separable filter? (default: select automatically). Returns: Float32 tensor of the shape `[filter_height, filter_width]` (non-separable) or `[filter_taps]` (separable). """ # Validate. if f is None: f = 1 f = torch.as_tensor(f, dtype=torch.float32) assert f.ndim in [0, 1, 2] assert f.numel() > 0 if f.ndim == 0: f = f[np.newaxis] # Separable? if separable is None: separable = (f.ndim == 1 and f.numel() >= 8) if f.ndim == 1 and not separable: f = f.ger(f) assert f.ndim == (1 if separable else 2) # Apply normalize, flip, gain, and device. if normalize: f /= f.sum() if flip_filter: f = f.flip(list(range(f.ndim))) f = f * (gain ** (f.ndim / 2)) f = f.to(device=device) return f #---------------------------------------------------------------------------- def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Pad, upsample, filter, and downsample a batch of 2D images. Performs the following sequence of operations for each channel: 1. Upsample the image by inserting N-1 zeros after each pixel (`up`). 2. Pad the image with the specified number of zeros on each side (`padding`). Negative padding corresponds to cropping the image. 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it so that the footprint of all output pixels lies within the input image. 4. Downsample the image by keeping every Nth pixel (`down`). This sequence of operations bears close resemblance to scipy.signal.upfirdn(). The fused op is considerably more efficient than performing the same calculation using standard PyTorch ops. It supports gradients of arbitrary order. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the upsampled image. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ assert isinstance(x, torch.Tensor) assert impl in ['ref', 'cuda'] if impl == 'cuda' and x.device.type == 'cuda' and _init(): return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f) return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) #---------------------------------------------------------------------------- @misc.profiled_function def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. """ # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] assert f.dtype == torch.float32 and not f.requires_grad batch_size, num_channels, in_height, in_width = x.shape upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Check that upsampled buffer is not smaller than the filter. upW = in_width * upx + padx0 + padx1 upH = in_height * upy + pady0 + pady1 assert upW >= f.shape[-1] and upH >= f.shape[0] # Upsample by inserting zeros. x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1]) x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1]) x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)] # Setup filter. f = f * (gain ** (f.ndim / 2)) f = f.to(x.dtype) if not flip_filter: f = f.flip(list(range(f.ndim))) # Convolve with the filter. f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim) if f.ndim == 4: x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels) else: x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels) x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels) # Downsample by throwing away pixels. x = x[:, :, ::downy, ::downx] return x #---------------------------------------------------------------------------- _upfirdn2d_cuda_cache = dict() def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1): """Fast CUDA implementation of `upfirdn2d()` using custom ops. """ # Parse arguments. upx, upy = _parse_scaling(up) downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) # Lookup from cache. key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) if key in _upfirdn2d_cuda_cache: return _upfirdn2d_cuda_cache[key] # Forward op. class Upfirdn2dCuda(torch.autograd.Function): @staticmethod def forward(ctx, x, f): # pylint: disable=arguments-differ assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: f = torch.ones([1, 1], dtype=torch.float32, device=x.device) if f.ndim == 1 and f.shape[0] == 1: f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1. assert isinstance(f, torch.Tensor) and f.ndim in [1, 2] y = x if f.ndim == 2: y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain) else: y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0) y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain) ctx.save_for_backward(f) ctx.x_shape = x.shape return y @staticmethod def backward(ctx, dy): # pylint: disable=arguments-differ f, = ctx.saved_tensors _, _, ih, iw = ctx.x_shape _, _, oh, ow = dy.shape fw, fh = _get_filter_size(f) p = [ fw - padx0 - 1, iw * upx - ow * downx + padx0 - upx + 1, fh - pady0 - 1, ih * upy - oh * downy + pady0 - upy + 1, ] dx = None df = None if ctx.needs_input_grad[0]: dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f) assert not ctx.needs_input_grad[1] return dx, df # Add to cache. _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda return Upfirdn2dCuda #---------------------------------------------------------------------------- def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Filter a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape matches the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + fw // 2, padx1 + (fw - 1) // 2, pady0 + fh // 2, pady1 + (fh - 1) // 2, ] return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Upsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a multiple of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). up: Integer upsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the output. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ upx, upy = _parse_scaling(up) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw + upx - 1) // 2, padx1 + (fw - upx) // 2, pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl) #---------------------------------------------------------------------------- def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): r"""Downsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a fraction of the input. User-specified padding is applied on top of that, with negative values indicating cropping. Pixels outside the image are assumed to be zero. Args: x: Float32/float64/float16 input tensor of the shape `[batch_size, num_channels, in_height, in_width]`. f: Float32 FIR filter of the shape `[filter_height, filter_width]` (non-separable), `[filter_taps]` (separable), or `None` (identity). down: Integer downsampling factor. Can be a single int or a list/tuple `[x, y]` (default: 1). padding: Padding with respect to the input. Can be a single number or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` (default: 0). flip_filter: False = convolution, True = correlation (default: False). gain: Overall scaling factor for signal magnitude (default: 1). impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`). Returns: Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. """ downx, downy = _parse_scaling(down) padx0, padx1, pady0, pady1 = _parse_padding(padding) fw, fh = _get_filter_size(f) p = [ padx0 + (fw - downx + 1) // 2, padx1 + (fw - downx) // 2, pady0 + (fh - downy + 1) // 2, pady1 + (fh - downy) // 2, ] return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) #---------------------------------------------------------------------------- ================================================ FILE: ADD/utils/util_net.py ================================================ #!/usr/bin/env python # -*- coding:utf-8 -*- # Power by Zongsheng Yue 2021-11-24 20:29:36 import math import torch from pathlib import Path from collections import OrderedDict import torch.nn.functional as F from copy import deepcopy def calculate_parameters(net): out = 0 for param in net.parameters(): out += param.numel() return out def pad_input(x, mod): h, w = x.shape[-2:] bottom = int(math.ceil(h/mod)*mod -h) right = int(math.ceil(w/mod)*mod - w) x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect') return x_pad def forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000): n_GPUs = 1 b, c, h, w = x.size() h_half, w_half = h // 2, w // 2 h_size, w_size = h_half + shave, w_half + shave lr_list = [ x[:, :, 0:h_size, 0:w_size], x[:, :, 0:h_size, (w - w_size):w], x[:, :, (h - h_size):h, 0:w_size], x[:, :, (h - h_size):h, (w - w_size):w]] if w_size * h_size < min_size: sr_list = [] for i in range(0, 4, n_GPUs): lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) if net_kwargs is None: sr_batch = net(lr_batch) else: sr_batch = net(lr_batch, **net_kwargs) sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) else: sr_list = [ forward_chop(patch, shave=shave, min_size=min_size) \ for patch in lr_list ] h, w = scale * h, scale * w h_half, w_half = scale * h_half, scale * w_half h_size, w_size = scale * h_size, scale * w_size shave *= scale output = x.new(b, c, h, w) output[:, :, 0:h_half, 0:w_half] \ = sr_list[0][:, :, 0:h_half, 0:w_half] output[:, :, 0:h_half, w_half:w] \ = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] output[:, :, h_half:h, 0:w_half] \ = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] output[:, :, h_half:h, w_half:w] \ = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] return output def measure_time(net, inputs, num_forward=100): ''' Measuring the average runing time (seconds) for pytorch. out = net(*inputs) ''' start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() with torch.set_grad_enabled(False): for _ in range(num_forward): out = net(*inputs) end.record() torch.cuda.synchronize() return start.elapsed_time(end) / 1000 def reload_model(model, ckpt): if list(model.state_dict().keys())[0].startswith('module.'): if list(ckpt.keys())[0].startswith('module.'): ckpt = ckpt else: ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()}) else: if list(ckpt.keys())[0].startswith('module.'): ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()}) else: ckpt = ckpt model.load_state_dict(ckpt, True) def compute_hinge_loss(real_output, fake_output, x_start_, r1_lambda): if r1_lambda == 0: real_loss_total = torch.relu(torch.ones_like(real_output) - real_output).mean() fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean() else: real_loss_ = torch.relu(torch.ones_like(real_output) - real_output).mean() # 计算真实样本的梯度 grad_real = torch.autograd.grad(outputs=real_output.sum(), inputs=x_start_, create_graph=True)[0] # 计算梯度惩罚 grad_penalty = (grad_real.contiguous().view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() * r1_lambda real_loss_total = real_loss_ + grad_penalty fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean() real_loss = real_loss_total fake_loss = fake_loss_total loss_d = real_loss + fake_loss return loss_d def reload_model_(model, ckpt): if list(model.state_dict().keys())[0].startswith('model.'): if list(ckpt.keys())[0].startswith('model.'): ckpt = ckpt else: ckpt = OrderedDict({f'model.{key}':value for key, value in ckpt.items()}) else: if list(ckpt.keys())[0].startswith('model.'): ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()}) else: ckpt = ckpt model.load_state_dict(ckpt, True) def reload_model_IDE(model, ckpt): extracted_dict = OrderedDict() for key, value in ckpt.items(): if key.startswith('E_st'): new_key = key.replace('E_st.', '') extracted_dict[new_key] = value model.load_state_dict(extracted_dict, True) class EMA(): def __init__(self, model, decay): self.model = model self.decay = decay self.shadow = {} self.backup = {} def register(self): for name, param in self.model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] self.shadow[name] = new_average.clone() def apply_shadow(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.shadow self.backup[name] = param.data param.data = self.shadow[name] def restore(self): for name, param in self.model.named_parameters(): if param.requires_grad: assert name in self.backup param.data = self.backup[name] self.backup = {} ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================

Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution

[Lingchen Sun](https://scholar.google.com/citations?hl=zh-CN&tzom=-480&user=ZCDjTn8AAAAJ)1,2 | [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ&hl=zh-CN)1,2 | [Jie Liang](https://scholar.google.com.sg/citations?user=REWxLZsAAAAJ&hl)2 | [Zhengqiang Zhang](https://scholar.google.com/citations?hl=zh-CN&user=UX26wSMAAAAJ&view_op=list_works&sortby=pubdate)1,2 | [Hongwei Yong](https://scholar.google.com.hk/citations?user=Xii74qQAAAAJ&hl=zh-CN)1 | [Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)1,2 1The Hong Kong Polytechnic University, 2OPPO Research Institute
:star: If CCSR is helpful to your images or projects, please help star this repo. Thanks! :hugs: ## 🧡ྀི What's New in CCSR-v2? We have implemented the CCSR-v2 code based on the [Diffusers](https://github.com/huggingface/diffusers). Compared to CCSR-v1, CCSR-v2 brings a host of upgrades: - 🛠️**Step Flexibility**: Offers flexibility in diffusion step selection, **allowing users to freely adjust the number of steps to suit their specific requirements**. This adaptability **requires no additional re-training**, ensuring seamless integration into diverse workflows. - ⚡**Efficiency**: Supports highly efficient inference with **as few as 2 or even 1 diffusion step**, drastically reducing computation time without compromising quality. - 📈**Enhanced Clarity**: With upgraded algorithms, CCSR-v2 restores images with crisper details while maintaining fidelity. - ⚖️**Results stability**: CCSR-v2 exhibits significantly improved stability in synthesizing fine image details, ensuring higher-quality outputs. - 🔄**Stage 2 Refinement**: In CCSR-v2, the output $\hat{x}_{0 \gets T}$ from Stage 1 is now directly fed into Stage 2, streamlining the restoration process into an efficient one-step diffusion workflow. This strategy boosts both speed and performance. ![ccsr](figs/fig.png) Visual comparisons between the SR outputs with the same input low-quality image but two different noise samples by different DM-based methods. `S` denotes diffusion sampling timesteps. Existing DM-based methods, including StableSR, PASD, SeeSR, SUPIR and AddSR, **show noticeable instability with the different noise samples**. OSEDiff directly takes low-quality image as input without noise sampling. It is deterministic and stable, but **cannot perform multi-step diffusion** for high generative capacity. In contrast, **our proposed CCSR method is flexible for both multi-step diffusion and single-step diffusion, while producing stable results with high fidelity and visual quality**. ## ⏰ Update - **2024.12.12**: Code and models for CCSR-v2 are released. 👀 Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v2.0). - **2024.9.25**: ⭐[CCSR-v2](https://arxiv.org/pdf/2401.00877) is released, offering reduced step requirements and supporting flexible diffusion step selection (2 or even 1 step) during the inference stage without the need for re-training. - **2023.12.23**: Code and models for [CCSR-v1](https://arxiv.org/pdf/2401.00877v1) are released. Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v1.0). ## 🌟 Overview Framework ![ccsr](figs/framework.png) ## 😍 Visual Results ### Demo on Real-world SR [](https://imgsli.com/MzI2MTg5) [](https://imgsli.com/MzI2MTky/1/3) [](https://imgsli.com/MzI2MTk0/0/2) [](https://imgsli.com/MzI2MTk1/0/2) ![ccsr](figs/compare_standard.png) ![ccsr](figs/compare_efficient.png) For more comparisons, please refer to our paper for details. ## 📝 Quantitative comparisons We propose new stability metrics, namely global standard deviation (G-STD) and local standard deviation (L-STD), to respectively measure the image-level and pixel-level variations of the SR results of diffusion-based methods. More details about G-STD and L-STD can be found in our paper. ![ccsr](figs/table.png) ## ⚙ Dependencies and Installation ```shell ## git clone this repository git clone https://github.com/csslc/CCSR.git cd CCSR # create an environment with python >= 3.9 conda create -n ccsr python=3.9 conda activate ccsr pip install -r requirements.txt ``` ## 🍭 Quick Inference **For ease of comparison, we have provided the test results of CCSR-v2 on the DIV2K, RealSR, and DrealSR benchmarks with varying diffusion steps, which can be accessed via [Google Drive](https://drive.google.com/drive/folders/1xjURQZgKAlENzMnAJA2PDG9h_UxfZzio?usp=sharing).** #### Step 1: Download the pretrained models - Download the pretrained SD-2.1-base models from [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). - Download the CCSR-v2 models from and put the models in the `preset/models`: | Model Name | Description | GoogleDrive | BaiduNetdisk | |:-----------------------|:---------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------| | Controlnet | Trained in the stage 1. | [download](https://drive.google.com/drive/folders/1aHwgodKwKYZJBKs0QlFzanSjMDhrNyRA?usp=sharing) | [download](https://pan.baidu.com/s/1SKS70iE4GhhHGxqY1KS8mw) (pwd: ccsr) | | VAE | Trained in the stage 2. | [download](https://drive.google.com/drive/folders/1yHfMV81Md6db4StHTP5MC-eSeLFeBKm8?usp=sharing) | [download](https://pan.baidu.com/s/1fxOIeL6Hk6Muq9h8itAIKQ) (pwd: ccsr) | | Pre-trained Controlnet | The pre-trained model of stage1. | [download](https://drive.google.com/drive/folders/1LTtBRuObITOJwbW-sTDnHtp8xIUZFDHh?usp=sharing) | [download](https://pan.baidu.com/s/1mDeuHBqNj_Iol7PCY_Xfww) (pwd: ccsr) | | Dino models | The pre-trained models for disc. | [download](https://drive.google.com/drive/folders/1PcuZGUTJlltdPz2yk2ZIa4GCtb1yk_y6?usp=sharing) | [download](https://pan.baidu.com/s/1nPdNwgua91mDDRApWUm39Q) (pwd: ccsr) | #### Step 2: Prepare testing data You can put the testing images in the `preset/test_datasets`. #### Step 3: Running testing command For one-step diffusion process: ``` python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 1 \ --t_min 0.0 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 1.0 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 ``` For multi-step diffusion process: ``` python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 6 \ --t_max 0.6667 \ --t_min 0.5 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 4.5 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 ``` We integrate [tile_diffusion](https://github.com/albarji/mixture-of-diffusers) and [tile_vae](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/tree/main) to the [test_ccsr_tile.py](test_ccsr_tile.py) to save the GPU memory for inference. You can change the tile size and stride according to the VRAM of your device. ``` python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 6 \ --t_max 0.6667 \ --t_min 0.5 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 4.5 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 \ --tile_diffusion \ --tile_diffusion_size 512 \ --tile_diffusion_stride 256 \ --tile_vae \ --vae_decoder_tile_size 224 \ --vae_encoder_tile_size 1024 \ ``` You can obtain `N` different SR results by setting `sample_times` as `N` to test the stability of CCSR. The data folder should be like this: ``` experiments/test ├── sample00 # the first group of SR results └── sample01 # the second group of SR results ... └── sampleN # the N-th group of SR results ``` ## 📏 Evaluation 1. Calculate the Image Quality Assessment for each restored group. Fill in the required information in [cal_iqa.py](cal_iqa/cal_iqa.py) and run, then you can obtain the evaluation results in the folder like this: ``` log_path ├── log_name_npy # save the IQA values of each restored group as the npy files └── log_name.log # log recode ``` 2. Calculate the G-STD value for the diffusion-based SR method. Fill in the required information in [iqa_G-STD.py](cal_iqa/iqa_G-STD.py) and run, then you can obtain the mean IQA values of N restored groups and G-STD value. 3. Calculate the L-STD value for the diffusion-based SR method. Fill in the required information in [iqa_L-STD.py](cal_iqa/iqa_L-STD.py) and run, then you can obtain the L-STD value. ## 🚋 Train #### Step1: Prepare training data Generate txt file for the training set. Fill in the required information in [get_path](scripts/get_path.py) and run, then you can obtain the txt file recording the paths of ground-truth images. You can save the txt file into `preset/gt_path.txt`. #### Step2: Train Stage1 Model 1. Download pretrained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to provide generative capabilities. ```shell wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt --no-check-certificate ``` 2. Start training. ```shell CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage1.py \ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \ --controlnet_model_name_or_path='preset/models/pretrained_controlnet' \ --enable_xformers_memory_efficient_attention \ --output_dir="./experiments/ccsrv2_stage1" \ --mixed_precision="fp16" \ --resolution=512 \ --learning_rate=5e-5 \ --train_batch_size=4 \ --gradient_accumulation_steps=6 \ --dataloader_num_workers=0 \ --checkpointing_steps=500 \ --t_max=0.6667 \ --max_train_steps=20000 \ --dataset_root_folders 'preset/gt_path.txt' ``` #### Step3: Train Stage2 Model 1. Put the model obtained from the stage1 into `controlnet_model_name_or_path`. 2. Start training. ```shell CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage2.py \ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \ --controlnet_model_name_or_path='preset/models/model_stage1' \ --enable_xformers_memory_efficient_attention \ --output_dir="./experiments/ccsrv2_stage2" \ --mixed_precision="fp16" \ --resolution=512 \ --learning_rate=5e-6 \ --train_batch_size=2 \ --gradient_accumulation_steps=8 \ --checkpointing_steps=500 \ --is_start_lr=True \ --t_max=0.6667 \ --num_inference_steps=1 \ --is_module \ --lambda_l2=1.0 \ --lambda_lpips=1.0 \ --lambda_disc=0.05 \ --lambda_disc_train=0.5 \ --begin_disc=100 \ --max_train_steps=2000 \ --dataset_root_folders 'preset/gt_path.txt' ``` ### Citations If our code helps your research or work, please consider citing our paper. The following are BibTeX references: ``` @article{sun2023ccsr, title={Improving the Stability of Diffusion Models for Content Consistent Super-Resolution}, author={Sun, Lingchen and Wu, Rongyuan and Zhang, Zhengqiang and Yong, Hongwei and Zhang, Lei}, journal={arXiv preprint arXiv:2401.00877}, year={2024} } ``` ### License This project is released under the [Apache 2.0 license](LICENSE). ### Acknowledgement This project is based on [ControlNet](https://github.com/lllyasviel/ControlNet), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [SeeSR](https://github.com/cswry/SeeSR). Some codes are brought from [ADDSR](https://github.com/NJU-PCALab/AddSR). Thanks for their awesome works. ### Contact If you have any questions, please contact: ling-chen.sun@connect.polyu.hk
statistics ![visitors](https://visitor-badge.laobi.icu/badge?page_id=csslc/CCSR)
================================================ FILE: dataloaders/paired_dataset_txt.py ================================================ import glob import os from PIL import Image import random import numpy as np from torch import nn from torchvision import transforms from torch.utils import data as data import torch.nn.functional as F from .realesrgan import RealESRGAN_degradation class PairedCaptionDataset(data.Dataset): def __init__( self, root_folders=None, tokenizer=None, gt_ratio=0, # let lr is gt ): super(PairedCaptionDataset, self).__init__() self.gt_ratio = gt_ratio with open(root_folders, 'r') as f: self.gt_list = [line.strip() for line in f.readlines()] self.img_preproc = transforms.Compose([ transforms.RandomCrop((512, 512)), transforms.Resize((512, 512)), transforms.RandomHorizontalFlip(), ]) self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda') self.tokenizer = tokenizer def tokenize_caption(self, caption=""): inputs = self.tokenizer( caption, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ) return inputs.input_ids def __getitem__(self, index): gt_path = self.gt_list[index] gt_img = Image.open(gt_path).convert('RGB') gt_img = self.img_preproc(gt_img) gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True) if random.random() < self.gt_ratio: lq_img = gt_img else: lq_img = img_t # no caption used lq_caption = '' example = dict() example["conditioning_pixel_values"] = lq_img.squeeze(0) # [0, 1] example["pixel_values"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1] example["input_caption"] = self.tokenize_caption(caption=lq_caption).squeeze(0) lq_img = lq_img.squeeze() return example def __len__(self): return len(self.gt_list) ================================================ FILE: dataloaders/params_ccsr.yml ================================================ scale: 4 color_jitter_prob: 0.0 gray_prob: 0.0 # the first degradation process resize_prob: [0.2, 0.7, 0.1] # up, down, keep resize_range: [0.3, 1.5] gaussian_noise_prob: 0.5 noise_range: [1, 15] poisson_scale_range: [0.05, 2.0] gray_noise_prob: 0.4 jpeg_range: [60, 95] # the second degradation process second_blur_prob: 0.5 resize_prob2: [0.3, 0.4, 0.3] # up, down, keep resize_range2: [0.6, 1.2] gaussian_noise_prob2: 0.5 noise_range2: [1, 12] poisson_scale_range2: [0.05, 1.0] gray_noise_prob2: 0.4 jpeg_range2: [60, 100] kernel_info: blur_kernel_size: 21 kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] sinc_prob: 0.1 blur_sigma: [0.2, 1.5] betag_range: [0.5, 2.0] betap_range: [1, 1.5] blur_kernel_size2: 11 kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'] kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03] sinc_prob2: 0.1 blur_sigma2: [0.2, 1.0] betag_range2: [0.5, 2.0] betap_range2: [1, 1.5] final_sinc_prob: 0.8 ================================================ FILE: dataloaders/realesrgan.py ================================================ import os import numpy as np import cv2 import glob import math import yaml import random from collections import OrderedDict import torch import torch.nn.functional as F from basicsr.data.transforms import augment from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels from basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img from basicsr.utils.img_process_util import filter2D from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, normalize, rgb_to_grayscale) cur_path = os.path.dirname(os.path.abspath(__file__)) def ordered_yaml(): """Support OrderedDict for yaml. Returns: yaml Loader and Dumper. """ try: from yaml import CDumper as Dumper from yaml import CLoader as Loader except ImportError: from yaml import Dumper, Loader _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG def dict_representer(dumper, data): return dumper.represent_dict(data.items()) def dict_constructor(loader, node): return OrderedDict(loader.construct_pairs(node)) Dumper.add_representer(OrderedDict, dict_representer) Loader.add_constructor(_mapping_tag, dict_constructor) return Loader, Dumper def opt_parse(opt_path): with open(opt_path, mode='r') as f: Loader, _ = ordered_yaml() opt = yaml.load(f, Loader=Loader) return opt class RealESRGAN_degradation(object): def __init__(self, opt_path='', device='cpu'): self.opt = opt_parse(opt_path) self.device = device #torch.device('cpu') optk = self.opt['kernel_info'] # blur settings for the first degradation self.blur_kernel_size = optk['blur_kernel_size'] self.kernel_list = optk['kernel_list'] self.kernel_prob = optk['kernel_prob'] self.blur_sigma = optk['blur_sigma'] self.betag_range = optk['betag_range'] self.betap_range = optk['betap_range'] self.sinc_prob = optk['sinc_prob'] # blur settings for the second degradation self.blur_kernel_size2 = optk['blur_kernel_size2'] self.kernel_list2 = optk['kernel_list2'] self.kernel_prob2 = optk['kernel_prob2'] self.blur_sigma2 = optk['blur_sigma2'] self.betag_range2 = optk['betag_range2'] self.betap_range2 = optk['betap_range2'] self.sinc_prob2 = optk['sinc_prob2'] # a final sinc filter self.final_sinc_prob = optk['final_sinc_prob'] self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect self.pulse_tensor[10, 10] = 1 self.jpeger = DiffJPEG(differentiable=False).to(self.device) self.usm_shaper = USMSharp().to(self.device) def color_jitter_pt(self, img, brightness, contrast, saturation, hue): fn_idx = torch.randperm(4) for fn_id in fn_idx: if fn_id == 0 and brightness is not None: brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() img = adjust_brightness(img, brightness_factor) if fn_id == 1 and contrast is not None: contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() img = adjust_contrast(img, contrast_factor) if fn_id == 2 and saturation is not None: saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() img = adjust_saturation(img, saturation_factor) if fn_id == 3 and hue is not None: hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() img = adjust_hue(img, hue_factor) return img def random_augment(self, img_gt): # random horizontal flip img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True) """ # random color jitter if np.random.uniform() < self.opt['color_jitter_prob']: jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) img_gt = img_gt + jitter_val img_gt = np.clip(img_gt, 0, 1) # random grayscale if np.random.uniform() < self.opt['gray_prob']: #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY) img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) """ # BGR to RGB, HWC to CHW, numpy to tensor img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0) return img_gt def random_kernels(self): # ------------------------ Generate kernels (used in the first degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.sinc_prob: # this sinc filter setting is for kernels ranging from [7, 21] if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel = random_mixed_kernels( self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma, self.blur_sigma, [-math.pi, math.pi], self.betag_range, self.betap_range, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------ Generate kernels (used in the second degradation) ------------------------ # kernel_size = random.choice(self.kernel_range) if np.random.uniform() < self.sinc_prob2: if kernel_size < 13: omega_c = np.random.uniform(np.pi / 3, np.pi) else: omega_c = np.random.uniform(np.pi / 5, np.pi) kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) else: kernel2 = random_mixed_kernels( self.kernel_list2, self.kernel_prob2, kernel_size, self.blur_sigma2, self.blur_sigma2, [-math.pi, math.pi], self.betag_range2, self.betap_range2, noise_range=None) # pad kernel pad_size = (21 - kernel_size) // 2 kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) # ------------------------------------- sinc kernel ------------------------------------- # if np.random.uniform() < self.final_sinc_prob: kernel_size = random.choice(self.kernel_range) omega_c = np.random.uniform(np.pi / 3, np.pi) sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) sinc_kernel = torch.FloatTensor(sinc_kernel) else: sinc_kernel = self.pulse_tensor kernel = torch.FloatTensor(kernel) kernel2 = torch.FloatTensor(kernel2) return kernel, kernel2, sinc_kernel @torch.no_grad() def degrade_process(self, img_gt, resize_bak=False): img_gt = self.random_augment(img_gt) kernel1, kernel2, sinc_kernel = self.random_kernels() img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device) #img_gt = self.usm_shaper(img_gt) # shaper gt ori_h, ori_w = img_gt.size()[2:4] #scale_final = random.randint(4, 16) scale_final = 4 # ----------------------- The first degradation process ----------------------- # # blur out = filter2D(img_gt, kernel1) # random resize updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] if updown_type == 'up': scale = np.random.uniform(1, self.opt['resize_range'][1]) elif updown_type == 'down': scale = np.random.uniform(self.opt['resize_range'][0], 1) else: scale = 1 mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, scale_factor=scale, mode=mode) # noise gray_noise_prob = self.opt['gray_noise_prob'] if np.random.uniform() < self.opt['gaussian_noise_prob']: out = random_add_gaussian_noise_pt( out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) else: out = random_add_poisson_noise_pt( out, scale_range=self.opt['poisson_scale_range'], gray_prob=gray_noise_prob, clip=True, rounds=False) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) out = torch.clamp(out, 0, 1) out = self.jpeger(out, quality=jpeg_p) # ----------------------- The second degradation process ----------------------- # # blur if np.random.uniform() < self.opt['second_blur_prob']: out = filter2D(out, kernel2) # random resize updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] if updown_type == 'up': scale = np.random.uniform(1, self.opt['resize_range2'][1]) elif updown_type == 'down': scale = np.random.uniform(self.opt['resize_range2'][0], 1) else: scale = 1 mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate( out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode) # noise gray_noise_prob = self.opt['gray_noise_prob2'] if np.random.uniform() < self.opt['gaussian_noise_prob2']: out = random_add_gaussian_noise_pt( out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) else: out = random_add_poisson_noise_pt( out, scale_range=self.opt['poisson_scale_range2'], gray_prob=gray_noise_prob, clip=True, rounds=False) # JPEG compression + the final sinc filter # We also need to resize images to desired sizes. We group [resize back + sinc filter] together # as one operation. # We consider two orders: # 1. [resize back + sinc filter] + JPEG compression # 2. JPEG compression + [resize back + sinc filter] # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. if np.random.uniform() < 0.5: # resize back + the final sinc filter mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) out = filter2D(out, sinc_kernel) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) out = torch.clamp(out, 0, 1) out = self.jpeger(out, quality=jpeg_p) else: # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) out = torch.clamp(out, 0, 1) out = self.jpeger(out, quality=jpeg_p) # resize back + the final sinc filter mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode) out = filter2D(out, sinc_kernel) if np.random.uniform() < self.opt['gray_prob']: out = rgb_to_grayscale(out, num_output_channels=1) if np.random.uniform() < self.opt['color_jitter_prob']: brightness = self.opt.get('brightness', (0.5, 1.5)) contrast = self.opt.get('contrast', (0.5, 1.5)) saturation = self.opt.get('saturation', (0, 1.5)) hue = self.opt.get('hue', (-0.1, 0.1)) out = self.color_jitter_pt(out, brightness, contrast, saturation, hue) if resize_bak: mode = random.choice(['area', 'bilinear', 'bicubic']) out = F.interpolate(out, size=(ori_h, ori_w), mode=mode) # clamp and round img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. return img_gt, img_lq ================================================ FILE: models/DiffAugment.py ================================================ # BSD 2-Clause "Simplified" License # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # * Redistributions of source code must retain the above copyright notice, this # list of conditions and the following disclaimer. # # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # # Code from https://github.com/mit-han-lab/data-efficient-gans """Training GANs with DiffAugment.""" import numpy as np import torch import torch.nn.functional as F def DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor: if policy: if not channels_first: x = x.permute(0, 3, 1, 2) for p in policy.split(','): for f in AUGMENT_FNS[p]: x = f(x) if not channels_first: x = x.permute(0, 2, 3, 1) x = x.contiguous() return x def rand_brightness(x: torch.Tensor) -> torch.Tensor: x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) return x def rand_saturation(x: torch.Tensor) -> torch.Tensor: x_mean = x.mean(dim=1, keepdim=True) x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean return x def rand_contrast(x: torch.Tensor) -> torch.Tensor: x_mean = x.mean(dim=[1, 2, 3], keepdim=True) x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean return x def rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor: shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(x.size(2), dtype=torch.long, device=x.device), torch.arange(x.size(3), dtype=torch.long, device=x.device), ) grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) return x def rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor: cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) grid_batch, grid_x, grid_y = torch.meshgrid( torch.arange(x.size(0), dtype=torch.long, device=x.device), torch.arange(cutout_size[0], dtype=torch.long, device=x.device), torch.arange(cutout_size[1], dtype=torch.long, device=x.device), ) grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) mask[grid_batch, grid_x, grid_y] = 0 x = x * mask.unsqueeze(1) return x def rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor: resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear') org_size = x.shape[3] if int(resize_ratio*x.shape[3]) < x.shape[3]: left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2. left_pad = int(left_pad) right_pad = x.shape[3] - left_pad - resized_img.shape[3] x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), "constant", 0.) else: left = (int(resize_ratio*x.shape[3])-x.shape[3])/2. left = int(left) x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])] assert x.shape[2] == org_size assert x.shape[3] == org_size return x AUGMENT_FNS = { 'color': [rand_brightness, rand_saturation, rand_contrast], 'translation': [rand_translation], 'resize': [rand_resize], 'cutout': [rand_cutout], } ================================================ FILE: models/controlnet.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from torch.nn import functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlnetMixin from diffusers.utils import BaseOutput, logging from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, DownBlock2D, UNetMidBlock2DCrossAttn, get_down_block, ) from .unet_2d_condition import UNet2DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): """ The output of [`ControlNetModel`]. Args: down_block_res_samples (`tuple[torch.Tensor]`): A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be used to condition the original UNet's downsampling activations. mid_down_block_re_sample (`torch.Tensor`): The activation of the midde block (the lowest sample resolution). Each tensor should be of shape `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. Output can be used to condition the original UNet's middle block activation. """ down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor class ControlNetConditioningEmbedding(nn.Module): """ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into feature maps ..." """ def __init__( self, conditioning_embedding_channels: int, conditioning_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), ): super().__init__() self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) self.blocks = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) self.conv_out = zero_module( nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) ) def forward(self, conditioning): embedding = self.conv_in(conditioning) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) embedding = self.conv_out(embedding) return embedding class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): """ A ControlNet model. Args: in_channels (`int`, defaults to 4): The number of channels in the input sample. flip_sin_to_cos (`bool`, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, defaults to 2): The number of layers per block. downsample_padding (`int`, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, defaults to 1): The scale factor to use for the mid block. act_fn (`str`, defaults to "silu"): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If None, normalization and activation layers is skipped in post-processing. norm_eps (`float`, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. upcast_attention (`bool`, defaults to `False`): resnet_time_scale_shift (`str`, defaults to `"default"`): Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): The tuple of output channel for each block in the `conditioning_embedding` layer. global_pool_conditions (`bool`, defaults to `False`): """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channels: int = 4, conditioning_channels: int = 3, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: int = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", projection_class_embeddings_input_dim: Optional[int] = None, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, addition_embed_type_num_heads=64, use_vae_encode_condition=False, ): super().__init__() # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # use_vae_encode_condition self.use_vae_encode_condition = use_vae_encode_condition if self.use_vae_encode_condition: print(f'============================') print(f'use vae encode condition in CONTROLNET!!!') print(f'============================') self.condition_conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) else: print(f'============================') print(f'Not !!! use vae encode condition in CONTROLNET') print(f'============================') # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], block_out_channels=conditioning_embedding_out_channels, conditioning_channels=conditioning_channels, ) # time time_embed_dim = block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") self.down_blocks = nn.ModuleList([]) self.controlnet_down_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): only_cross_attention = [only_cross_attention] * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) # down output_channel = block_out_channels[0] controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block, transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[i], attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, downsample_padding=downsample_padding, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, ) self.down_blocks.append(down_block) for _ in range(layers_per_block): controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) if not is_final_block: controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_down_blocks.append(controlnet_block) # mid mid_block_channel = block_out_channels[-1] controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) controlnet_block = zero_module(controlnet_block) self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, ) @classmethod def from_unet( cls, unet: UNet2DConditionModel, controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), load_weights_from_unet: bool = True, use_vae_encode_condition: bool = False, ): r""" Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ transformer_layers_per_block = ( unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 ) encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None addition_time_embed_dim = ( unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None ) controlnet = cls( encoder_hid_dim=encoder_hid_dim, encoder_hid_dim_type=encoder_hid_dim_type, addition_embed_type=addition_embed_type, addition_time_embed_dim=addition_time_embed_dim, transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, down_block_types=unet.config.down_block_types, only_cross_attention=unet.config.only_cross_attention, block_out_channels=unet.config.block_out_channels, layers_per_block=unet.config.layers_per_block, downsample_padding=unet.config.downsample_padding, mid_block_scale_factor=unet.config.mid_block_scale_factor, act_fn=unet.config.act_fn, norm_num_groups=unet.config.norm_num_groups, norm_eps=unet.config.norm_eps, cross_attention_dim=unet.config.cross_attention_dim, attention_head_dim=unet.config.attention_head_dim, num_attention_heads=unet.config.num_attention_heads, use_linear_projection=unet.config.use_linear_projection, class_embed_type=unet.config.class_embed_type, num_class_embeds=unet.config.num_class_embeds, upcast_attention=unet.config.upcast_attention, resnet_time_scale_shift=unet.config.resnet_time_scale_shift, projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, conditioning_embedding_out_channels=conditioning_embedding_out_channels, use_vae_encode_condition=use_vae_encode_condition, ) if load_weights_from_unet: controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) if controlnet.class_embedding: controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) return controlnet @property # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, controlnet_cond: torch.FloatTensor, conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, image_encoder_hidden_states: torch.Tensor = None, vae_encode_condition_hidden_states: torch.Tensor = None, use_vae_encode_condition = False, ) -> Union[ControlNetOutput, Tuple]: """ The [`ControlNetModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor. timestep (`Union[torch.Tensor, float, int]`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. controlnet_cond (`torch.FloatTensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): added_cond_kwargs (`dict`): Additional conditions for the Stable Diffusion XL UNet. cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. return_dict (`bool`, defaults to `True`): Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. Returns: [`~models.controlnet.ControlNetOutput`] **or** `tuple`: If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order if channel_order == "rgb": # in rgb order by default ... elif channel_order == "bgr": controlnet_cond = torch.flip(controlnet_cond, dims=[1]) else: raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") # prepare attention_mask if attention_mask is not None: attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb if self.config.addition_embed_type is not None: if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_time": if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) emb = emb + aug_emb if aug_emb is not None else emb # 2. pre-process sample = self.conv_in(sample) if not self.use_vae_encode_condition: controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) else: controlnet_cond = self.condition_conv_in(vae_encode_condition_hidden_states) sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, image_encoder_hidden_states=image_encoder_hidden_states, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) down_block_res_samples += res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, image_encoder_hidden_states=image_encoder_hidden_states, ) # 5. Control net blocks controlnet_down_block_res_samples = () for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling if guess_mode and not self.config.global_pool_conditions: scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples ] mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) if not return_dict: return (down_block_res_samples, mid_block_res_sample) return ControlNetOutput( down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample ) def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module ================================================ FILE: models/losses/__init__.py ================================================ from models.losses.contperceptual import LPIPSWithDiscriminator ================================================ FILE: models/losses/contperceptual.py ================================================ import torch import torch.nn as nn from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? from diffusers.models.modeling_utils import ModelMixin from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalControlnetMixin class LPIPSWithDiscriminator(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_loss="hinge"): super().__init__() assert disc_loss in ["hinge", "vanilla"] self.kl_weight = kl_weight self.pixel_weight = pixelloss_weight self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight # output log variance self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm ).apply(weights_init) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): if last_layer is not None: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] else: nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward(self, inputs, reconstructions, optimizer_idx, global_step, posteriors=None, last_layer=None, cond=None, split="train", weights=None, return_dic=False): rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights*nll_loss weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.mean(nll_loss) / nll_loss.shape[0] if self.kl_weight>0: kl_loss = posteriors.kl() kl_loss = torch.mean(kl_loss) / kl_loss.shape[0] # now the GAN part if optimizer_idx == 0: # generator update if cond is None: assert not self.disc_conditional logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) g_loss = -torch.mean(logits_fake) if self.disc_factor > 0.0: try: d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: # assert not self.training d_weight = torch.tensor(1.0) * self.discriminator_weight else: # d_weight = torch.tensor(0.0) d_weight = torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) if self.kl_weight>0: loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): rec_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } if return_dic: loss_dic = {} loss_dic['total_loss'] = loss.clone().detach().mean() loss_dic['logvar'] = self.logvar.detach() loss_dic['kl_loss'] = kl_loss.detach().mean() loss_dic['nll_loss'] = nll_loss.detach().mean() loss_dic['rec_loss'] = rec_loss.detach().mean() loss_dic['d_weight'] = d_weight.detach() loss_dic['disc_factor'] = torch.tensor(disc_factor) loss_dic['g_loss'] = g_loss.detach().mean() else: loss = weighted_nll_loss + d_weight * disc_factor * g_loss log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): rec_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } if return_dic: loss_dic = {} loss_dic["{}/total_loss".format(split)] = loss.clone().detach().mean() loss_dic["{}/logvar".format(split)] = self.logvar.detach() loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean() loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean() loss_dic['d_weight'.format(split)] = d_weight.detach() loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor) loss_dic['g_loss'.format(split)] = g_loss.detach().mean() if return_dic: return loss, log, loss_dic return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) else: logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), "{}/logits_real".format(split): logits_real.detach().mean(), "{}/logits_fake".format(split): logits_fake.detach().mean() } if return_dic: loss_dic = {} loss_dic["{}/disc_loss".format(split)] = d_loss.clone().detach().mean() loss_dic["{}/logits_real".format(split)] = logits_real.detach().mean() loss_dic["{}/logits_fake".format(split)] = logits_fake.detach().mean() return d_loss, log, loss_dic return d_loss, log ================================================ FILE: models/losses/vqperceptual.py ================================================ import torch from torch import nn import torch.nn.functional as F from einops import repeat from taming.modules.discriminator.model import NLayerDiscriminator, weights_init from taming.modules.losses.lpips import LPIPS from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) loss_real = (weights * loss_real).sum() / weights.sum() loss_fake = (weights * loss_fake).sum() / weights.sum() d_loss = 0.5 * (loss_real + loss_fake) return d_loss def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: weight = value return weight def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): return torch.abs(x-y) def l2(x, y): return torch.pow((x-y), 2) class VQLPIPSWithDiscriminator(nn.Module): def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", pixel_loss="l1"): super().__init__() assert disc_loss in ["hinge", "vanilla"] assert perceptual_loss in ["lpips", "clips", "dists"] assert pixel_loss in ["l1", "l2"] self.codebook_weight = codebook_weight self.pixel_weight = pixelloss_weight if perceptual_loss == "lpips": print(f"{self.__class__.__name__}: Running with LPIPS.") self.perceptual_loss = LPIPS().eval() else: raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") self.perceptual_weight = perceptual_weight if pixel_loss == "l1": self.pixel_loss = l1 else: self.pixel_loss = l2 self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, n_layers=disc_num_layers, use_actnorm=use_actnorm, ndf=disc_ndf ).apply(weights_init) self.discriminator_iter_start = disc_start if disc_loss == "hinge": self.disc_loss = hinge_d_loss elif disc_loss == "vanilla": self.disc_loss = vanilla_d_loss else: raise ValueError(f"Unknown GAN loss '{disc_loss}'.") print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.disc_conditional = disc_conditional self.n_classes = n_classes # def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): # if last_layer is not None: # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] # else: # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() # d_weight = d_weight * self.discriminator_weight # return d_weight def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): # if last_layer is not None: # nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] # g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] # else: # nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] # g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = 1.0 * self.discriminator_weight return d_weight def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, cond=None, split="train", predicted_indices=None): if not exists(codebook_loss): codebook_loss = torch.tensor([0.]).to(inputs.device) #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) rec_loss = rec_loss + self.perceptual_weight * p_loss else: p_loss = torch.tensor([0.0]) nll_loss = rec_loss #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] nll_loss = torch.mean(nll_loss) # now the GAN part if optimizer_idx == 0: # generator update if cond is None: assert not self.disc_conditional logits_fake = self.discriminator(reconstructions.contiguous()) else: assert self.disc_conditional logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) g_loss = -torch.mean(logits_fake) try: d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) except RuntimeError: assert not self.training d_weight = torch.tensor(0.0) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/quant_loss".format(split): codebook_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), "{}/rec_loss".format(split): rec_loss.detach().mean(), "{}/p_loss".format(split): p_loss.detach().mean(), "{}/d_weight".format(split): d_weight.detach(), "{}/disc_factor".format(split): torch.tensor(disc_factor), "{}/g_loss".format(split): g_loss.detach().mean(), } if predicted_indices is not None: assert self.n_classes is not None with torch.no_grad(): perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) log[f"{split}/perplexity"] = perplexity log[f"{split}/cluster_usage"] = cluster_usage return loss, log if optimizer_idx == 1: # second pass for discriminator update if cond is None: logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) else: logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), "{}/logits_real".format(split): logits_real.detach().mean(), "{}/logits_fake".format(split): logits_fake.detach().mean() } return d_loss, log ================================================ FILE: models/shared.py ================================================ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. """Shared architecture blocks.""" from typing import Callable import numpy as np import torch import torch.nn as nn from ADD.th_utils.ops import bias_act class ResidualBlock(nn.Module): def __init__(self, fn: Callable): super().__init__() self.fn = fn def forward(self, x: torch.Tensor) -> torch.Tensor: return (self.fn(x) + x) / np.sqrt(2) class FullyConnectedLayer(nn.Module): def __init__( self, in_features: int, # Number of input features. out_features: int, # Number of output features. bias: bool = True, # Apply additive bias before the activation function? activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier: float = 1.0, # Learning rate multiplier. weight_init: float = 1.0, # Initial standard deviation of the weight tensor. bias_init: float = 0.0, # Initial value for the additive bias. ): super().__init__() self.in_features = in_features self.out_features = out_features self.activation = activation self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier)) bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features]) self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None self.weight_gain = lr_multiplier / np.sqrt(in_features) self.bias_gain = lr_multiplier def forward(self, x: torch.Tensor) -> torch.Tensor: w = self.weight.to(x.dtype) * self.weight_gain b = self.bias if b is not None: b = b.to(x.dtype) if self.bias_gain != 1: b = b * self.bias_gain if self.activation == 'linear' and b is not None: x = torch.addmm(b.unsqueeze(0), x, w.t()) else: x = x.matmul(w.t()) x = bias_act.bias_act(x, b, act=self.activation) return x def extra_repr(self) -> str: return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}' class MLP(nn.Module): def __init__( self, features_list: list[int], # Number of features in each layer of the MLP. activation: str = 'linear', # Activation function: 'relu', 'lrelu', etc. lr_multiplier: float = 1.0, # Learning rate multiplier. linear_out: bool = False # Use the 'linear' activation function for the output layer? ): super().__init__() num_layers = len(features_list) - 1 self.num_layers = num_layers self.out_dim = features_list[-1] for idx in range(num_layers): in_features = features_list[idx] out_features = features_list[idx + 1] if linear_out and idx == num_layers-1: activation = 'linear' layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) setattr(self, f'fc{idx}', layer) def forward(self, x: torch.Tensor) -> torch.Tensor: ''' if x is sequence of tokens, shift tokens to batch and apply MLP to all''' shift2batch = (x.ndim == 3) if shift2batch: B, K, C = x.shape x = x.flatten(0,1) for idx in range(self.num_layers): layer = getattr(self, f'fc{idx}') x = layer(x) if shift2batch: x = x.reshape(B, K, -1) return x ================================================ FILE: models/unet_2d_blocks.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from diffusers.utils import is_torch_version, logging from diffusers.models.activations import get_activation from diffusers.models.attention import AdaGroupNorm from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from diffusers.models.dual_transformer_2d import DualTransformer2DModel from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", attention_type="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, downsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "ResnetDownsampleBlock2D": return ResnetDownsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif down_block_type == "AttnDownBlock2D": if add_downsample is False: downsample_type = None else: downsample_type = downsample_type or "conv" # default to 'conv' return AttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, ) elif down_block_type == "CrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) elif down_block_type == "SimpleCrossAttnDownBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") return SimpleCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif down_block_type == "SkipDownBlock2D": return SkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnSkipDownBlock2D": return AttnSkipDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "DownEncoderBlock2D": return DownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "AttnDownEncoderBlock2D": return AttnDownEncoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif down_block_type == "KDownBlock2D": return KDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif down_block_type == "KCrossAttnDownBlock2D": return KCrossAttnDownBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_downsample=add_downsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, add_self_attention=True if not add_downsample else False, ) raise ValueError(f"{down_block_type} does not exist.") def get_up_block( up_block_type, num_layers, in_channels, out_channels, prev_output_channel, temb_channels, add_upsample, resnet_eps, resnet_act_fn, transformer_layers_per_block=1, num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", attention_type="default", resnet_skip_time_act=False, resnet_out_scale_factor=1.0, cross_attention_norm=None, attention_head_dim=None, upsample_type=None, ): # If attn head dim is not defined, we default it to the number of heads if attention_head_dim is None: logger.warn( f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock2D": return UpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "ResnetUpsampleBlock2D": return ResnetUpsampleBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, ) elif up_block_type == "CrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock2D( num_layers=num_layers, transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, ) elif up_block_type == "SimpleCrossAttnUpBlock2D": if cross_attention_dim is None: raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") return SimpleCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, output_scale_factor=resnet_out_scale_factor, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif up_block_type == "AttnUpBlock2D": if add_upsample is False: upsample_type = None else: upsample_type = upsample_type or "conv" # default to 'conv' return AttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, ) elif up_block_type == "SkipUpBlock2D": return SkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "AttnSkipUpBlock2D": return AttnSkipUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, ) elif up_block_type == "UpDecoderBlock2D": return UpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, ) elif up_block_type == "KCrossAttnUpBlock2D": return KCrossAttnUpBlock2D( num_layers=num_layers, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, add_upsample=add_upsample, resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, cross_attention_dim=cross_attention_dim, attention_head_dim=attention_head_dim, ) raise ValueError(f"{up_block_type} does not exist.") class AutoencoderTinyBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int, act_fn: str): super().__init__() act_fn = get_activation(act_fn) self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), act_fn, nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), act_fn, nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), ) self.skip = ( nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) if in_channels != out_channels else nn.Identity() ) self.fuse = nn.ReLU() def forward(self, x): return self.fuse(self.conv(x) + self.skip(x)) class UNetMidBlock2D(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim=1, output_scale_factor=1.0, ): super().__init__() resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels for _ in range(num_layers): if self.add_attention: attentions.append( Attention( in_channels, heads=in_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) else: attentions.append(None) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, attention_type="default", image_cross_attention_dim=512, ): super().__init__() self.has_cross_attention = True self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ] attentions = [] for _ in range(num_layers): if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, in_channels // num_attention_heads, in_channels=in_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, image_encoder_hidden_states: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) else: hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] hidden_states = resnet(hidden_states, temb) return hidden_states class UNetMidBlock2DSimpleCrossAttn(nn.Module): def __init__( self, in_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, cross_attention_dim=1280, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True self.attention_head_dim = attention_head_dim resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.num_heads = in_channels // self.attention_head_dim # there is always at least one resnet resnets = [ ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ] attentions = [] for _ in range(num_layers): processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=in_channels, cross_attention_dim=in_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) # resnet hidden_states = resnet(hidden_states, temb) return hidden_states class AttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, downsample_padding=1, downsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.downsample_type = downsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if downsample_type == "conv": self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) elif downsample_type == "resnet": self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, down=True, ) ] ) else: self.downsamplers = None def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: if self.downsample_type == "resnet": hidden_states = downsampler(hidden_states, temb=temb) else: hidden_states = downsampler(hidden_states) output_states += (hidden_states,) return hidden_states, output_states class CrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, downsample_padding=1, add_downsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, attention_type="default", image_cross_attention_dim=512, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, additional_residuals=None, image_encoder_hidden_states: Optional[torch.FloatTensor] = None, ): output_states = () blocks = list(zip(self.resnets, self.attentions)) for i, (resnet, attn) in enumerate(blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] # apply additional residuals to the output of the last pair of resnet and attention blocks if i == len(blocks) - 1 and additional_residuals is not None: hidden_states = hidden_states + additional_residuals output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) output_states = output_states + (hidden_states,) return hidden_states, output_states class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnDownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_downsample=True, downsample_padding=1, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class AttnSkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_downsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class SkipDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_downsample=True, downsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(in_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if add_downsample: self.resnet_down = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, down=True, kernel="fir", ) self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) else: self.resnet_down = None self.downsamplers = None self.skip_conv = None def forward(self, hidden_states, temb=None, skip_sample=None): output_states = () for resnet in self.resnets: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: hidden_states = self.resnet_down(hidden_states, temb) for downsampler in self.downsamplers: skip_sample = downsampler(skip_sample) hidden_states = self.skip_conv(skip_sample) + hidden_states output_states += (hidden_states,) return hidden_states, output_states, skip_sample class ResnetDownsampleBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class SimpleCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_downsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() self.has_cross_attention = True resnets = [] attentions = [] self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, down=True, ) ] ) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) output_states = output_states + (hidden_states,) return hidden_states, output_states class KDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, add_downsample=False, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: # YiYi's comments- might be able to use FirDownsample2D, look into details later self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class KCrossAttnDownBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, cross_attention_dim: int, dropout: float = 0.0, num_layers: int = 4, resnet_group_size: int = 32, add_downsample=True, attention_head_dim: int = 64, add_self_attention: bool = False, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, dropout=dropout, temb_channels=temb_channels, groups=groups, groups_out=groups_out, eps=resnet_eps, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( out_channels, out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", group_size=resnet_group_size, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_downsample: self.downsamplers = nn.ModuleList([KDownsample2D()]) else: self.downsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: output_states += (None,) else: output_states += (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states class AttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, upsample_type="conv", ): super().__init__() resnets = [] attentions = [] self.upsample_type = upsample_type if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if upsample_type == "conv": self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) elif upsample_type == "resnet": self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, up=True, ) ] ) else: self.upsamplers = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = attn(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: if self.upsample_type == "resnet": hidden_states = upsampler(hidden_states, temb=temb) else: hidden_states = upsampler(hidden_states) return hidden_states class CrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, num_attention_heads=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, dual_cross_attention=False, use_linear_projection=False, only_cross_attention=False, upcast_attention=False, attention_type="default", image_cross_attention_dim=512, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.num_attention_heads = num_attention_heads for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if not dual_cross_attention: attentions.append( Transformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=transformer_layers_per_block, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, attention_type=attention_type, ) ) else: attentions.append( DualTransformer2DModel( num_attention_heads, out_channels // num_attention_heads, in_channels=out_channels, num_layers=1, cross_attention_dim=cross_attention_dim, norm_num_groups=resnet_groups, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, image_encoder_hidden_states: Optional[torch.FloatTensor] = None, ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, return_dict=False, )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, upsample_size) return hidden_states class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnUpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=1.0, add_upsample=True, temb_channels=None, ): super().__init__() resnets = [] attentions = [] if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=temb) hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class AttnSkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, attention_head_dim=1, output_scale_factor=np.sqrt(2.0), add_upsample=True, ): super().__init__() self.attentions = nn.ModuleList([]) self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(resnet_in_channels + res_skip_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) if attention_head_dim is None: logger.warn( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." ) attention_head_dim = out_channels self.attentions.append( Attention( out_channels, heads=out_channels // attention_head_dim, dim_head=attention_head_dim, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=32, residual_connection=True, bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) hidden_states = self.attentions[0](hidden_states) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class SkipUpBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_pre_norm: bool = True, output_scale_factor=np.sqrt(2.0), add_upsample=True, upsample_padding=1, ): super().__init__() self.resnets = nn.ModuleList([]) for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels self.resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min((resnet_in_channels + res_skip_channels) // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) if add_upsample: self.resnet_up = ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=min(out_channels // 4, 32), groups_out=min(out_channels // 4, 32), dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, use_in_shortcut=True, up=True, kernel="fir", ) self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) self.skip_norm = torch.nn.GroupNorm( num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True ) self.act = nn.SiLU() else: self.resnet_up = None self.skip_conv = None self.skip_norm = None self.act = None def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) hidden_states = resnet(hidden_states, temb) if skip_sample is not None: skip_sample = self.upsampler(skip_sample) else: skip_sample = 0 if self.resnet_up is not None: skip_sample_states = self.skip_norm(hidden_states) skip_sample_states = self.act(skip_sample_states) skip_sample_states = self.skip_conv(skip_sample_states) skip_sample = skip_sample + skip_sample_states hidden_states = self.resnet_up(hidden_states, temb) return hidden_states, skip_sample class ResnetUpsampleBlock2D(nn.Module): def __init__( self, in_channels: int, prev_output_channel: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, ): super().__init__() resnets = [] for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet in self.resnets: # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class SimpleCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, prev_output_channel: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim=1, cross_attention_dim=1280, output_scale_factor=1.0, add_upsample=True, skip_time_act=False, only_cross_attention=False, cross_attention_norm=None, ): super().__init__() resnets = [] attentions = [] self.has_cross_attention = True self.attention_head_dim = attention_head_dim self.num_heads = out_channels // self.attention_head_dim for i in range(num_layers): res_skip_channels = in_channels if (i == num_layers - 1) else out_channels resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, ) ) processor = ( AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() ) attentions.append( Attention( query_dim=out_channels, cross_attention_dim=out_channels, heads=self.num_heads, dim_head=self.attention_head_dim, added_kv_proj_dim=cross_attention_dim, norm_num_groups=resnet_groups, bias=True, upcast_softmax=True, only_cross_attention=only_cross_attention, cross_attention_norm=cross_attention_norm, processor=processor, ) ) self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList( [ ResnetBlock2D( in_channels=out_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, skip_time_act=skip_time_act, up=True, ) ] ) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} if attention_mask is None: # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. mask = None if encoder_hidden_states is None else encoder_attention_mask else: # when attention_mask is defined: we don't even check for encoder_attention_mask. # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. # then we can simplify this whole if/else block to: # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask mask = attention_mask for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=mask, **cross_attention_kwargs, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states, temb) return hidden_states class KUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 5, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: Optional[int] = 32, add_upsample=True, ): super().__init__() resnets = [] k_in_channels = 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=k_out_channels if (i == num_layers - 1) else out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet in self.resnets: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward if is_torch_version(">=", "1.11.0"): hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, use_reentrant=False ) else: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb ) else: hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class KCrossAttnUpBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, temb_channels: int, dropout: float = 0.0, num_layers: int = 4, resnet_eps: float = 1e-5, resnet_act_fn: str = "gelu", resnet_group_size: int = 32, attention_head_dim=1, # attention dim_head cross_attention_dim: int = 768, add_upsample: bool = True, upcast_attention: bool = False, ): super().__init__() resnets = [] attentions = [] is_first_block = in_channels == out_channels == temb_channels is_middle_block = in_channels != out_channels add_self_attention = True if is_first_block else False self.has_cross_attention = True self.attention_head_dim = attention_head_dim # in_channels, and out_channels for the block (k-unet) k_in_channels = out_channels if is_first_block else 2 * out_channels k_out_channels = in_channels num_layers = num_layers - 1 for i in range(num_layers): in_channels = k_in_channels if i == 0 else out_channels groups = in_channels // resnet_group_size groups_out = out_channels // resnet_group_size if is_middle_block and (i == num_layers - 1): conv_2d_out_channels = k_out_channels else: conv_2d_out_channels = None resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, conv_2d_out_channels=conv_2d_out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=groups, groups_out=groups_out, dropout=dropout, non_linearity=resnet_act_fn, time_embedding_norm="ada_group", conv_shortcut_bias=False, ) ) attentions.append( KAttentionBlock( k_out_channels if (i == num_layers - 1) else out_channels, k_out_channels // attention_head_dim if (i == num_layers - 1) else out_channels // attention_head_dim, attention_head_dim, cross_attention_dim=cross_attention_dim, temb_channels=temb_channels, attention_bias=True, add_self_attention=add_self_attention, cross_attention_norm="layer_norm", upcast_attention=upcast_attention, ) ) self.resnets = nn.ModuleList(resnets) self.attentions = nn.ModuleList(attentions) if add_upsample: self.upsamplers = nn.ModuleList([KUpsample2D()]) else: self.upsamplers = None self.gradient_checkpointing = False def forward( self, hidden_states: torch.FloatTensor, res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], temb: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, upsample_size: Optional[int] = None, attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) for resnet, attn in zip(self.resnets, self.attentions): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(resnet), hidden_states, temb, **ckpt_kwargs, ) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: hidden_states = resnet(hidden_states, temb) hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states # can potentially later be renamed to `No-feed-forward` attention class KAttentionBlock(nn.Module): r""" A basic Transformer block. Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. num_embeds_ada_norm (: obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, dropout: float = 0.0, cross_attention_dim: Optional[int] = None, attention_bias: bool = False, upcast_attention: bool = False, temb_channels: int = 768, # for ada_group_norm add_self_attention: bool = False, cross_attention_norm: Optional[str] = None, group_size: int = 32, ): super().__init__() self.add_self_attention = add_self_attention # 1. Self-Attn if add_self_attention: self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, cross_attention_dim=None, cross_attention_norm=None, ) # 2. Cross-Attn self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, cross_attention_norm=cross_attention_norm, ) def _to_3d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) def _to_4d(self, hidden_states, height, weight): return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: Optional[torch.FloatTensor] = None, # TODO: mark emb as non-optional (self.norm2 requires it). # requires assessing impact of change to positional param interface. emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} # 1. Self-Attention if self.add_self_attention: norm_hidden_states = self.norm1(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states # 2. Cross-Attention/None norm_hidden_states = self.norm2(hidden_states, emb) height, weight = norm_hidden_states.shape[2:] norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) hidden_states = attn_output + hidden_states return hidden_states ================================================ FILE: models/unet_2d_condition.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import UNet2DConditionLoadersMixin from diffusers.utils import BaseOutput, logging from diffusers.models.activations import get_activation from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor from diffusers.models.embeddings import ( GaussianFourierProjection, ImageHintTimeEmbedding, ImageProjection, ImageTimeEmbedding, PositionNet, TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps, ) from diffusers.models.modeling_utils import ModelMixin from .unet_2d_blocks import ( UNetMidBlock2DCrossAttn, UNetMidBlock2DSimpleCrossAttn, get_down_block, get_up_block, ) import os, json logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ The output of [`UNet2DConditionModel`]. Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor = None class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample shaped output. This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see [`~models.attention.BasicTransformerBlock`]. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. encoder_hid_dim (`int`, *optional*, defaults to None): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. encoder_hid_dim_type (`str`, *optional*, defaults to `None`): If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. addition_time_embed_dim: (`int`, *optional*, defaults to `None`): Dimension for the timestep embeddings. num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. time_embedding_act_fn (`str`, *optional*, defaults to `None`): Optional activation function to use only once on the time embeddings before they are passed to the rest of the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. time_cond_proj_dim (`int`, *optional*, defaults to `None`): The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` otherwise. """ _supports_gradient_checkpointing = True @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int]] = 1, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: int = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads=64, ): super().__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs if len(down_block_types) != len(up_block_types): raise ValueError( f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." ) if len(block_out_channels) != len(down_block_types): raise ValueError( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." ) if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) elif encoder_hid_dim_type == "text_image_proj": # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` self.encoder_hid_proj = TextImageProjection( text_embed_dim=encoder_hid_dim, image_embed_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 self.encoder_hid_proj = ImageProjection( image_embed_dim=encoder_hid_dim, cross_attention_dim=cross_attention_dim, ) elif encoder_hid_dim_type is not None: raise ValueError( f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." ) else: self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" ) # The projection `class_embed_type` is the same as the timestep `class_embed_type` except # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings # 2. it projects from an arbitrary input dimension. # # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None if addition_embed_type == "text": if encoder_hid_dim is not None: text_time_embedding_from_dim = encoder_hid_dim else: text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.down_blocks.append(down_block) # mid if mid_block_type == "UNetMidBlock2DCrossAttn": self.mid_block = UNetMidBlock2DCrossAttn( transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, cross_attention_dim=cross_attention_dim[-1], num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, attention_type=attention_type, ) elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": self.mid_block = UNetMidBlock2DSimpleCrossAttn( in_channels=block_out_channels[-1], temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, cross_attention_dim=cross_attention_dim[-1], attention_head_dim=attention_head_dim[-1], resnet_groups=norm_num_groups, resnet_time_scale_shift=resnet_time_scale_shift, skip_time_act=resnet_skip_time_act, only_cross_attention=mid_block_only_cross_attention, cross_attention_norm=cross_attention_norm, ) elif mid_block_type is None: self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) if attention_type == "gated": positive_len = 768 if isinstance(cross_attention_dim, int): positive_len = cross_attention_dim elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): positive_len = cross_attention_dim[0] self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim) @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ self.set_attn_processor(AttnProcessor()) def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. When this option is enabled, the attention module splits the input tensor in slices to compute attention in several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): fn_recursive_retrieve_sliceable_dims(module) num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between # speed and memory slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." ) for i in range(len(slice_size)): size = slice_size[i] dim = sliceable_head_dims[i] if size is not None and size > dim: raise ValueError(f"size {size} has to be smaller or equal to {dim}.") # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) for child in module.children(): fn_recursive_set_attention_slice(child, slice_size) reversed_slice_size = list(reversed(slice_size)) for module in self.children(): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, image_encoder_hidden_states: torch.Tensor = None, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. Args: sample (`torch.FloatTensor`): The noisy input tensor with the following shape `(batch, channel, height, width)`. timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. encoder_hidden_states (`torch.FloatTensor`): The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. added_cond_kwargs: (`dict`, *optional*): A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time timesteps = timestep if not torch.is_tensor(timesteps): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): dtype = torch.float32 if is_mps else torch.float64 else: dtype = torch.int32 if is_mps else torch.int64 timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) elif len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(sample.shape[0]) t_emb = self.time_proj(timesteps) # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None if self.class_embedding is not None: if class_labels is None: raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) # `Timesteps` does not contain any weights and will always return f32 tensors # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) elif self.config.addition_embed_type == "text_image": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) aug_emb = self.add_embedding(text_embs, image_embs) elif self.config.addition_embed_type == "text_time": # SDXL - style if "text_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" ) text_embeds = added_cond_kwargs.get("text_embeds") if "time_ids" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(emb.dtype) aug_emb = self.add_embedding(add_embeds) elif self.config.addition_embed_type == "image": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) image_embs = added_cond_kwargs.get("image_embeds") hint = added_cond_kwargs.get("hint") aug_emb, hint = self.add_embedding(image_embs, hint) sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kadinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_block_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, image_encoder_hidden_states=image_encoder_hidden_states, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_block_additional_residuals) > 0: sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples # 4. mid if self.mid_block is not None: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, image_encoder_hidden_states=image_encoder_hidden_states, ) # To support T2I-Adapter-XL if ( is_adapter and len(down_block_additional_residuals) > 0 and sample.shape == down_block_additional_residuals[0].shape ): sample += down_block_additional_residuals.pop(0) if is_controlnet: sample = sample + mid_block_additional_residual # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, image_encoder_hidden_states=image_encoder_hidden_states, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if not return_dict: return (sample,) return UNet2DConditionOutput(sample=sample) @classmethod def from_pretrained_orig(cls, pretrained_model_path, subfolder=None, **kwargs): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) from diffusers.utils import WEIGHTS_NAME from diffusers.utils import SAFETENSORS_WEIGHTS_NAME model = cls.from_config(config) ## for .bin file # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) # if not os.path.isfile(model_file): # raise RuntimeError(f"{model_file} does not exist") # state_dict = torch.load(model_file, map_location="cpu") # model.load_state_dict(state_dict, strict=False) ## for .safetensors file import safetensors model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") state_dict = safetensors.torch.load_file(model_file, device="cpu") model.load_state_dict(state_dict, strict=False) return model @classmethod def from_pretrained_safetensor(cls, pretrained_model_path, subfolder=None, **kwargs): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) config_file = os.path.join(pretrained_model_path, 'config.json') if not os.path.isfile(config_file): raise RuntimeError(f"{config_file} does not exist") with open(config_file, "r") as f: config = json.load(f) from diffusers.utils import SAFETENSORS_WEIGHTS_NAME model = cls.from_config(config) model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") state_dict = torch.load(model_file, map_location="cpu") for k, v in model.state_dict().items(): if 'attn2_plus' in k: print(k) state_dict.update({k: v}) model.load_state_dict(state_dict, strict=False) return model ================================================ FILE: models/vit_utils.py ================================================ # MIT License # # Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab) # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # # Based on code from https://github.com/isl-org/DPT """Flexible configuration and feature extraction of timm VisionTransformers.""" import types import math from typing import Callable import torch import torch.nn as nn import torch.nn.functional as F class AddReadout(nn.Module): def __init__(self, start_index: bool = 1): super(AddReadout, self).__init__() self.start_index = start_index def forward(self, x: torch.Tensor) -> torch.Tensor: if self.start_index == 2: readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] return x[:, self.start_index:] + readout.unsqueeze(1) class Transpose(nn.Module): def __init__(self, dim0: int, dim1: int): super(Transpose, self).__init__() self.dim0 = dim0 self.dim1 = dim1 def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.transpose(self.dim0, self.dim1) return x.contiguous() def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict: _, _, H, W = x.size() _ = pretrained.model.forward_flex(x) return {k: pretrained.rearrange(v) for k, v in activations.items()} def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: posemb_tok, posemb_grid = ( posemb[:, : self.start_index], posemb[0, self.start_index :], ) gs_old = int(math.sqrt(len(posemb_grid))) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb def forward_flex(self, x: torch.Tensor) -> torch.Tensor: # patch proj and dynamically resize B, C, H, W = x.size() x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) pos_embed = self._resize_pos_embed( self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] ) # add cls token cls_tokens = self.cls_token.expand( x.size(0), -1, -1 ) x = torch.cat((cls_tokens, x), dim=1) # forward pass x = x + pos_embed x = self.pos_drop(x) for blk in self.blocks: x = blk(x) x = self.norm(x) return x activations = {} def get_activation(name: str) -> Callable: def hook(model, input, output): activations[name] = output return hook def make_sd_backbone( model: nn.Module, hooks: list[int] = [2, 5, 8, 11], hook_patch: bool = True, start_index: list[int] = 1, ): assert len(hooks) == 4 pretrained = nn.Module() pretrained.model = model # add hooks pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) if hook_patch: pretrained.model.pos_drop.register_forward_hook(get_activation('4')) # configure readout pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) pretrained.model.start_index = start_index pretrained.model.patch_size = patch_size # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained def make_vit_backbone( model: nn.Module, patch_size: list[int] = [16, 16], hooks: list[int] = [2, 5, 8, 11], hook_patch: bool = True, start_index: list[int] = 1, ): assert len(hooks) == 4 pretrained = nn.Module() pretrained.model = model # add hooks pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) if hook_patch: pretrained.model.pos_drop.register_forward_hook(get_activation('4')) # configure readout pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2)) pretrained.model.start_index = start_index pretrained.model.patch_size = patch_size # We inject this function into the VisionTransformer instances so that # we can use it with interpolated position embeddings without modifying the library source. pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) pretrained.model._resize_pos_embed = types.MethodType( _resize_pos_embed, pretrained.model ) return pretrained ================================================ FILE: myutils/devices.py ================================================ import sys import contextlib from functools import lru_cache import torch #from modules import errors if sys.platform == "darwin": from modules import mac_specific def has_mps() -> bool: if sys.platform != "darwin": return False else: return mac_specific.has_mps def get_cuda_device_string(): return "cuda" def get_optimal_device_name(): if torch.cuda.is_available(): return get_cuda_device_string() if has_mps(): return "mps" return "cpu" def get_optimal_device(): return torch.device(get_optimal_device_name()) def get_device_for(task): return get_optimal_device() def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() if has_mps(): mac_specific.torch_mps_gc() def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True enable_tf32() #errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") dtype = torch.float16 dtype_vae = torch.float16 dtype_unet = torch.float16 unet_needs_upcast = False def cond_cast_unet(input): return input.to(dtype_unet) if unet_needs_upcast else input def cond_cast_float(input): return input.float() if unet_needs_upcast else input def randn(seed, shape): torch.manual_seed(seed) return torch.randn(shape, device=device) def randn_without_seed(shape): return torch.randn(shape, device=device) def autocast(disable=False): if disable: return contextlib.nullcontext() return torch.autocast("cuda") def without_autocast(disable=False): return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): pass def test_for_nans(x, where): if not torch.all(torch.isnan(x)).item(): return if where == "unet": message = "A tensor with all NaNs was produced in Unet." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." else: message = "A tensor with all NaNs was produced." message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) @lru_cache def first_time_calculation(): """ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and spends about 2.7 seconds doing that, at least wih NVidia. """ x = torch.zeros((1, 1)).to(device, dtype) linear = torch.nn.Linear(1, 1).to(device, dtype) linear(x) x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) ================================================ FILE: myutils/img_util.py ================================================ import os import PIL import cv2 import math import numpy as np import torch import torchvision import imageio from einops import rearrange def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0): videos = rearrange(videos, "b c t h w -> t b c h w").cpu() outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) #x = adjust_gamma(x, 0.5) outputs.append(x) outputs = outputs[discardN:] if path is not None: #os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, duration=1000/fps, loop=0) return outputs def convert_image_to_fn(img_type, minsize, image, eps=0.02): width, height = image.size if min(width, height) < minsize: scale = minsize/min(width, height) + eps image = image.resize((math.ceil(width*scale), math.ceil(height*scale))) if image.mode != img_type: return image.convert(img_type) return image ================================================ FILE: myutils/misc.py ================================================ import os import binascii from safetensors import safe_open import torch from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint def rand_name(length=8, suffix=''): name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') if suffix: if not suffix.startswith('.'): suffix = '.' + suffix name += suffix return name def cycle(dl): while True: for data in dl: yield data def exists(x): return x is not None def identity(x): return x def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""): if model_path is None: return unet if model_path.endswith(".ckpt"): base_state_dict = torch.load(model_path)['state_dict'] elif model_path.endswith(".safetensors"): state_dict = {} with safe_open(model_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) is_lora = all("lora" in k for k in state_dict.keys()) if not is_lora: base_state_dict = state_dict else: base_state_dict = {} with safe_open(model_base, framework="pt", device="cpu") as f: for key in f.keys(): base_state_dict[key] = f.get_tensor(key) converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config) unet_state_dict = unet.state_dict() for key in converted_unet_checkpoint: converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key] unet.load_state_dict(converted_unet_checkpoint, strict=False) if vae is not None: converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config) vae.load_state_dict(converted_vae_checkpoint) return unet, vae ================================================ FILE: myutils/vaehook.py ================================================ # ------------------------------------------------------------------------ # # Ultimate VAE Tile Optimization # # Introducing a revolutionary new optimization designed to make # the VAE work with giant images on limited VRAM! # Say goodbye to the frustration of OOM and hello to seamless output! # # ------------------------------------------------------------------------ # # This script is a wild hack that splits the image into tiles, # encodes each tile separately, and merges the result back together. # # Advantages: # - The VAE can now work with giant images on limited VRAM # (~10 GB for 8K images!) # - The merged output is completely seamless without any post-processing. # # Drawbacks: # - Giant RAM needed. To store the intermediate results for a 4096x4096 # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192 # you need 128 GB RAM machine (it consumes ~100 GB) # - NaNs always appear in for 8k images when you use fp16 (half) VAE # You must use --no-half-vae to disable half VAE for that giant image. # - Slow speed. With default tile size, it takes around 50/200 seconds # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.) # - The gradient calculation is not compatible with this hack. It # will break any backward() or torch.autograd.grad() that passes VAE. # (But you can still use the VAE to generate training data.) # # How it works: # 1) The image is split into tiles. # - To ensure perfect results, each tile is padded with 32 pixels # on each side. # - Then the conv2d/silu/upsample/downsample can produce identical # results to the original image without splitting. # 2) The original forward is decomposed into a task queue and a task worker. # - The task queue is a list of functions that will be executed in order. # - The task worker is a loop that executes the tasks in the queue. # 3) The task queue is executed for each tile. # - Current tile is sent to GPU. # - local operations are directly executed. # - Group norm calculation is temporarily suspended until the mean # and var of all tiles are calculated. # - The residual is pre-calculated and stored and addded back later. # - When need to go to the next tile, the current tile is send to cpu. # 4) After all tiles are processed, tiles are merged on cpu and return. # # Enjoy! # # @author: LI YI @ Nanyang Technological University - Singapore # @date: 2023-03-02 # @license: MIT License # # Please give me a star if you like this project! # # ------------------------------------------------------------------------- import gc from time import time import math from tqdm import tqdm import torch import torch.version import torch.nn.functional as F from einops import rearrange import sys sys.path.append('/home/notebook/code/personal/S9048295/code/PASD') import myutils.devices as devices #from modules.shared import state #from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock try: import xformers import xformers.ops except ImportError: pass sd_flag = False def get_recommend_encoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 else: ENCODER_TILE_SIZE = 960 else: ENCODER_TILE_SIZE = 512 return ENCODER_TILE_SIZE def get_recommend_decoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 30*1000: DECODER_TILE_SIZE = 256 elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 else: DECODER_TILE_SIZE = 64 else: DECODER_TILE_SIZE = 64 return DECODER_TILE_SIZE if 'global const': DEFAULT_ENABLED = False DEFAULT_MOVE_TO_GPU = False DEFAULT_FAST_ENCODER = True DEFAULT_FAST_DECODER = True DEFAULT_COLOR_FIX = 0 DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size() DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size() # inplace version of silu def inplace_nonlinearity(x): # Test: fix for Nans return F.silu(x, inplace=True) # extracted from ldm.modules.diffusionmodules.model # from diffusers lib def attn_forward_new(self, h_): batch_size, channel, height, width = h_.shape hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2) attention_mask = None encoder_hidden_states = None batch_size, sequence_length, _ = hidden_states.shape attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) query = self.head_to_batch_dim(query) key = self.head_to_batch_dim(key) value = self.head_to_batch_dim(value) attention_probs = self.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = self.batch_to_head_dim(hidden_states) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states def attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h*w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return h_ def xformer_attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) out = self.proj_out(out) return out def attn2task(task_queue, net): if False: #isinstance(net, AttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) task_queue.append(['add_res', None]) elif False: #isinstance(net, MemoryEfficientAttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append( ('attn', lambda x, net=net: xformer_attn_forward(net, x))) task_queue.append(['add_res', None]) else: task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.group_norm)) task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x))) task_queue.append(['add_res', None]) def resblock2task(queue, block): """ Turn a ResNetBlock into a sequence of tasks and append to the task queue @param queue: the target task queue @param block: ResNetBlock """ if block.in_channels != block.out_channels: if sd_flag: if block.use_conv_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: if block.use_in_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: queue.append(('store_res', lambda x: x)) queue.append(('pre_norm', block.norm1)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv1', block.conv1)) queue.append(('pre_norm', block.norm2)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv2', block.conv2)) queue.append(['add_res', None]) def build_sampling(task_queue, net, is_decoder): """ Build the sampling part of a task queue @param task_queue: the target task queue @param net: the network @param is_decoder: currently building decoder or encoder """ if is_decoder: if sd_flag: resblock2task(task_queue, net.mid.block_1) attn2task(task_queue, net.mid.attn_1) print(task_queue) resblock2task(task_queue, net.mid.block_2) resolution_iter = reversed(range(net.num_resolutions)) block_ids = net.num_res_blocks + 1 condition = 0 module = net.up func_name = 'upsample' else: resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) resolution_iter = (range(len(net.up_blocks))) # net.num_resolutions = 3 block_ids = 2 + 1 condition = len(net.up_blocks) - 1 module = net.up_blocks func_name = 'upsamplers' else: resolution_iter = range(net.num_resolutions) block_ids = net.num_res_blocks condition = net.num_resolutions - 1 module = net.down func_name = 'downsample' for i_level in resolution_iter: for i_block in range(block_ids): if sd_flag: resblock2task(task_queue, module[i_level].block[i_block]) else: resblock2task(task_queue, module[i_level].resnets[i_block]) if i_level != condition: if sd_flag: task_queue.append((func_name, getattr(module[i_level], func_name))) else: task_queue.append((func_name, module[i_level].upsamplers[0])) if not is_decoder: if sd_flag: resblock2task(task_queue, net.mid.block_1) attn2task(task_queue, net.mid.attn_1) resblock2task(task_queue, net.mid.block_2) else: resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) def build_task_queue(net, is_decoder): """ Build a single task queue for the encoder or decoder @param net: the VAE decoder or encoder network @param is_decoder: currently building decoder or encoder @return: the task queue """ task_queue = [] task_queue.append(('conv_in', net.conv_in)) # construct the sampling part of the task queue # because encoder and decoder share the same architecture, we extract the sampling part build_sampling(task_queue, net, is_decoder) if is_decoder and not sd_flag: net.give_pre_end = False net.tanh_out = False if not is_decoder or not net.give_pre_end: if sd_flag: task_queue.append(('pre_norm', net.norm_out)) else: task_queue.append(('pre_norm', net.conv_norm_out)) task_queue.append(('silu', inplace_nonlinearity)) task_queue.append(('conv_out', net.conv_out)) if is_decoder and net.tanh_out: task_queue.append(('tanh', torch.tanh)) return task_queue def clone_task_queue(task_queue): """ Clone a task queue @param task_queue: the task queue to be cloned @return: the cloned task queue """ return [[item for item in task] for task in task_queue] def get_var_mean(input, num_groups, eps=1e-6): """ Get mean and var for group norm """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) var, mean = torch.var_mean( input_reshaped, dim=[0, 2, 3, 4], unbiased=False) return var, mean def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): """ Custom group norm with fixed mean and var @param input: input tensor @param num_groups: number of groups. by default, num_groups = 32 @param mean: mean, must be pre-calculated by get_var_mean @param var: var, must be pre-calculated by get_var_mean @param weight: weight, should be fetched from the original group norm @param bias: bias, should be fetched from the original group norm @param eps: epsilon, by default, eps = 1e-6 to match the original group norm @return: normalized tensor """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) out = out.view(b, c, *input.size()[2:]) # post affine transform if weight is not None: out *= weight.view(1, -1, 1, 1) if bias is not None: out += bias.view(1, -1, 1, 1) return out def crop_valid_region(x, input_bbox, target_bbox, is_decoder): """ Crop the valid region from the tile @param x: input tile @param input_bbox: original input bounding box @param target_bbox: output bounding box @param scale: scale factor @return: cropped tile """ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ def perfcount(fn): def wrapper(*args, **kwargs): ts = time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats(devices.device) devices.torch_gc() gc.collect() ret = fn(*args, **kwargs) devices.torch_gc() gc.collect() if torch.cuda.is_available(): vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 torch.cuda.reset_peak_memory_stats(devices.device) print( f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') else: print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') return ret return wrapper # copy end :) class GroupNormParam: def __init__(self): self.var_list = [] self.mean_list = [] self.pixel_list = [] self.weight = None self.bias = None def add_tile(self, tile, layer): var, mean = get_var_mean(tile, 32) # For giant images, the variance can be larger than max float16 # In this case we create a copy to float32 if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # ============= DEBUG: test for infinite ============= # if torch.isinf(var).any(): # print('var: ', var) # ==================================================== self.var_list.append(var) self.mean_list.append(mean) self.pixel_list.append( tile.shape[2]*tile.shape[3]) if hasattr(layer, 'weight'): self.weight = layer.weight self.bias = layer.bias else: self.weight = None self.bias = None def summary(self): """ summarize the mean and var and return a function that apply group norm on each tile """ if len(self.var_list) == 0: return None var = torch.vstack(self.var_list) mean = torch.vstack(self.mean_list) max_value = max(self.pixel_list) pixels = torch.tensor( self.pixel_list, dtype=torch.float32, device=devices.device) / max_value sum_pixels = torch.sum(pixels) pixels = pixels.unsqueeze( 1) / sum_pixels var = torch.sum( var * pixels, dim=0) mean = torch.sum( mean * pixels, dim=0) return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) @staticmethod def from_tile(tile, norm): """ create a function from a single tile without summary """ var, mean = get_var_mean(tile, 32) if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # if it is a macbook, we need to convert back to float16 if var.device.type == 'mps': # clamp to avoid overflow var = torch.clamp(var, 0, 60000) var = var.half() mean = mean.half() if hasattr(norm, 'weight'): weight = norm.weight bias = norm.bias else: weight = None bias = None def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) return group_norm_func class VAEHook: def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False): self.net = net # encoder | decoder self.tile_size = tile_size self.is_decoder = is_decoder self.fast_mode = (fast_encoder and not is_decoder) or ( fast_decoder and is_decoder) self.color_fix = color_fix and not is_decoder self.to_gpu = to_gpu self.pad = 11 if is_decoder else 32 def __call__(self, x): B, C, H, W = x.shape original_device = next(self.net.parameters()).device try: if self.to_gpu: self.net.to(devices.get_optimal_device()) if max(H, W) <= self.pad * 2 + self.tile_size: print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") return self.net.original_forward(x) else: return self.vae_tile_forward(x) finally: self.net.to(original_device) def get_best_tile_size(self, lowerbound, upperbound): """ Get the best tile size for GPU memory """ divider = 32 while divider >= 2: remainer = lowerbound % divider if remainer == 0: return lowerbound candidate = lowerbound - remainer + divider if candidate <= upperbound: return candidate divider //= 2 return lowerbound def split_tiles(self, h, w): """ Tool function to split the image into tiles @param h: height of the image @param w: width of the image @return: tile_input_bboxes, tile_output_bboxes """ tile_input_bboxes, tile_output_bboxes = [], [] tile_size = self.tile_size pad = self.pad num_height_tiles = math.ceil((h - 2 * pad) / tile_size) num_width_tiles = math.ceil((w - 2 * pad) / tile_size) # If any of the numbers are 0, we let it be 1 # This is to deal with long and thin images num_height_tiles = max(num_height_tiles, 1) num_width_tiles = max(num_width_tiles, 1) # Suggestions from https://github.com/Kahsolt: auto shrink the tile size real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') for i in range(num_height_tiles): for j in range(num_width_tiles): # bbox: [x1, x2, y1, y2] # the padding is is unnessary for image borders. So we directly start from (32, 32) input_bbox = [ pad + j * real_tile_width, min(pad + (j + 1) * real_tile_width, w), pad + i * real_tile_height, min(pad + (i + 1) * real_tile_height, h), ] # if the output bbox is close to the image boundary, we extend it to the image boundary output_bbox = [ input_bbox[0] if input_bbox[0] > pad else 0, input_bbox[1] if input_bbox[1] < w - pad else w, input_bbox[2] if input_bbox[2] > pad else 0, input_bbox[3] if input_bbox[3] < h - pad else h, ] # scale to get the final output bbox output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] tile_output_bboxes.append(output_bbox) # indistinguishable expand the input bbox by pad pixels tile_input_bboxes.append([ max(0, input_bbox[0] - pad), min(w, input_bbox[1] + pad), max(0, input_bbox[2] - pad), min(h, input_bbox[3] + pad), ]) return tile_input_bboxes, tile_output_bboxes @torch.no_grad() def estimate_group_norm(self, z, task_queue, color_fix): device = z.device tile = z last_id = len(task_queue) - 1 while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': last_id -= 1 if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': raise ValueError('No group norm found in the task queue') # estimate until the last group norm for i in range(last_id + 1): task = task_queue[i] if task[0] == 'pre_norm': group_norm_func = GroupNormParam.from_tile(tile, task[1]) task_queue[i] = ('apply_norm', group_norm_func) if i == last_id: return True tile = group_norm_func(tile) elif task[0] == 'store_res': task_id = i + 1 while task_id < last_id and task_queue[task_id][0] != 'add_res': task_id += 1 if task_id >= last_id: continue task_queue[task_id][1] = task[1](tile) elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None elif color_fix and task[0] == 'downsample': for j in range(i, last_id + 1): if task_queue[j][0] == 'store_res': task_queue[j] = ('store_res_cpu', task_queue[j][1]) return True else: tile = task[1](tile) try: devices.test_for_nans(tile, "vae") except: print(f'Nan detected in fast mode estimation. Fast mode disabled.') return False raise IndexError('Should not reach here') @perfcount @torch.no_grad() def vae_tile_forward(self, z): """ Decode a latent vector z into an image in a tiled manner. @param z: latent vector @return: image """ device = next(self.net.parameters()).device net = self.net tile_size = self.tile_size is_decoder = self.is_decoder z = z.detach() # detach the input to avoid backprop N, height, width = z.shape[0], z.shape[2], z.shape[3] net.last_z_shape = z.shape # Split the input into tiles and build a task queue for each tile print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') in_bboxes, out_bboxes = self.split_tiles(height, width) # Prepare tiles by split the input latents tiles = [] for input_bbox in in_bboxes: tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() tiles.append(tile) num_tiles = len(tiles) num_completed = 0 # Build task queues single_task_queue = build_task_queue(net, is_decoder) #print(single_task_queue) if self.fast_mode: # Fast mode: downsample the input image to the tile size, # then estimate the group norm parameters on the downsampled image scale_factor = tile_size / max(height, width) z = z.to(device) downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') # use nearest-exact to keep statictics as close as possible print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') # ======= Special thanks to @Kahsolt for distribution shift issue ======= # # The downsampling will heavily distort its mean and std, so we need to recover it. std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old del std_old, mean_old, std_new, mean_new # occasionally the std_new is too small or too large, which exceeds the range of float16 # so we need to clamp it to max z's range. downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) estimate_task_queue = clone_task_queue(single_task_queue) if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): single_task_queue = estimate_task_queue del downsampled_z task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] # Dummy result result = None result_approx = None #try: # with devices.autocast(): # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() #except: pass # Free memory of input latent tensor del z # Task queue execution pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") # execute the task back and forth when switch tiles so that we always # keep one tile on the GPU to reduce unnecessary data transfer forward = True interrupted = False #state.interrupted = interrupted while True: #if state.interrupted: interrupted = True ; break group_norm_param = GroupNormParam() for i in range(num_tiles) if forward else reversed(range(num_tiles)): #if state.interrupted: interrupted = True ; break tile = tiles[i].to(device) input_bbox = in_bboxes[i] task_queue = task_queues[i] interrupted = False while len(task_queue) > 0: #if state.interrupted: interrupted = True ; break # DEBUG: current task # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) task = task_queue.pop(0) if task[0] == 'pre_norm': group_norm_param.add_tile(tile, task[1]) break elif task[0] == 'store_res' or task[0] == 'store_res_cpu': task_id = 0 res = task[1](tile) if not self.fast_mode or task[0] == 'store_res_cpu': res = res.cpu() while task_queue[task_id][0] != 'add_res': task_id += 1 task_queue[task_id][1] = res elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None else: tile = task[1](tile) #print(tiles[i].shape, tile.shape, task) pbar.update(1) if interrupted: break # check for NaNs in the tile. # If there are NaNs, we abort the process to save user's time #devices.test_for_nans(tile, "vae") #print(tiles[i].shape, tile.shape, i, num_tiles) if len(task_queue) == 0: tiles[i] = None num_completed += 1 if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) del tile elif i == num_tiles - 1 and forward: forward = False tiles[i] = tile elif i == 0 and not forward: forward = True tiles[i] = tile else: tiles[i] = tile.cpu() del tile if interrupted: break if num_completed == num_tiles: break # insert the group norm task to the head of each task queue group_norm_func = group_norm_param.summary() if group_norm_func is not None: for i in range(num_tiles): task_queue = task_queues[i] task_queue.insert(0, ('apply_norm', group_norm_func)) # Done! pbar.close() return result if result is not None else result_approx.to(device) ================================================ FILE: myutils/wavelet_color_fix.py ================================================ ''' # -------------------------------------------------------------------------------- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) # -------------------------------------------------------------------------------- ''' import torch from PIL import Image from torch import Tensor from torch.nn import functional as F from torchvision.transforms import ToTensor, ToPILImage def adain_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = to_tensor(target).unsqueeze(0) source_tensor = to_tensor(source).unsqueeze(0) # Apply adaptive instance normalization result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) # Convert tensor back to image to_image = ToPILImage() result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) return result_image def wavelet_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = to_tensor(target).unsqueeze(0) source_tensor = to_tensor(source).unsqueeze(0) # Apply wavelet reconstruction result_tensor = wavelet_reconstruction(target_tensor, source_tensor) # Convert tensor back to image to_image = ToPILImage() result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) return result_image def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: feat (Tensor): 4D tensor. eps (float): A small value added to the variance to avoid divide-by-zero. Default: 1e-5. """ size = feat.size() assert len(size) == 4, 'The input feature should be 4D tensor.' b, c = size[:2] feat_var = feat.reshape(b, c, -1).var(dim=2) + eps feat_std = feat_var.sqrt().reshape(b, c, 1, 1) feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) return feat_mean, feat_std def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): """Adaptive instance normalization. Adjust the reference features to have the similar color and illuminations as those in the degradate features. Args: content_feat (Tensor): The reference feature. style_feat (Tensor): The degradate features. """ size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size) def wavelet_blur(image: Tensor, radius: int): """ Apply wavelet blur to the input tensor. """ # input shape: (1, 3, H, W) # convolution kernel kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) # add channel dimensions to the kernel to make it a 4D tensor kernel = kernel[None, None] # repeat the kernel across all input channels kernel = kernel.repeat(3, 1, 1, 1) image = F.pad(image, (radius, radius, radius, radius), mode='replicate') # apply convolution output = F.conv2d(image, kernel, groups=3, dilation=radius) return output def wavelet_decomposition(image: Tensor, levels=5): """ Apply wavelet decomposition to the input tensor. This function only returns the low frequency & the high frequency. """ high_freq = torch.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq += (image - low_freq) image = low_freq return high_freq, low_freq def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): """ Apply wavelet decomposition, so that the content will have the same color as the style. """ # calculate the wavelet decomposition of the content feature content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # calculate the wavelet decomposition of the style feature style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # reconstruct the content feature with the style's high frequency return content_high_freq + style_low_freq ================================================ FILE: pipelines/pipeline_ccsr.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import os import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image import torch import torch.nn.functional as F from torchvision.utils import save_image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import TextualInversionLoaderMixin # from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel from models.controlnet import ControlNetModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, logging, replace_example_docstring, ) from diffusers.utils.torch_utils import is_compiled_module, randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from utils.vaehook import VAEHook, perfcount from tqdm import tqdm from torch import FloatTensor from PIL import Image import time logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> # !pip install opencv-python transformers accelerate >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler >>> from diffusers.utils import load_image >>> import numpy as np >>> import torch >>> import cv2 >>> from PIL import Image >>> # download an image >>> image = load_image( ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" ... ) >>> image = np.array(image) >>> # get canny image >>> image = cv2.Canny(image, 100, 200) >>> image = image[:, :, None] >>> image = np.concatenate([image, image, image], axis=2) >>> canny_image = Image.fromarray(image) >>> # load control net and stable diffusion v1-5 >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 ... ) >>> # speed up diffusion process with faster scheduler and memory optimization >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> # remove following line if xformers is not installed >>> pipe.enable_xformers_memory_efficient_attention() >>> pipe.enable_model_cpu_offload() >>> # generate image >>> generator = torch.manual_seed(0) >>> image = pipe( ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image ... ).images[0] ``` """ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): Frozen text-encoder. Stable Diffusion uses the text portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets as a list, the outputs from each ControlNet are added together to create one combined additional conditioning. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) if isinstance(controlnet, (list, tuple)): controlnet = MultiControlNetModel(controlnet) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, ) self.scheduler = scheduler self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def _init_tiled_vae(self, encoder_tile_size = 256, decoder_tile_size = 256, fast_decoder = False, fast_encoder = False, color_fix = False, vae_to_gpu = True): # save original forward (only once) if not hasattr(self.vae.encoder, 'original_forward'): setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward) if not hasattr(self.vae.decoder, 'original_forward'): setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward) encoder = self.vae.encoder decoder = self.vae.decoder self.vae.encoder.forward = VAEHook( encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) self.vae.decoder.forward = VAEHook( decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling def enable_vae_tiling(self): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow the processing of larger images. """ self.vae.enable_tiling() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling def disable_vae_tiling(self): r""" Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_tiling() def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. Note that offloading happens on a submodule basis. Memory savings are higher than with `enable_model_cpu_offload`, but performance is lower. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: cpu_offload(cpu_offloaded_model, device) if self.safety_checker is not None: cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. """ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): from accelerate import cpu_offload_with_hook else: raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") hook = None for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) if self.safety_checker is not None: # the safety checker can offload the vae again _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) # control net hook has be manually offloaded as it alternates with unet cpu_offload_with_hook(self.controlnet, device) # We'll offload the last model manually. self.final_offload_hook = hook @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt do_classifier_free_guidance (`bool`): whether to use classifier free guidance or not negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. """ if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = self.tokenizer.batch_decode( untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None prompt_embeds = self.text_encoder( text_input_ids.to(device), attention_mask=attention_mask, ) prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) # get unconditional embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) else: uncond_tokens = negative_prompt # textual inversion: procecss multi-vector tokens if necessary if isinstance(self, TextualInversionLoaderMixin): uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) max_length = prompt_embeds.shape[1] uncond_input = self.tokenizer( uncond_tokens, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt", ) if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: attention_mask = uncond_input.attention_mask.to(device) else: attention_mask = None negative_prompt_embeds = self.text_encoder( uncond_input.input_ids.to(device), attention_mask=attention_mask, ) negative_prompt_embeds = negative_prompt_embeds[0] if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) return prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator #extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs( self, prompt, image, height, width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and negative_prompt_embeds is not None: raise ValueError( f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" f" {negative_prompt_embeds}. Please make sure to only forward one of the two." ) if prompt_embeds is not None and negative_prompt_embeds is not None: if prompt_embeds.shape != negative_prompt_embeds.shape: raise ValueError( "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" {negative_prompt_embeds.shape}." ) # `prompt` needs more sophisticated handling when there are multiple # conditionings. if isinstance(self.controlnet, MultiControlNetModel): if isinstance(prompt, list): logger.warning( f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" " prompts. The conditionings will be fixed across the prompts." ) # Check `image` is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( self.controlnet, torch._dynamo.eval_frame.OptimizedModule ) if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): self.check_image(image, prompt, prompt_embeds) elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") # When `image` is a nested list: # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) elif any(isinstance(i, list) for i in image): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( "For multiple controlnets: `image` must have the same length as the number of controlnets." ) for image_ in image: self.check_image(image_, prompt, prompt_embeds) else: assert False # Check `controlnet_conditioning_scale` if ( isinstance(self.controlnet, ControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, ControlNetModel) ): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") elif ( isinstance(self.controlnet, MultiControlNetModel) or is_compiled and isinstance(self.controlnet._orig_mod, MultiControlNetModel) ): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( self.controlnet.nets ): raise ValueError( "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" " the same length as the number of controlnets" ) else: assert False def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: raise TypeError( "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" ) if image_is_pil: image_batch_size = 1 elif image_is_tensor: image_batch_size = image.shape[0] elif image_is_pil_list: image_batch_size = len(image) elif image_is_tensor_list: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): prompt_batch_size = 1 elif prompt is not None and isinstance(prompt, list): prompt_batch_size = len(prompt) elif prompt_embeds is not None: prompt_batch_size = prompt_embeds.shape[0] if image_batch_size != 1 and image_batch_size != prompt_batch_size: raise ValueError( f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" ) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): if not isinstance(image, torch.Tensor): if isinstance(image, PIL.Image.Image): image = [image] if isinstance(image[0], PIL.Image.Image): images = [] for image_ in image: image_ = image_.convert("RGB") #image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) image_ = np.array(image_) image_ = image_[None, :] images.append(image_) image = images image = np.concatenate(image, axis=0) image = np.array(image).astype(np.float32) / 255.0 image = image.transpose(0, 3, 1, 2) image = torch.from_numpy(image)#.flip(1) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, dim=0) image_batch_size = image.shape[0] if image_batch_size == 1: repeat_by = batch_size else: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt image = image.repeat_interleave(repeat_by, dim=0) image = image.to(device=device, dtype=dtype) if do_classifier_free_guidance and not guess_mode: image = torch.cat([image] * 2) return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image # is not _exactly_ correct, but it is simple. while isinstance(image, list): image = image[0] if height is None: if isinstance(image, PIL.Image.Image): height = image.height elif isinstance(image, torch.Tensor): height = image.shape[2] height = (height // 8) * 8 # round down to nearest multiple of 8 if width is None: if isinstance(image, PIL.Image.Image): width = image.width elif isinstance(image, torch.Tensor): width = image.shape[3] width = (width // 8) * 8 # round down to nearest multiple of 8 return height, width # override DiffusionPipeline def save_pretrained( self, save_directory: Union[str, os.PathLike], safe_serialization: bool = False, variant: Optional[str] = None, ): if isinstance(self.controlnet, ControlNetModel): super().save_pretrained(save_directory, safe_serialization, variant) else: raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") def previous_timestep(self, timestep): if self.scheduler.custom_timesteps: index = (self.scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == self.scheduler.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = self.scheduler.timesteps[index + 1] else: num_inference_steps = ( self.scheduler.num_inference_steps if self.scheduler.num_inference_steps else self.scheduler.config.num_train_timesteps ) prev_t = timestep - self.scheduler.config.num_train_timesteps // num_inference_steps return prev_t def predict_start_from_noise(self, sample, t, model_output): t = t.to(self.scheduler.alphas_cumprod.device) prev_t = self.previous_timestep(t) # 1. compute alphas, betas alpha_prod_t = self.scheduler.alphas_cumprod[t].to(sample.device) alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else self.scheduler.one alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample def _sliding_windows(self,h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: hi_list = list(range(0, h - tile_size + 1, tile_stride)) if (h - tile_size) % tile_stride != 0: hi_list.append(h - tile_size) wi_list = list(range(0, w - tile_size + 1, tile_stride)) if (w - tile_size) % tile_stride != 0: wi_list.append(w - tile_size) coords = [] for hi in hi_list: for wi in wi_list: coords.append((hi, hi + tile_size, wi, wi + tile_size)) return coords # Helper methods within the class def _prepare_controlnet_inputs(self, latent_model_input, latents, prompt_embeds, do_classifier_free_guidance, guess_mode): if guess_mode and do_classifier_free_guidance: return latents, prompt_embeds.chunk(2)[1] return latent_model_input, prompt_embeds def _predict_noise(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_diffusion, tile_size, tile_stride, conditioning_scale, guess_mode): if not tile_diffusion: noise_pred = self._unet_predict(latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions) else: noise_pred = self._tile_predict(latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_size, tile_stride, conditioning_scale, guess_mode) return noise_pred def _unet_predict(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions): down_res_samples, mid_res_sample = self.controlnet( latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_cond=image, conditioning_scale=1.0, guess_mode=False, return_dict=False, vae_encode_condition_hidden_states=vae_conditions ) noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_res_samples, mid_block_additional_residual=mid_res_sample, return_dict=False, )[0] return noise_pred def _tile_predict(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_size, tile_stride, conditioning_scale, guess_mode): tile_weight = self.gaussian_weights(int(tile_size//8), int(tile_size//8), 1).to(latent_model_input.device) noise_pred = torch.zeros_like(latent_model_input, dtype=torch.float32) count = torch.zeros_like(latent_model_input, dtype=torch.float32) h, w = latent_model_input.shape[2:4] for hi, hi_end, wi, wi_end in self._sliding_windows(h, w, int(tile_size // 8), int(tile_stride // 8)): tile = latent_model_input[:, :, hi:hi_end, wi:wi_end] tile_cond = vae_conditions[:, :, hi:hi_end, wi:wi_end] if vae_conditions is not None else None tile_image = image[:, :, hi*8:hi_end*8, wi*8:wi_end*8] # tile_cond = self.vae.encode(tile_image * 2 - 1).latent_dist.sample() * self.vae.config.scaling_factor down_block_res_samples, mid_block_res_sample = [None]*10, None down_res_samples, mid_res_sample = self.controlnet( tile, t, encoder_hidden_states=prompt_embeds, controlnet_cond=tile_image, conditioning_scale=1.0, guess_mode=False, return_dict=False, vae_encode_condition_hidden_states=tile_cond ) tile_noise = self.unet( tile, t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_res_samples, mid_block_additional_residual=mid_res_sample, return_dict=False, )[0] noise_pred[:, :, hi:hi_end, wi:wi_end] += tile_noise * tile_weight count[:, :, hi:hi_end, wi:wi_end] += tile_weight noise_pred /= count return noise_pred.to(torch.float16) def _initial_step(self, do_classifier_free_guidance, latents, t, timesteps, prompt_embeds, image, vae_conditions, tile_diffusion, tile_size, tile_stride): if do_classifier_free_guidance: prompt_embeds = prompt_embeds.chunk(2)[0] image = image.chunk(2)[0] vae_conditions = vae_conditions.chunk(2)[0] noise_pred = self._predict_noise(latents, t, image, prompt_embeds, None, vae_conditions, tile_diffusion, tile_size, tile_stride, 1.0, False) x0_T = self.predict_start_from_noise(latents, t, noise_pred) noise_tao = torch.randn_like(latents) latents = self.scheduler.add_noise(x0_T, noise_tao, timesteps) return latents, x0_T def _postprocess_latents(self, latents, output_type, do_denormalize): latents = latents.to(torch.float16) if output_type != "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0].to(torch.float32) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) else: image = latents return image def gaussian_weights(self, tile_width: int, tile_height: int, nbatches: int) -> torch.Tensor: """Generates a gaussian mask of weights for tile contributions""" from numpy import pi, exp, sqrt import numpy as np latent_width = tile_width latent_height = tile_height var = 0.01 midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)] midpoint = latent_height / 2 y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)] weights = np.outer(y_probs, x_probs) return torch.tile(torch.tensor(weights, device=self.device), (nbatches, 4, 1, 1)) @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, t_max: float, t_min: float, tile_diffusion: bool, tile_size: float, tile_stride: float, prompt: Union[str, List[str]] = None, image: Union[FloatTensor, Image.Image, List[FloatTensor], List[Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[FloatTensor] = None, prompt_embeds: Optional[FloatTensor] = None, negative_prompt_embeds: Optional[FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, start_steps: int = 999, use_vae_encode_condition: bool = False, start_point: str = 'noise', ) -> Union[StableDiffusionPipelineOutput, tuple]: r""" Optimized diffusion pipeline call for image super-resolution. For 'Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution'. Examples: # Example usage: # pipeline(t_max=0.6667, t_min=0.5, tile_diffusion=True, tile_size=256, tile_stride=128, prompt="", num_inference_steps=6) pass """ # 0. Set default height and width height, width = self._default_height_width(height, width, image) # 1. Determine batch size if prompt is not None: batch_size = 1 if isinstance(prompt, str) else len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet # 2. Prepare image image = self.prepare_image( image=image, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode ) # 3. Prepare scheduler timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 4. Prepare extra step kwargs extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) ### calculate the running time for each inference step torch.cuda.synchronize() start_time = time.time() # 5. Encode prompts prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds ) # 6. Prepare latent variables latents = self.prepare_latents( batch_size * num_images_per_prompt, self.unet.config.in_channels, height, width, prompt_embeds.dtype, device, generator, latents ) # 7. Initialize latent variables based on start_point latents_condition_image = self.vae.encode(image * 2 - 1).latent_dist.sample() * self.vae.config.scaling_factor if start_point != 'noise': start_steps_tensor = torch.randint(start_steps, start_steps + 1, (latents.shape[0],), device=latents.device).long() latents = self.scheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor) # 8. Optionally prepare VAE-encoded condition vae_encode_condition_hidden_states = ( latents_condition_image if use_vae_encode_condition else None ) # 9. Initial prediction at t_max if needed total_steps = len(timesteps) t_tao = timesteps[-round(total_steps * t_max)] if t_max != 1: t = torch.randint(start_steps, start_steps+1, (batch_size,), device=latents.device) latents = latents.to(torch.float16) # we do not do the classifier free guidance in this step latent_model_input = self.scheduler.scale_model_input(latents, t) latents, x0_T = self._initial_step(do_classifier_free_guidance, latent_model_input, t, t_tao, prompt_embeds, image, vae_encode_condition_hidden_states, tile_diffusion, tile_size, tile_stride) # redefine timesteps timesteps = timesteps[-round(total_steps * t_max):] timesteps = timesteps[:-round(total_steps * t_min)] if t_min > 0 else timesteps # 10. Denoising loop if num_inference_steps==1: latents = x0_T else: with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): latents = latents.to(torch.float16) latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) controlnet_latent_model_input, controlnet_prompt_embeds = self._prepare_controlnet_inputs(latent_model_input, latents, prompt_embeds, do_classifier_free_guidance, guess_mode) noise_pred = self._predict_noise( controlnet_latent_model_input, t, image, controlnet_prompt_embeds, cross_attention_kwargs, vae_encode_condition_hidden_states, tile_diffusion, tile_size, tile_stride, conditioning_scale, guess_mode ) if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents_old = latents latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided progress_bar.update() if i == len(timesteps) - 1: if callback is not None and i % callback_steps == 0: callback(i, t, latents) # Predict x0 for t_min if t_min: x0_tmin = self.predict_start_from_noise(latents_old, t, noise_pred) latents = x0_tmin # 11. Post-processing has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self._postprocess_latents(latents, output_type, do_denormalize) ## cauculate the inference time for each inference step torch.cuda.synchronize() end_time = time.time() total_time = end_time - start_time return total_time, StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) ================================================ FILE: requirements.txt ================================================ diffusers==0.21.0 torch==2.0.1 pytorch_lightning accelerate==1.2.0 transformers==4.25.0 xformers==0.0.22 loralib fairscale==0.4.13 basicsr==1.4.2 timm==0.9.5 pydantic==1.10.11 huggingface_hub==0.25.2 opencv-python-headless lpips ================================================ FILE: scripts/get_path.py ================================================ import os def write_png_paths(folder_path, txt_path): with open(txt_path, 'w') as f: for root, dirs, files in os.walk(folder_path): for file in files: if file.endswith('.png'): f.write(os.path.join(root, file) + '\n') # Example usage: folder_path = '' txt_path = '/gt_path.txt' write_png_paths(folder_path, txt_path) ================================================ FILE: scripts/test/test_ccsr_multistep.sh ================================================ python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 6 \ --t_max 0.6667 \ --t_min 0.5 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 4.5 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 ================================================ FILE: scripts/test/test_ccsr_onestep.sh ================================================ python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 1 \ --t_min 0.0 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 1.0 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 ================================================ FILE: scripts/test/test_ccsr_tile.sh ================================================ python test_ccsr_tile.py \ --pretrained_model_path preset/models/stable-diffusion-2-1-base \ --controlnet_model_path preset/models \ --vae_model_path preset/models \ --baseline_name ccsr-v2 \ --image_path preset/test_datasets \ --output_dir experiments/test \ --sample_method ddpm \ --num_inference_steps 6 \ --t_max 0.6667 \ --t_min 0.5 \ --start_point lr \ --start_steps 999 \ --process_size 512 \ --guidance_scale 4.5 \ --sample_times 1 \ --use_vae_encode_condition \ --upscale 4 \ --tile_diffusion \ --tile_diffusion_size 512 \ --tile_diffusion_stride 256 \ --tile_vae \ --vae_decoder_tile_size 224 \ --vae_encoder_tile_size 1024 \ ================================================ FILE: scripts/train/train_ccsr_stage1.sh ================================================ CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage1.py \ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \ --controlnet_model_name_or_path='preset/models/pretrained_controlnet' \ --enable_xformers_memory_efficient_attention \ --output_dir="./experiments/ccsrv2_stage1" \ --mixed_precision="fp16" \ --resolution=512 \ --learning_rate=5e-5 \ --train_batch_size=4 \ --gradient_accumulation_steps=6 \ --dataloader_num_workers=0 \ --checkpointing_steps=500 \ --t_max=0.6667 \ --max_train_steps=20000 \ --dataset_root_folders 'preset/gt_path.txt' ================================================ FILE: scripts/train/train_ccsr_stage2.sh ================================================ CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_ccsr_stage2.py \ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \ --controlnet_model_name_or_path='preset/models/model_stage1' \ --enable_xformers_memory_efficient_attention \ --output_dir="./experiments/ccsrv2_stage2" \ --mixed_precision="fp16" \ --resolution=512 \ --learning_rate=5e-6 \ --train_batch_size=2 \ --gradient_accumulation_steps=8 \ --checkpointing_steps=500 \ --is_start_lr=True \ --t_max=0.6667 \ --num_inference_steps=1 \ --is_module \ --lambda_l2=1.0 \ --lambda_lpips=1.0 \ --lambda_disc=0.05 \ --lambda_disc_train=0.5 \ --begin_disc=100 \ --max_train_steps=2000 \ --dataset_root_folders 'preset/gt_path.txt' ================================================ FILE: scripts/train/train_controlnet.sh ================================================ CUDA_VISIBLE_DEVICES="0,1,2,3," accelerate launch train_controlnet.py \ --pretrained_model_name_or_path="preset/models/stable-diffusion-2-1-base" \ --controlnet_model_name_or_path='' \ --enable_xformers_memory_efficient_attention \ --output_dir="./experiments/pretrained_controlnet" \ --mixed_precision="fp16" \ --resolution=512 \ --learning_rate=5e-5 \ --train_batch_size=4 \ --gradient_accumulation_steps=6 \ --dataloader_num_workers=0 \ --checkpointing_steps=5000 \ --max_train_steps=40000 \ --dataset_root_folders 'preset/gt_path.txt' ================================================ FILE: test_ccsr_tile.py ================================================ import os import glob import math import time import argparse import numpy as np from PIL import Image import safetensors.torch import torch from torchvision import transforms import torchvision.transforms.functional as F from accelerate import Accelerator from accelerate.utils import set_seed from diffusers import ( AutoencoderKL, UniPCMultistepScheduler, DPMSolverMultistepScheduler, DDPMScheduler, UNet2DConditionModel, ) from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor from pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix from models.controlnet import ControlNetModel def load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention): scheduler_mapping = { 'unipcmultistep': UniPCMultistepScheduler, 'ddpm': DDPMScheduler, 'dpmmultistep': DPMSolverMultistepScheduler, } try: scheduler_cls = scheduler_mapping[args.sample_method] except KeyError: raise ValueError(f"Invalid sample_method: {args.sample_method}") scheduler = scheduler_cls.from_pretrained(args.pretrained_model_path, subfolder="scheduler") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") feature_extractor = CLIPImageProcessor.from_pretrained(os.path.join(args.pretrained_model_path, "feature_extractor")) unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_path, subfolder="unet") controlnet = ControlNetModel.from_pretrained(args.controlnet_model_path, subfolder="controlnet") vae_path = args.vae_model_path if args.vae_model_path else args.pretrained_model_path vae = AutoencoderKL.from_pretrained(vae_path, subfolder="vae") # Freeze models for model in [vae, text_encoder, unet, controlnet]: model.requires_grad_(False) # Enable xformers if available if enable_xformers_memory_efficient_attention: if is_xformers_available(): unet.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Ensure it is installed correctly.") # Initialize pipeline validation_pipeline = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, ) if args.tile_vae: validation_pipeline._init_tiled_vae( encoder_tile_size=args.vae_encoder_tile_size, decoder_tile_size=args.vae_decoder_tile_size ) # Set weight dtype based on mixed precision dtype_mapping = { "fp16": torch.float16, "bf16": torch.bfloat16, } weight_dtype = dtype_mapping.get(accelerator.mixed_precision, torch.float32) # Move models to accelerator device with appropriate dtype for model in [text_encoder, vae, unet, controlnet]: model.to(accelerator.device, dtype=weight_dtype) return validation_pipeline def main(args, enable_xformers_memory_efficient_attention=True,): detailed_output_dir = os.path.join( args.output_dir, f"sr_{args.baseline_name}_{args.sample_method}_{str(args.num_inference_steps).zfill(3)}steps_{args.start_point}{args.start_steps}_size{args.process_size}_cfg{args.guidance_scale}" ) accelerator = Accelerator( mixed_precision=args.mixed_precision, ) # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the output folder creation # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: os.makedirs(detailed_output_dir, exist_ok=True) accelerator.init_trackers("Controlnet") pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention) if accelerator.is_main_process: generator = torch.Generator(device=accelerator.device) if args.seed is not None: generator.manual_seed(args.seed) image_paths = sorted(glob.glob(os.path.join(args.image_path, "*.*"))) if os.path.isdir(args.image_path) else [args.image_path] time_records = [] for image_path in image_paths: validation_image = Image.open(image_path).convert("RGB") negative_prompt = args.negative_prompt validation_prompt = args.added_prompt ori_width, ori_height = validation_image.size resize_flag = False rscale = args.upscale if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale: scale = (args.process_size//rscale)/min(ori_width, ori_height) tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height))) validation_image = tmp_image resize_flag = True validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale)) validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8)) width, height = validation_image.size resize_flag = True # for sample_idx in range(args.sample_times): os.makedirs(f'{detailed_output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True) for sample_idx in range(args.sample_times): inference_time, image = pipeline( args.t_max, args.t_min, args.tile_diffusion, args.tile_diffusion_size, args.tile_diffusion_stride, args.added_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=generator, height=height, width=width, guidance_scale=args.guidance_scale, negative_prompt=args.negative_prompt, conditioning_scale=args.conditioning_scale, start_steps=args.start_steps, start_point=args.start_point, use_vae_encode_condition=args.use_vae_encode_condition, ) image = image.images[0] print(f"Inference time: {inference_time:.4f} seconds") time_records.append(inference_time) # Apply color fixing if specified if args.align_method != 'nofix': fix_func = wavelet_color_fix if args.align_method == 'wavelet' else adain_color_fix image = fix_func(image, validation_image) if resize_flag: image = image.resize((ori_width*rscale, ori_height*rscale)) image_tensor = torch.clamp(F.to_tensor(image), 0, 1) final_image = transforms.ToPILImage()(image_tensor) base_name = os.path.splitext(os.path.basename(image_path))[0] save_path = os.path.join(detailed_output_dir, f"sample{str(sample_idx).zfill(2)}", f"{base_name}.png") image.save(save_path) # Calculate the average inference time, excluding the first few for stabilization if len(time_records) > 3: average_time = np.mean(time_records[3:]) else: average_time = np.mean(time_records) if accelerator.is_main_process: print(f"Average inference time: {average_time:.4f} seconds") # Save the run settings to a file settings_path = os.path.join(detailed_output_dir, "settings.txt") with open(settings_path, 'w') as f: f.write("------------------ start ------------------\n") for key, value in vars(args).items(): f.write(f"{key} : {value}\n") f.write("------------------- end -------------------\n") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Stable Diffusion ControlNet Pipeline for Super-Resolution") parser.add_argument("--controlnet_model_path", type=str, default="", help="Path to ControlNet model") parser.add_argument("--pretrained_model_path", type=str, default="", help="Path to pretrained model") parser.add_argument("--vae_model_path", type=str, default="", help="Path to VAE model") parser.add_argument("--added_prompt", type=str, default="clean, high-resolution, 8k", help="Additional prompt for generation") parser.add_argument("--negative_prompt", type=str, default="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", help="Negative prompt to avoid certain features") parser.add_argument("--image_path", type=str, default="", help="Path to input image or directory") parser.add_argument("--output_dir", type=str, default="", help="Directory to save outputs") parser.add_argument("--mixed_precision", type=str, choices=["no", "fp16", "bf16"], default="fp16", help="Mixed precision mode") parser.add_argument("--guidance_scale", type=float, default=1.0, help="Guidance scale for generation") parser.add_argument("--conditioning_scale", type=float, default=1.0, help="Conditioning scale") parser.add_argument("--num_inference_steps", type=int, default=1, help="Number of inference steps(not the final inference time)") # final_inference_time = num_inference_steps * (t_max - t_min) + 1 parser.add_argument("--t_max", type=float, default=0.6666, help="Maximum timestep") parser.add_argument("--t_min", type=float, default=0.0, help="Minimum timestep") parser.add_argument("--process_size", type=int, default=512, help="Processing size of the image") parser.add_argument("--upscale", type=int, default=1, help="Upscaling factor") parser.add_argument("--seed", type=int, default=None, help="Random seed") parser.add_argument("--sample_times", type=int, default=5, help="Number of samples to generate per image") parser.add_argument("--sample_method", type=str, choices=['unipcmultistep', 'ddpm', 'dpmmultistep'], default='ddpm', help="Sampling method") parser.add_argument("--align_method", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain', help="Alignment method for color fixing") parser.add_argument("--start_steps", type=int, default=999, help="Starting steps") parser.add_argument("--start_point", type=str, choices=['lr', 'noise'], default='lr', help="Starting point for generation") parser.add_argument("--baseline_name", type=str, default='ccsr-v2', help="Baseline name for output naming") parser.add_argument("--use_vae_encode_condition", action='store_true', help="Use VAE encoding LQ condition") # Tiling settings for high-resolution SR parser.add_argument("--tile_diffusion", action="store_true", help="Optionally! Enable tile-based diffusion") parser.add_argument("--tile_diffusion_size", type=int, default=512, help="Tile size for diffusion") parser.add_argument("--tile_diffusion_stride", type=int, default=256, help="Stride size for diffusion tiles") parser.add_argument("--tile_vae", action="store_true", help="Optionally! Enable tiling for VAE") parser.add_argument("--vae_decoder_tile_size", type=int, default=224, help="Tile size for VAE decoder") parser.add_argument("--vae_encoder_tile_size", type=int, default=1024, help="Tile size for VAE encoder") args = parser.parse_args() main(args) ================================================ FILE: train_ccsr_stage1.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import logging import math import os import random import shutil from pathlib import Path import accelerate import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from dataloaders.paired_dataset_txt import PairedCaptionDataset if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.21.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) pipeline = StableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if len(args.validation_image) == len(args.validation_prompt): validation_images = args.validation_image validation_prompts = args.validation_prompt elif len(args.validation_image) == 1: validation_images = args.validation_image * len(args.validation_prompt) validation_prompts = args.validation_prompt elif len(args.validation_prompt) == 1: validation_images = args.validation_image validation_prompts = args.validation_prompt * len(args.validation_image) else: raise ValueError( "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) image_logs = [] for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") images = [] for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline( validation_prompt, validation_image, num_inference_steps=20, generator=generator ).images[0] images.append(image) image_logs.append( {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images = [] formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) formatted_images = np.stack(formatted_images) tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") elif tracker.name == "wandb": formatted_images = [] for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) tracker.log({"validation": formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") return image_logs def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: img_str = "You can find some example images below.\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} tags: - stable-diffusion - stable-diffusion-diffusers - text-to-image - diffusers - controlnet inference: true --- """ model_card = f""" # controlnet-{repo_id} These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument('--dataset_root_folders', type=str, default="") parser.add_argument("--t_max", type=float, default=0.6667) parser.add_argument("--start_timesteps", type=int, default=999) parser.add_argument( "--pretrained_model_name_or_path", type=str, default="", help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--controlnet_model_name_or_path", type=str, default='', help="Path to pretrained controlnet model." " If not specified controlnet weights are initialized from unet.", ) parser.add_argument( "--output_dir", type=str, default="./experiments/test", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1000) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", action="store_true", help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing the target image." ) parser.add_argument( "--conditioning_image_column", type=str, default="conditioning_image", help="The column of the dataset containing the controlnet conditioning image.", ) parser.add_argument( "--caption_column", type=str, default="text", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--proportion_empty_prompts", type=float, default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) parser.add_argument( "--validation_prompt", type=str, default=[""], nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=[""], nargs="+", help=( "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( "--is_start_lr", type=bool, default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--num_validation_images", type=int, default=100, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=1, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_ccsr_stage1", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) return args def previous_timestep(timestep): if noise_scheduler.custom_timesteps: index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == noise_scheduler.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = noise_scheduler.timesteps[index + 1] else: num_inference_steps = ( noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps ) prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps return prev_t def predict_start_from_noise(sample, t, model_output): t = t.to(noise_scheduler.alphas_cumprod.device) prev_t = previous_timestep(t) # 1. compute alphas, betas alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device) alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if noise_scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif noise_scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif noise_scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample # def main(args): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet") else: logger.info("Initializing controlnet weights from unet") controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): i = len(weights) - 1 for i, model in enumerate(models): sub_dir = "unet" if isinstance(model, UNet2DConditionModel) else "controlnet" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): assert len(models) == 2 for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model if not isinstance(model, UNet2DConditionModel): load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True else: load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) controlnet.train() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(controlnet).dtype != torch.float32: raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" ) if accelerator.unwrap_model(unet).dtype != torch.float32: raise ValueError( f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW params_to_optimize = list(controlnet.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) train_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders, tokenizer=tokenizer, gt_ratio=0) # let lr is gt train_dataloader = torch.utils.data.DataLoader( train_dataset, num_workers=args.dataloader_num_workers, batch_size=args.train_batch_size, shuffle=False ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae, unet and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( controlnet, optimizer, train_dataloader, lr_scheduler ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) # tensorboard cannot handle list types for config tracker_config.pop("validation_prompt") tracker_config.pop("validation_image") accelerator.init_trackers(args.tracker_project_name, config=tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype) # Convert images to latent space latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image t_max = round(noise_scheduler.config.num_train_timesteps*args.t_max) timesteps = torch.randint(0, t_max, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_caption"].to(accelerator.device))[0] controlnet_image = batch["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype) vae_encode_condition_hidden_states = vae.encode(2*controlnet_image-1).latent_dist.sample() vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss_ori = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # original loss ### loss for t_max ### noise2 = torch.randn_like(latents) timesteps = args.start_timesteps * torch.ones(model_pred.shape[0]).to(accelerator.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.start_timesteps==1: noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise2, timesteps) down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Predict x0 for T x0_T = noisy_latents - model_pred else: if args.is_start_lr: noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise2, timesteps) else: noisy_latents = noise_scheduler.add_noise(latents, noise2, timesteps) down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Predict x0 for T x0_T = predict_start_from_noise(noisy_latents, timesteps[0], model_pred) # Re-add noise on x0_tmax noise3 = torch.randn_like(latents) timesteps = t_max * torch.ones(model_pred.shape[0]).to(accelerator.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(x0_T, noise3, timesteps[0]) down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Predict x0 for t_max x0_tmax = predict_start_from_noise(noisy_latents, timesteps[0], model_pred) loss_x0 = F.mse_loss(x0_T.float(), latents.float(), reduction="mean") loss_x0_from_tao = F.mse_loss(x0_tmax.float(), latents.float(), reduction="mean") loss = loss_ori + loss_x0 + loss_x0_from_tao accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = controlnet.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # if args.validation_prompt is not None and global_step % args.validation_steps == 0: if False: image_logs = log_validation( vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, global_step, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: controlnet = accelerator.unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) if args.push_to_hub: save_model_card( repo_id, image_logs=image_logs, base_model=args.pretrained_model_name_or_path, repo_folder=args.output_dir, ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() ================================================ FILE: train_ccsr_stage2.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import logging import math import os import random import shutil from pathlib import Path from ADD.models.discriminator import ProjectedDiscriminator import accelerate import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel # from models.losses import LPIPSWithDiscriminator from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from accelerate import DistributedDataParallelKwargs from dataloaders.paired_dataset_txt import PairedCaptionDataset from ADD.models.vit import vit_large, vit_small import ADD.utils.util_net as util_net if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.21.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) pipeline = StableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if len(args.validation_image) == len(args.validation_prompt): validation_images = args.validation_image validation_prompts = args.validation_prompt elif len(args.validation_image) == 1: validation_images = args.validation_image * len(args.validation_prompt) validation_prompts = args.validation_prompt elif len(args.validation_prompt) == 1: validation_images = args.validation_image validation_prompts = args.validation_prompt * len(args.validation_image) else: raise ValueError( "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) image_logs = [] for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") images = [] for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline( validation_prompt, validation_image, num_inference_steps=20, generator=generator ).images[0] images.append(image) image_logs.append( {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images = [] formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) formatted_images = np.stack(formatted_images) tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") elif tracker.name == "wandb": formatted_images = [] for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) tracker.log({"validation": formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") return image_logs def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: img_str = "You can find some example images below.\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} tags: - stable-diffusion - stable-diffusion-diffusers - text-to-image - diffusers - controlnet inference: true --- """ model_card = f""" # controlnet-{repo_id} These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument('--dataset_root_folders', type=str, default="") parser.add_argument("--is_module", action="store_true") parser.add_argument("--t_max", type=float, default=0.6666) parser.add_argument("--t_min", type=float, default=0.5) parser.add_argument("--num_inference_steps", type=int, default=1) parser.add_argument("--start_timesteps", type=int, default=999) parser.add_argument("--lambda_l2", type=float, default=1.0) parser.add_argument("--lambda_lpips", type=float, default=1.0) parser.add_argument("--lambda_disc", type=float, default=0.05) parser.add_argument("--lambda_disc_train", type=float, default=0.5) parser.add_argument("--begin_disc", type=float, default=100) parser.add_argument( "--is_start_lr", type=bool, default=True, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--vae_model_name_or_path", type=str, default='', help="Path to pretrained vae model." " If not specified vae weights are initialized from pre-trained model.", ) parser.add_argument( "--pretrained_model_name_or_path", type=str, default="", help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--controlnet_model_name_or_path", type=str, default='', help="Path to pretrained controlnet model." " If not specified controlnet weights are initialized from unet.", ) parser.add_argument( "--output_dir", type=str, default="./experiments/test", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1000) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", action="store_true", help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing the target image." ) parser.add_argument( "--conditioning_image_column", type=str, default="conditioning_image", help="The column of the dataset containing the controlnet conditioning image.", ) parser.add_argument( "--caption_column", type=str, default="text", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--proportion_empty_prompts", type=float, default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) parser.add_argument( "--validation_prompt", type=str, default=[""], nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=[""], nargs="+", help=( "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( "--num_validation_images", type=int, default=100, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=1, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_ccsr_stage2", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) return args def previous_timestep(timestep): if noise_scheduler.custom_timesteps: index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == noise_scheduler.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = noise_scheduler.timesteps[index + 1] else: num_inference_steps = ( noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps ) prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps return prev_t def predict_start_from_noise(sample, t, model_output): t = t.to(noise_scheduler.alphas_cumprod.device) prev_t = previous_timestep(t) # 1. compute alphas, betas alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device) alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if noise_scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif noise_scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif noise_scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample # def main(args): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and smodels noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler.set_timesteps(args.num_inference_steps) text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # Load VAE model if args.vae_model_name_or_path: logger.info("Loading existing vae weights") vae = AutoencoderKL.from_pretrained(args.vae_model_name_or_path, subfolder="vae", revision=args.revision) else: logger.info("Loading pretrained vae weights") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) # Load Controlnet model if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet") else: logger.info("Initializing controlnet weights from unet") controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True) # # Load discriminator model # discriminatornet = LPIPSWithDiscriminator(disc_start=1.0, kl_weight=0, perceptual_weight=1.0, disc_weight=0.5, disc_factor=1.0) # Load discriminator model discriminatornet = ProjectedDiscriminator(c_dim=384).train() criterion_GAN = torch.nn.BCEWithLogitsLoss() # 实例化提取cls_lr的特征网络 model_fea = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0) util_net.reload_model(model_fea, torch.load('preset/models/dino/dinov2_vits14_pretrain.pth')) model_fea.requires_grad_(False) # load lpips model import lpips net_lpips = lpips.LPIPS(net='vgg').cuda() # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): i = len(weights) - 1 assert len(models) == 2 and len(weights) == 2 for i, model in enumerate(models): if i==0: sub_dir = 'vae' model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): assert len(models) == 2 for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model if not isinstance(model, UNet2DConditionModel): load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True else: load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) controlnet.requires_grad_(False) discriminatornet.train() vae.train() # unlease vae decoder for training for name, params in vae.named_parameters(): if 'decoder' in name: params.requires_grad = True if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: vae.enable_gradient_checkpointing() discriminatornet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(vae).dtype != torch.float32: raise ValueError( f"vae loaded as datatype {accelerator.unwrap_model(vae).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW # Optimizer creation params_to_optimize = list(vae.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) params_to_optimize_disc = list(discriminatornet.parameters()) optimizer_disc = optimizer_class( params_to_optimize_disc, lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) train_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders, tokenizer=tokenizer, gt_ratio=0) # let lr is gt train_dataloader = torch.utils.data.DataLoader( train_dataset, num_workers=args.dataloader_num_workers, batch_size=args.train_batch_size, shuffle=False ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move controlnet, unet and text_encoder to device and cast to weight_dtype controlnet.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) model_fea.to(accelerator.device, dtype=weight_dtype) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) lr_scheduler_disc = get_scheduler( args.lr_scheduler, optimizer=optimizer_disc, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. vae, discriminatornet, optimizer, optimizer_disc, train_dataloader, lr_scheduler, lr_scheduler_disc = accelerator.prepare( vae, discriminatornet, optimizer, optimizer_disc, train_dataloader, lr_scheduler, lr_scheduler_disc ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) # tensorboard cannot handle list types for config tracker_config.pop("validation_prompt") tracker_config.pop("validation_image") accelerator.init_trackers(args.tracker_project_name, config=tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): l_acc = [vae, discriminatornet] with accelerator.accumulate(*l_acc): with torch.no_grad(): total_time_steps = noise_scheduler.timesteps num_time_steps = len(total_time_steps) if num_time_steps != 1: timesteps_loop = total_time_steps[-round(num_time_steps*args.t_max):] timesteps_loop = timesteps_loop[:-round(num_time_steps*args.t_min)] t_max = timesteps_loop[0] t_min = timesteps_loop[-1] pixel_values = batch["pixel_values"].to(accelerator.device) if args.is_module: latents_gt = vae.module.encode(pixel_values).latent_dist.sample() latents_gt = latents_gt * vae.module.config.scaling_factor # Convert images to latent space else: latents_gt = vae.encode(pixel_values).latent_dist.sample() latents_gt = latents_gt * vae.config.scaling_factor # Convert images to latent space encoder_hidden_states = text_encoder(batch["input_caption"].to(accelerator.device))[0] controlnet_image = batch["conditioning_pixel_values"].to(accelerator.device) controlnet_image_encode = 2*controlnet_image-1 if args.is_module: vae_encode_condition_hidden_states = vae.module.encode(controlnet_image_encode).latent_dist.sample() vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.module.config.scaling_factor else: vae_encode_condition_hidden_states = vae.encode(controlnet_image_encode).latent_dist.sample() vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor # Convert images to latent space if global_step > args.begin_disc: lambda_l2 = args.lambda_l2 lambda_lpips = args.lambda_lpips lambda_disc = args.lambda_disc lambda_disc_train = args.lambda_disc_train else: lambda_l2 = args.lambda_l2 lambda_lpips = 0 lambda_disc = 0 lambda_disc_train = args.lambda_disc_train noise = torch.randn_like(latents_gt) bsz = latents_gt.shape[0] timesteps = args.start_timesteps * torch.ones(latents_gt.shape[0]).to(accelerator.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) if args.start_timesteps==1: noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise, timesteps) noisy_latents = noisy_latents.to(dtype=weight_dtype) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) controlnet_image = controlnet_image.to(dtype=weight_dtype) vae_encode_condition_hidden_states = vae_encode_condition_hidden_states.to(dtype=weight_dtype) down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Predict x0 for T x0_T = noisy_latents - model_pred else: # Sample noise based on LR (controlnet_image) or a Random Noise? if args.is_start_lr: noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise, timesteps) noisy_latents = noisy_latents.to(dtype=weight_dtype) else: noisy_latents = noise_scheduler.add_noise(latents_gt, noise, timesteps) noisy_latents = noisy_latents.to(dtype=weight_dtype) encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) controlnet_image = controlnet_image.to(dtype=weight_dtype) vae_encode_condition_hidden_states = vae_encode_condition_hidden_states.to(dtype=weight_dtype) down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # Predict x0 for T x0_T = predict_start_from_noise(noisy_latents, timesteps[0], model_pred) if num_time_steps!=1: # Re-add noise on x0_tmax noise2 = torch.randn_like(latents_gt) timesteps = t_max * torch.ones(model_pred.shape[0]).to(accelerator.device) timesteps = timesteps.long() latents = noise_scheduler.add_noise(x0_T, noise2, timesteps[0]) # Denoising loop for i, t in enumerate(timesteps_loop): # controlnet_latent_model_input = noise_scheduler.scale_model_input(latents, t) latents = latents.to(dtype=weight_dtype) down_block_res_samples, mid_block_res_sample = controlnet( latents, t, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # predict the noise residual noise_pred = unet( latents, t, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, return_dict=False, )[0] # compute the previous noisy sample x_t -> x_t-1 latents_old = latents latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] x0_tmin = predict_start_from_noise(latents_old, t, noise_pred) latents = x0_tmin latents = latents.to(dtype=torch.float32) else: latents = x0_T.to(dtype=torch.float32) # optimize the generator: vae decoder discriminatornet.requires_grad_(False) if args.is_module: image = vae.module.decode(latents / vae.module.config.scaling_factor, return_dict=False)[0].clamp(-1, 1) else: image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0].clamp(-1, 1) # compute the discriminator loss & update parameters _, cls_lr = model_fea(F.interpolate(controlnet_image, size=518, mode='bilinear')) # compute the generator loss pred_fake, _ = discriminatornet(image, cls_lr.detach()) pred_fake = torch.cat(pred_fake, dim=1) gan_loss = -torch.mean(pred_fake) loss_x0 = F.mse_loss(image.float(), pixel_values.float(), reduction="mean") * lambda_l2 if lambda_lpips != 0: loss_lpips = net_lpips(image.float(), pixel_values.float()).mean() * lambda_lpips loss_x0 = loss_lpips + loss_x0 loss = loss_x0 + lambda_disc * gan_loss accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = vae.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) # update discriminator discriminatornet.requires_grad_(True) if args.is_module: discriminatornet.module.dino.requires_grad_(False) else: discriminatornet.dino.requires_grad_(False) pred_real, features = discriminatornet(pixel_values, cls_lr.detach()) pred_fake, _ = discriminatornet(image.detach(), cls_lr.detach()) pred_fake = torch.cat(pred_fake, dim=1) pred_real = torch.cat(pred_real, dim=1) loss_real = torch.mean(torch.relu(1.0 - pred_real)) * lambda_disc_train loss_fake = torch.mean(torch.relu(1.0 + pred_fake)) * lambda_disc_train loss_disc = loss_real + loss_fake accelerator.backward(loss_disc) if accelerator.sync_gradients: params_to_clip = discriminatornet.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer_disc.step() lr_scheduler_disc.step() optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none) model_fea.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # if args.validation_prompt is not None and global_step % args.validation_steps == 0: if False: image_logs = log_validation( vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, global_step, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: vae = accelerator.unwrap_model(vae) vae.save_pretrained(args.output_dir) if args.push_to_hub: save_model_card( repo_id, image_logs=image_logs, base_model=args.pretrained_model_name_or_path, repo_folder=args.output_dir, ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() ================================================ FILE: train_controlnet.py ================================================ #!/usr/bin/env python # coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and import argparse import logging import math import os import random import shutil from pathlib import Path import accelerate import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from models.controlnet import ControlNetModel from models.unet_2d_condition import UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available from dataloaders.paired_dataset_txt import PairedCaptionDataset if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.21.0.dev0") logger = get_logger(__name__) def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols w, h = imgs[0].size grid = Image.new("RGB", size=(cols * w, rows * h)) for i, img in enumerate(imgs): grid.paste(img, box=(i % cols * w, i // cols * h)) return grid def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") controlnet = accelerator.unwrap_model(controlnet) pipeline = StableDiffusionControlNetPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, safety_checker=None, revision=args.revision, torch_dtype=weight_dtype, ) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) if args.enable_xformers_memory_efficient_attention: pipeline.enable_xformers_memory_efficient_attention() if args.seed is None: generator = None else: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if len(args.validation_image) == len(args.validation_prompt): validation_images = args.validation_image validation_prompts = args.validation_prompt elif len(args.validation_image) == 1: validation_images = args.validation_image * len(args.validation_prompt) validation_prompts = args.validation_prompt elif len(args.validation_prompt) == 1: validation_images = args.validation_image validation_prompts = args.validation_prompt * len(args.validation_image) else: raise ValueError( "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" ) image_logs = [] for validation_prompt, validation_image in zip(validation_prompts, validation_images): validation_image = Image.open(validation_image).convert("RGB") images = [] for _ in range(args.num_validation_images): with torch.autocast("cuda"): image = pipeline( validation_prompt, validation_image, num_inference_steps=20, generator=generator ).images[0] images.append(image) image_logs.append( {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} ) for tracker in accelerator.trackers: if tracker.name == "tensorboard": for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images = [] formatted_images.append(np.asarray(validation_image)) for image in images: formatted_images.append(np.asarray(image)) formatted_images = np.stack(formatted_images) tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") elif tracker.name == "wandb": formatted_images = [] for log in image_logs: images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) for image in images: image = wandb.Image(image, caption=validation_prompt) formatted_images.append(image) tracker.log({"validation": formatted_images}) else: logger.warn(f"image logging not implemented for {tracker.name}") return image_logs def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision, ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "RobertaSeriesModelWithTransformation": from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation else: raise ValueError(f"{model_class} is not supported.") def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): img_str = "" if image_logs is not None: img_str = "You can find some example images below.\n" for i, log in enumerate(image_logs): images = log["images"] validation_prompt = log["validation_prompt"] validation_image = log["validation_image"] validation_image.save(os.path.join(repo_folder, "image_control.png")) img_str += f"prompt: {validation_prompt}\n" images = [validation_image] + images image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) img_str += f"![images_{i})](./images_{i}.png)\n" yaml = f""" --- license: creativeml-openrail-m base_model: {base_model} tags: - stable-diffusion - stable-diffusion-diffusers - text-to-image - diffusers - controlnet inference: true --- """ model_card = f""" # controlnet-{repo_id} These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument('--dataset_root_folders', type=str, default="") parser.add_argument("--t_max", type=float, default=0.6667) parser.add_argument("--start_timesteps", type=int, default=999) parser.add_argument( "--pretrained_model_name_or_path", type=str, default="", help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--controlnet_model_name_or_path", type=str, default='', help="Path to pretrained controlnet model." " If not specified controlnet weights are initialized from unet.", ) parser.add_argument( "--output_dir", type=str, default="./experiments/test", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images, all the images in the train/validation dataset will be resized to this" " resolution" ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument("--num_train_epochs", type=int, default=1000) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--checkpointing_steps", type=int, default=500, help=( "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" "instructions." ), ) parser.add_argument( "--revision", type=str, default=None, required=False, help=( "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" " float32 precision." ), ) parser.add_argument( "--tokenizer_name", type=str, default=None, help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") parser.add_argument( "--hub_model_id", type=str, default=None, help="The name of the repository to keep in sync with the local `output_dir`.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default="fp16", choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) parser.add_argument( "--set_grads_to_none", action="store_true", help=( "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" " behaviors, so disable this argument if it causes any problems. More info:" " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" ), ) parser.add_argument( "--dataset_name", type=str, default=None, help=( "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," " or to a folder containing files that 🤗 Datasets can understand." ), ) parser.add_argument( "--dataset_config_name", type=str, default=None, help="The config of the Dataset, leave as None if there's only one config.", ) parser.add_argument( "--image_column", type=str, default="image", help="The column of the dataset containing the target image." ) parser.add_argument( "--conditioning_image_column", type=str, default="conditioning_image", help="The column of the dataset containing the controlnet conditioning image.", ) parser.add_argument( "--caption_column", type=str, default="text", help="The column of the dataset containing a caption or a list of captions.", ) parser.add_argument( "--max_train_samples", type=int, default=None, help=( "For debugging purposes or quicker training, truncate the number of training examples to this " "value if set." ), ) parser.add_argument( "--proportion_empty_prompts", type=float, default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) parser.add_argument( "--validation_prompt", type=str, default=[""], nargs="+", help=( "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." " Provide either a matching number of `--validation_image`s, a single `--validation_image`" " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." ), ) parser.add_argument( "--validation_image", type=str, default=[""], nargs="+", help=( "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" " `--validation_image` that will be used with all `--validation_prompt`s." ), ) parser.add_argument( "--is_start_lr", type=bool, default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--num_validation_images", type=int, default=100, help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", ) parser.add_argument( "--validation_steps", type=int, default=1, help=( "Run validation every X steps. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), ) parser.add_argument( "--tracker_project_name", type=str, default="train_ccsr_stage1", help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() if args.resolution % 8 != 0: raise ValueError( "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." ) return args def previous_timestep(timestep): if noise_scheduler.custom_timesteps: index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0] if index == noise_scheduler.timesteps.shape[0] - 1: prev_t = torch.tensor(-1) else: prev_t = noise_scheduler.timesteps[index + 1] else: num_inference_steps = ( noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps ) prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps return prev_t def predict_start_from_noise(sample, t, model_output): t = t.to(noise_scheduler.alphas_cumprod.device) prev_t = previous_timestep(t) # 1. compute alphas, betas alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device) alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if noise_scheduler.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif noise_scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif noise_scheduler.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) return pred_original_sample # def main(args): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) if args.push_to_hub: repo_id = create_repo( repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token ).repo_id # Load the tokenizer if args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) elif args.pretrained_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False, ) # import correct text encoder class text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) # Load scheduler and models noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) if args.controlnet_model_name_or_path: logger.info("Loading existing controlnet weights") controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder="controlnet") else: logger.info("Initializing controlnet weights from unet") controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True) # `accelerate` 0.16.0 will have better support for customized saving if version.parse(accelerate.__version__) >= version.parse("0.16.0"): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): i = len(weights) - 1 for i, model in enumerate(models): sub_dir = "unet" if isinstance(model, UNet2DConditionModel) else "controlnet" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again weights.pop() def load_model_hook(models, input_dir): assert len(models) == 2 for i in range(len(models)): # pop models so that they are not loaded again model = models.pop() # load diffusers style into model if not isinstance(model, UNet2DConditionModel): load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True else: load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) unet.requires_grad_(False) text_encoder.requires_grad_(False) controlnet.train() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): import xformers xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() controlnet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") if args.gradient_checkpointing: unet.enable_gradient_checkpointing() controlnet.enable_gradient_checkpointing() # Check that all trainable models are in full precision low_precision_error_string = ( " Please make sure to always have all model weights in full float32 precision when starting training - even if" " doing mixed precision training, copy of the weights should still be float32." ) if accelerator.unwrap_model(controlnet).dtype != torch.float32: raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" ) if accelerator.unwrap_model(unet).dtype != torch.float32: raise ValueError( f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: torch.backends.cuda.matmul.allow_tf32 = True # Optimizer creation if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW params_to_optimize = list(controlnet.parameters()) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) # Training dataset creation train_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders, tokenizer=tokenizer, gt_ratio=0) # let lr is gt train_dataloader = torch.utils.data.DataLoader( train_dataset, num_workers=args.dataloader_num_workers, batch_size=args.train_batch_size, shuffle=True ) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 # Move vae, unet and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( controlnet, optimizer, train_dataloader, lr_scheduler ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = dict(vars(args)) # tensorboard cannot handle list types for config tracker_config.pop("validation_prompt") tracker_config.pop("validation_image") accelerator.init_trackers(args.tracker_project_name, config=tracker_config) # Begin to train total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): pixel_values = batch["pixel_values"].to(accelerator.device, dtype=weight_dtype) # Convert images to latent space latents = vae.encode(pixel_values).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_caption"].to(accelerator.device))[0] controlnet_image = batch["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype) vae_encode_condition_hidden_states = vae.encode(2*controlnet_image-1).latent_dist.sample() vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor down_block_res_samples, mid_block_res_sample = controlnet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_image, return_dict=False, vae_encode_condition_hidden_states=vae_encode_condition_hidden_states, ) # Predict the noise residual model_pred = unet( noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ], mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") # original loss accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = controlnet.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=args.set_grads_to_none) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") # if args.validation_prompt is not None and global_step % args.validation_steps == 0: if False: image_logs = log_validation( vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, global_step, ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: controlnet = accelerator.unwrap_model(controlnet) controlnet.save_pretrained(args.output_dir) unet = accelerator.unwrap_model(unet) unet.save_pretrained(args.output_dir) if args.push_to_hub: save_model_card( repo_id, image_logs=image_logs, base_model=args.pretrained_model_name_or_path, repo_folder=args.output_dir, ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, commit_message="End of training", ignore_patterns=["step_*", "epoch_*"], ) accelerator.end_training() # if __name__ == "__main__": # args = parse_args() # main(args) ================================================ FILE: utils/devices.py ================================================ import sys import contextlib from functools import lru_cache import torch #from modules import errors if sys.platform == "darwin": from modules import mac_specific def has_mps() -> bool: if sys.platform != "darwin": return False else: return mac_specific.has_mps def get_cuda_device_string(): return "cuda" def get_optimal_device_name(): if torch.cuda.is_available(): return get_cuda_device_string() if has_mps(): return "mps" return "cpu" def get_optimal_device(): return torch.device(get_optimal_device_name()) def get_device_for(task): return get_optimal_device() def torch_gc(): if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() if has_mps(): mac_specific.torch_mps_gc() def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True enable_tf32() #errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device("cuda") dtype = torch.float16 dtype_vae = torch.float16 dtype_unet = torch.float16 unet_needs_upcast = False def cond_cast_unet(input): return input.to(dtype_unet) if unet_needs_upcast else input def cond_cast_float(input): return input.float() if unet_needs_upcast else input def randn(seed, shape): torch.manual_seed(seed) return torch.randn(shape, device=device) def randn_without_seed(shape): return torch.randn(shape, device=device) def autocast(disable=False): if disable: return contextlib.nullcontext() return torch.autocast("cuda") def without_autocast(disable=False): return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): pass def test_for_nans(x, where): if not torch.all(torch.isnan(x)).item(): return if where == "unet": message = "A tensor with all NaNs was produced in Unet." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." else: message = "A tensor with all NaNs was produced." message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) @lru_cache def first_time_calculation(): """ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and spends about 2.7 seconds doing that, at least wih NVidia. """ x = torch.zeros((1, 1)).to(device, dtype) linear = torch.nn.Linear(1, 1).to(device, dtype) linear(x) x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) ================================================ FILE: utils/img_util.py ================================================ import os import PIL import cv2 import math import numpy as np import torch import torchvision import imageio from einops import rearrange def save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0): videos = rearrange(videos, "b c t h w -> t b c h w").cpu() outputs = [] for x in videos: x = torchvision.utils.make_grid(x, nrow=n_rows) x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) if rescale: x = (x / 2.0 + 0.5).clamp(0, 1) # -1,1 -> 0,1 x = (x * 255).numpy().astype(np.uint8) #x = adjust_gamma(x, 0.5) outputs.append(x) outputs = outputs[discardN:] if path is not None: #os.makedirs(os.path.dirname(path), exist_ok=True) imageio.mimsave(path, outputs, duration=1000/fps, loop=0) return outputs def convert_image_to_fn(img_type, minsize, image, eps=0.02): width, height = image.size if min(width, height) < minsize: scale = minsize/min(width, height) + eps image = image.resize((math.ceil(width*scale), math.ceil(height*scale))) if image.mode != img_type: return image.convert(img_type) return image ================================================ FILE: utils/misc.py ================================================ import os import binascii from safetensors import safe_open import torch from diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint def rand_name(length=8, suffix=''): name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') if suffix: if not suffix.startswith('.'): suffix = '.' + suffix name += suffix return name def cycle(dl): while True: for data in dl: yield data def exists(x): return x is not None def identity(x): return x def load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=""): if model_path is None: return unet if model_path.endswith(".ckpt"): base_state_dict = torch.load(model_path)['state_dict'] elif model_path.endswith(".safetensors"): state_dict = {} with safe_open(model_path, framework="pt", device="cpu") as f: for key in f.keys(): state_dict[key] = f.get_tensor(key) is_lora = all("lora" in k for k in state_dict.keys()) if not is_lora: base_state_dict = state_dict else: base_state_dict = {} with safe_open(model_base, framework="pt", device="cpu") as f: for key in f.keys(): base_state_dict[key] = f.get_tensor(key) converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config) unet_state_dict = unet.state_dict() for key in converted_unet_checkpoint: converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key] unet.load_state_dict(converted_unet_checkpoint, strict=False) if vae is not None: converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config) vae.load_state_dict(converted_vae_checkpoint) return unet, vae ================================================ FILE: utils/vaehook.py ================================================ # ------------------------------------------------------------------------ # # Ultimate VAE Tile Optimization # # Introducing a revolutionary new optimization designed to make # the VAE work with giant images on limited VRAM! # Say goodbye to the frustration of OOM and hello to seamless output! # # ------------------------------------------------------------------------ # # This script is a wild hack that splits the image into tiles, # encodes each tile separately, and merges the result back together. # # Advantages: # - The VAE can now work with giant images on limited VRAM # (~10 GB for 8K images!) # - The merged output is completely seamless without any post-processing. # # Drawbacks: # - Giant RAM needed. To store the intermediate results for a 4096x4096 # images, you need 32 GB RAM it consumes ~20GB); for 8192x8192 # you need 128 GB RAM machine (it consumes ~100 GB) # - NaNs always appear in for 8k images when you use fp16 (half) VAE # You must use --no-half-vae to disable half VAE for that giant image. # - Slow speed. With default tile size, it takes around 50/200 seconds # to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode # a 8192x8192 image. (The speed is limited by both the GPU and the CPU.) # - The gradient calculation is not compatible with this hack. It # will break any backward() or torch.autograd.grad() that passes VAE. # (But you can still use the VAE to generate training data.) # # How it works: # 1) The image is split into tiles. # - To ensure perfect results, each tile is padded with 32 pixels # on each side. # - Then the conv2d/silu/upsample/downsample can produce identical # results to the original image without splitting. # 2) The original forward is decomposed into a task queue and a task worker. # - The task queue is a list of functions that will be executed in order. # - The task worker is a loop that executes the tasks in the queue. # 3) The task queue is executed for each tile. # - Current tile is sent to GPU. # - local operations are directly executed. # - Group norm calculation is temporarily suspended until the mean # and var of all tiles are calculated. # - The residual is pre-calculated and stored and addded back later. # - When need to go to the next tile, the current tile is send to cpu. # 4) After all tiles are processed, tiles are merged on cpu and return. # # Enjoy! # # @author: LI YI @ Nanyang Technological University - Singapore # @date: 2023-03-02 # @license: MIT License # # Please give me a star if you like this project! # # ------------------------------------------------------------------------- import gc from time import time import math from tqdm import tqdm import torch import torch.version import torch.nn.functional as F from einops import rearrange import sys import myutils.devices as devices #from modules.shared import state #from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock try: import xformers import xformers.ops except ImportError: pass sd_flag = False def get_recommend_encoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 16*1000: ENCODER_TILE_SIZE = 3072 elif total_memory > 12*1000: ENCODER_TILE_SIZE = 2048 elif total_memory > 8*1000: ENCODER_TILE_SIZE = 1536 else: ENCODER_TILE_SIZE = 960 else: ENCODER_TILE_SIZE = 512 return ENCODER_TILE_SIZE def get_recommend_decoder_tile_size(): if torch.cuda.is_available(): total_memory = torch.cuda.get_device_properties( devices.device).total_memory // 2**20 if total_memory > 30*1000: DECODER_TILE_SIZE = 256 elif total_memory > 16*1000: DECODER_TILE_SIZE = 192 elif total_memory > 12*1000: DECODER_TILE_SIZE = 128 elif total_memory > 8*1000: DECODER_TILE_SIZE = 96 else: DECODER_TILE_SIZE = 64 else: DECODER_TILE_SIZE = 64 return DECODER_TILE_SIZE if 'global const': DEFAULT_ENABLED = False DEFAULT_MOVE_TO_GPU = False DEFAULT_FAST_ENCODER = True DEFAULT_FAST_DECODER = True DEFAULT_COLOR_FIX = 0 DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size() DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size() # inplace version of silu def inplace_nonlinearity(x): # Test: fix for Nans return F.silu(x, inplace=True) # extracted from ldm.modules.diffusionmodules.model # from diffusers lib def attn_forward_new(self, h_): batch_size, channel, height, width = h_.shape hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2) attention_mask = None encoder_hidden_states = None batch_size, sequence_length, _ = hidden_states.shape attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) query = self.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif self.norm_cross: encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) key = self.to_k(encoder_hidden_states) value = self.to_v(encoder_hidden_states) query = self.head_to_batch_dim(query) key = self.head_to_batch_dim(key) value = self.head_to_batch_dim(value) attention_probs = self.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = self.batch_to_head_dim(hidden_states) # linear proj hidden_states = self.to_out[0](hidden_states) # dropout hidden_states = self.to_out[1](hidden_states) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) return hidden_states def attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q.shape q = q.reshape(b, c, h*w) q = q.permute(0, 2, 1) # b,hw,c k = k.reshape(b, c, h*w) # b,c,hw w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w_ = w_ * (int(c)**(-0.5)) w_ = torch.nn.functional.softmax(w_, dim=2) # attend to values v = v.reshape(b, c, h*w) w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] h_ = torch.bmm(v, w_) h_ = h_.reshape(b, c, h, w) h_ = self.proj_out(h_) return h_ def xformer_attn_forward(self, h_): q = self.q(h_) k = self.k(h_) v = self.v(h_) # compute attention B, C, H, W = q.shape q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) q, k, v = map( lambda t: t.unsqueeze(3) .reshape(B, t.shape[1], 1, C) .permute(0, 2, 1, 3) .reshape(B * 1, t.shape[1], C) .contiguous(), (q, k, v), ) out = xformers.ops.memory_efficient_attention( q, k, v, attn_bias=None, op=self.attention_op) out = ( out.unsqueeze(0) .reshape(B, 1, out.shape[1], C) .permute(0, 2, 1, 3) .reshape(B, out.shape[1], C) ) out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) out = self.proj_out(out) return out def attn2task(task_queue, net): if False: #isinstance(net, AttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) task_queue.append(['add_res', None]) elif False: #isinstance(net, MemoryEfficientAttnBlock): task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.norm)) task_queue.append( ('attn', lambda x, net=net: xformer_attn_forward(net, x))) task_queue.append(['add_res', None]) else: task_queue.append(('store_res', lambda x: x)) task_queue.append(('pre_norm', net.group_norm)) task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x))) task_queue.append(['add_res', None]) def resblock2task(queue, block): """ Turn a ResNetBlock into a sequence of tasks and append to the task queue @param queue: the target task queue @param block: ResNetBlock """ if block.in_channels != block.out_channels: if sd_flag: if block.use_conv_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: if block.use_in_shortcut: queue.append(('store_res', block.conv_shortcut)) else: queue.append(('store_res', block.nin_shortcut)) else: queue.append(('store_res', lambda x: x)) queue.append(('pre_norm', block.norm1)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv1', block.conv1)) queue.append(('pre_norm', block.norm2)) queue.append(('silu', inplace_nonlinearity)) queue.append(('conv2', block.conv2)) queue.append(['add_res', None]) def build_sampling(task_queue, net, is_decoder): """ Build the sampling part of a task queue @param task_queue: the target task queue @param net: the network @param is_decoder: currently building decoder or encoder """ if is_decoder: # resblock2task(task_queue, net.mid.block_1) # attn2task(task_queue, net.mid.attn_1) # resblock2task(task_queue, net.mid.block_2) # resolution_iter = reversed(range(net.num_resolutions)) # block_ids = net.num_res_blocks + 1 # condition = 0 # module = net.up # func_name = 'upsample' resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) resolution_iter = (range(len(net.up_blocks))) # range(0,4) block_ids = 2 + 1 condition = len(net.up_blocks) - 1 module = net.up_blocks func_name = 'upsamplers' else: # resolution_iter = range(net.num_resolutions) # block_ids = net.num_res_blocks # condition = net.num_resolutions - 1 # module = net.down # func_name = 'downsample' resolution_iter = (range(len(net.down_blocks))) # range(0,4) block_ids = 2 condition = len(net.down_blocks) - 1 module = net.down_blocks func_name = 'downsamplers' for i_level in resolution_iter: for i_block in range(block_ids): resblock2task(task_queue, module[i_level].resnets[i_block]) if i_level != condition: if is_decoder: task_queue.append((func_name, module[i_level].upsamplers[0])) else: task_queue.append((func_name, module[i_level].downsamplers[0])) if not is_decoder: resblock2task(task_queue, net.mid_block.resnets[0]) attn2task(task_queue, net.mid_block.attentions[0]) resblock2task(task_queue, net.mid_block.resnets[1]) def build_task_queue(net, is_decoder): """ Build a single task queue for the encoder or decoder @param net: the VAE decoder or encoder network @param is_decoder: currently building decoder or encoder @return: the task queue """ task_queue = [] task_queue.append(('conv_in', net.conv_in)) # construct the sampling part of the task queue # because encoder and decoder share the same architecture, we extract the sampling part build_sampling(task_queue, net, is_decoder) if is_decoder and not sd_flag: net.give_pre_end = False net.tanh_out = False if not is_decoder or not net.give_pre_end: if sd_flag: task_queue.append(('pre_norm', net.norm_out)) else: task_queue.append(('pre_norm', net.conv_norm_out)) task_queue.append(('silu', inplace_nonlinearity)) task_queue.append(('conv_out', net.conv_out)) if is_decoder and net.tanh_out: task_queue.append(('tanh', torch.tanh)) return task_queue def clone_task_queue(task_queue): """ Clone a task queue @param task_queue: the task queue to be cloned @return: the cloned task queue """ return [[item for item in task] for task in task_queue] def get_var_mean(input, num_groups, eps=1e-6): """ Get mean and var for group norm """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) var, mean = torch.var_mean( input_reshaped, dim=[0, 2, 3, 4], unbiased=False) return var, mean def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): """ Custom group norm with fixed mean and var @param input: input tensor @param num_groups: number of groups. by default, num_groups = 32 @param mean: mean, must be pre-calculated by get_var_mean @param var: var, must be pre-calculated by get_var_mean @param weight: weight, should be fetched from the original group norm @param bias: bias, should be fetched from the original group norm @param eps: epsilon, by default, eps = 1e-6 to match the original group norm @return: normalized tensor """ b, c = input.size(0), input.size(1) channel_in_group = int(c/num_groups) input_reshaped = input.contiguous().view( 1, int(b * num_groups), channel_in_group, *input.size()[2:]) out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, training=False, momentum=0, eps=eps) out = out.view(b, c, *input.size()[2:]) # post affine transform if weight is not None: out *= weight.view(1, -1, 1, 1) if bias is not None: out += bias.view(1, -1, 1, 1) return out def crop_valid_region(x, input_bbox, target_bbox, is_decoder): """ Crop the valid region from the tile @param x: input tile @param input_bbox: original input bounding box @param target_bbox: output bounding box @param scale: scale factor @return: cropped tile """ padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] # ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓ def perfcount(fn): def wrapper(*args, **kwargs): ts = time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats(devices.device) devices.torch_gc() gc.collect() ret = fn(*args, **kwargs) devices.torch_gc() gc.collect() if torch.cuda.is_available(): vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 torch.cuda.reset_peak_memory_stats(devices.device) print( f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') else: print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') return ret return wrapper # copy end :) class GroupNormParam: def __init__(self): self.var_list = [] self.mean_list = [] self.pixel_list = [] self.weight = None self.bias = None def add_tile(self, tile, layer): var, mean = get_var_mean(tile, 32) # For giant images, the variance can be larger than max float16 # In this case we create a copy to float32 if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # ============= DEBUG: test for infinite ============= # if torch.isinf(var).any(): # print('var: ', var) # ==================================================== self.var_list.append(var) self.mean_list.append(mean) self.pixel_list.append( tile.shape[2]*tile.shape[3]) if hasattr(layer, 'weight'): self.weight = layer.weight self.bias = layer.bias else: self.weight = None self.bias = None def summary(self): """ summarize the mean and var and return a function that apply group norm on each tile """ if len(self.var_list) == 0: return None var = torch.vstack(self.var_list) mean = torch.vstack(self.mean_list) max_value = max(self.pixel_list) pixels = torch.tensor( self.pixel_list, dtype=torch.float32, device=devices.device) / max_value sum_pixels = torch.sum(pixels) pixels = pixels.unsqueeze( 1) / sum_pixels var = torch.sum( var * pixels, dim=0) mean = torch.sum( mean * pixels, dim=0) return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) @staticmethod def from_tile(tile, norm): """ create a function from a single tile without summary """ var, mean = get_var_mean(tile, 32) if var.dtype == torch.float16 and var.isinf().any(): fp32_tile = tile.float() var, mean = get_var_mean(fp32_tile, 32) # if it is a macbook, we need to convert back to float16 if var.device.type == 'mps': # clamp to avoid overflow var = torch.clamp(var, 0, 60000) var = var.half() mean = mean.half() if hasattr(norm, 'weight'): weight = norm.weight bias = norm.bias else: weight = None bias = None def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) return group_norm_func class VAEHook: def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False): self.net = net # encoder | decoder self.tile_size = tile_size self.is_decoder = is_decoder self.fast_mode = (fast_encoder and not is_decoder) or ( fast_decoder and is_decoder) self.color_fix = color_fix and not is_decoder self.to_gpu = to_gpu self.pad = 11 if is_decoder else 32 def __call__(self, x): B, C, H, W = x.shape original_device = next(self.net.parameters()).device try: if self.to_gpu: self.net.to(devices.get_optimal_device()) if max(H, W) <= self.pad * 2 + self.tile_size: print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") x_type = x.dtype x = self.net.original_forward(x) x = x.to(dtype=x_type) return x else: x_type = x.dtype x = self.vae_tile_forward(x) x = x.to(dtype=x_type) return x finally: self.net.to(original_device) def get_best_tile_size(self, lowerbound, upperbound): """ Get the best tile size for GPU memory """ divider = 32 while divider >= 2: remainer = lowerbound % divider if remainer == 0: return lowerbound candidate = lowerbound - remainer + divider if candidate <= upperbound: return candidate divider //= 2 return lowerbound def split_tiles(self, h, w): """ Tool function to split the image into tiles @param h: height of the image @param w: width of the image @return: tile_input_bboxes, tile_output_bboxes """ tile_input_bboxes, tile_output_bboxes = [], [] tile_size = self.tile_size pad = self.pad num_height_tiles = math.ceil((h - 2 * pad) / tile_size) num_width_tiles = math.ceil((w - 2 * pad) / tile_size) # If any of the numbers are 0, we let it be 1 # This is to deal with long and thin images num_height_tiles = max(num_height_tiles, 1) num_width_tiles = max(num_width_tiles, 1) # Suggestions from https://github.com/Kahsolt: auto shrink the tile size real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') for i in range(num_height_tiles): for j in range(num_width_tiles): # bbox: [x1, x2, y1, y2] # the padding is is unnessary for image borders. So we directly start from (32, 32) input_bbox = [ pad + j * real_tile_width, min(pad + (j + 1) * real_tile_width, w), pad + i * real_tile_height, min(pad + (i + 1) * real_tile_height, h), ] # if the output bbox is close to the image boundary, we extend it to the image boundary output_bbox = [ input_bbox[0] if input_bbox[0] > pad else 0, input_bbox[1] if input_bbox[1] < w - pad else w, input_bbox[2] if input_bbox[2] > pad else 0, input_bbox[3] if input_bbox[3] < h - pad else h, ] # scale to get the final output bbox output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] tile_output_bboxes.append(output_bbox) # indistinguishable expand the input bbox by pad pixels tile_input_bboxes.append([ max(0, input_bbox[0] - pad), min(w, input_bbox[1] + pad), max(0, input_bbox[2] - pad), min(h, input_bbox[3] + pad), ]) return tile_input_bboxes, tile_output_bboxes @torch.no_grad() def estimate_group_norm(self, z, task_queue, color_fix): device = z.device tile = z last_id = len(task_queue) - 1 while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': last_id -= 1 if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': raise ValueError('No group norm found in the task queue') # estimate until the last group norm for i in range(last_id + 1): task = task_queue[i] if task[0] == 'pre_norm': group_norm_func = GroupNormParam.from_tile(tile, task[1]) task_queue[i] = ('apply_norm', group_norm_func) if i == last_id: return True tile = group_norm_func(tile) elif task[0] == 'store_res': task_id = i + 1 while task_id < last_id and task_queue[task_id][0] != 'add_res': task_id += 1 if task_id >= last_id: continue task_queue[task_id][1] = task[1](tile) elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None elif color_fix and task[0] == 'downsample': for j in range(i, last_id + 1): if task_queue[j][0] == 'store_res': task_queue[j] = ('store_res_cpu', task_queue[j][1]) return True else: tile = task[1](tile) try: devices.test_for_nans(tile, "vae") except: print(f'Nan detected in fast mode estimation. Fast mode disabled.') return False raise IndexError('Should not reach here') @perfcount @torch.no_grad() def vae_tile_forward(self, z): """ Decode a latent vector z into an image in a tiled manner. @param z: latent vector @return: image """ device = next(self.net.parameters()).device net = self.net tile_size = self.tile_size is_decoder = self.is_decoder z = z.detach() # detach the input to avoid backprop N, height, width = z.shape[0], z.shape[2], z.shape[3] net.last_z_shape = z.shape # Split the input into tiles and build a task queue for each tile print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') in_bboxes, out_bboxes = self.split_tiles(height, width) # Prepare tiles by split the input latents tiles = [] for input_bbox in in_bboxes: tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() tiles.append(tile) num_tiles = len(tiles) num_completed = 0 # Build task queues single_task_queue = build_task_queue(net, is_decoder) #print(single_task_queue) if self.fast_mode: # Fast mode: downsample the input image to the tile size, # then estimate the group norm parameters on the downsampled image scale_factor = tile_size / max(height, width) z = z.to(device) downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') # use nearest-exact to keep statictics as close as possible print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') # ======= Special thanks to @Kahsolt for distribution shift issue ======= # # The downsampling will heavily distort its mean and std, so we need to recover it. std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old del std_old, mean_old, std_new, mean_new # occasionally the std_new is too small or too large, which exceeds the range of float16 # so we need to clamp it to max z's range. downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) estimate_task_queue = clone_task_queue(single_task_queue) if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): single_task_queue = estimate_task_queue del downsampled_z task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] # Dummy result result = None result_approx = None #try: # with devices.autocast(): # result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu() #except: pass # Free memory of input latent tensor del z # Task queue execution pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") # execute the task back and forth when switch tiles so that we always # keep one tile on the GPU to reduce unnecessary data transfer forward = True interrupted = False #state.interrupted = interrupted while True: #if state.interrupted: interrupted = True ; break group_norm_param = GroupNormParam() for i in range(num_tiles) if forward else reversed(range(num_tiles)): #if state.interrupted: interrupted = True ; break tile = tiles[i].to(device) input_bbox = in_bboxes[i] task_queue = task_queues[i] interrupted = False while len(task_queue) > 0: #if state.interrupted: interrupted = True ; break # DEBUG: current task # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape) task = task_queue.pop(0) if task[0] == 'pre_norm': group_norm_param.add_tile(tile, task[1]) break elif task[0] == 'store_res' or task[0] == 'store_res_cpu': task_id = 0 res = task[1](tile) if not self.fast_mode or task[0] == 'store_res_cpu': res = res.cpu() while task_queue[task_id][0] != 'add_res': task_id += 1 task_queue[task_id][1] = res elif task[0] == 'add_res': tile += task[1].to(device) task[1] = None else: tile = task[1](tile) #print(tiles[i].shape, tile.shape, task) pbar.update(1) if interrupted: break # check for NaNs in the tile. # If there are NaNs, we abort the process to save user's time #devices.test_for_nans(tile, "vae") #print(tiles[i].shape, tile.shape, i, num_tiles) if len(task_queue) == 0: tiles[i] = None num_completed += 1 if result is None: # NOTE: dim C varies from different cases, can only be inited dynamically result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) del tile elif i == num_tiles - 1 and forward: forward = False tiles[i] = tile elif i == 0 and not forward: forward = True tiles[i] = tile else: tiles[i] = tile.cpu() del tile if interrupted: break if num_completed == num_tiles: break # insert the group norm task to the head of each task queue group_norm_func = group_norm_param.summary() if group_norm_func is not None: for i in range(num_tiles): task_queue = task_queues[i] task_queue.insert(0, ('apply_norm', group_norm_func)) # Done! pbar.close() return result if result is not None else result_approx.to(device) ================================================ FILE: utils/wavelet_color_fix.py ================================================ ''' # -------------------------------------------------------------------------------- # Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py) # -------------------------------------------------------------------------------- ''' import torch from PIL import Image from torch import Tensor from torch.nn import functional as F from torchvision.transforms import ToTensor, ToPILImage def adain_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = to_tensor(target).unsqueeze(0) source_tensor = to_tensor(source).unsqueeze(0) # Apply adaptive instance normalization result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) # Convert tensor back to image to_image = ToPILImage() result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) return result_image def wavelet_color_fix(target: Image, source: Image): # Convert images to tensors to_tensor = ToTensor() target_tensor = to_tensor(target).unsqueeze(0) source_tensor = to_tensor(source).unsqueeze(0) # Apply wavelet reconstruction result_tensor = wavelet_reconstruction(target_tensor, source_tensor) # Convert tensor back to image to_image = ToPILImage() result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) return result_image def calc_mean_std(feat: Tensor, eps=1e-5): """Calculate mean and std for adaptive_instance_normalization. Args: feat (Tensor): 4D tensor. eps (float): A small value added to the variance to avoid divide-by-zero. Default: 1e-5. """ size = feat.size() assert len(size) == 4, 'The input feature should be 4D tensor.' b, c = size[:2] feat_var = feat.reshape(b, c, -1).var(dim=2) + eps feat_std = feat_var.sqrt().reshape(b, c, 1, 1) feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) return feat_mean, feat_std def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): """Adaptive instance normalization. Adjust the reference features to have the similar color and illuminations as those in the degradate features. Args: content_feat (Tensor): The reference feature. style_feat (Tensor): The degradate features. """ size = content_feat.size() style_mean, style_std = calc_mean_std(style_feat) content_mean, content_std = calc_mean_std(content_feat) normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) return normalized_feat * style_std.expand(size) + style_mean.expand(size) def wavelet_blur(image: Tensor, radius: int): """ Apply wavelet blur to the input tensor. """ # input shape: (1, 3, H, W) # convolution kernel kernel_vals = [ [0.0625, 0.125, 0.0625], [0.125, 0.25, 0.125], [0.0625, 0.125, 0.0625], ] kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) # add channel dimensions to the kernel to make it a 4D tensor kernel = kernel[None, None] # repeat the kernel across all input channels kernel = kernel.repeat(3, 1, 1, 1) image = F.pad(image, (radius, radius, radius, radius), mode='replicate') # apply convolution output = F.conv2d(image, kernel, groups=3, dilation=radius) return output def wavelet_decomposition(image: Tensor, levels=5): """ Apply wavelet decomposition to the input tensor. This function only returns the low frequency & the high frequency. """ high_freq = torch.zeros_like(image) for i in range(levels): radius = 2 ** i low_freq = wavelet_blur(image, radius) high_freq += (image - low_freq) image = low_freq return high_freq, low_freq def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): """ Apply wavelet decomposition, so that the content will have the same color as the style. """ # calculate the wavelet decomposition of the content feature content_high_freq, content_low_freq = wavelet_decomposition(content_feat) del content_low_freq # calculate the wavelet decomposition of the style feature style_high_freq, style_low_freq = wavelet_decomposition(style_feat) del style_high_freq # reconstruct the content feature with the style's high frequency return content_high_freq + style_low_freq