[
  {
    "path": ".idea/CCSR.iml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<module type=\"PYTHON_MODULE\" version=\"4\">\n  <component name=\"NewModuleRootManager\">\n    <content url=\"file://$MODULE_DIR$\" />\n    <orderEntry type=\"inheritedJdk\" />\n    <orderEntry type=\"sourceFolder\" forTests=\"false\" />\n  </component>\n  <component name=\"PyDocumentationSettings\">\n    <option name=\"format\" value=\"PLAIN\" />\n    <option name=\"myDocStringFormat\" value=\"Plain\" />\n  </component>\n</module>"
  },
  {
    "path": ".idea/inspectionProfiles/Project_Default.xml",
    "content": "<component name=\"InspectionProjectProfileManager\">\n  <profile version=\"1.0\">\n    <option name=\"myName\" value=\"Project Default\" />\n    <inspection_tool class=\"PyPackageRequirementsInspection\" enabled=\"true\" level=\"WARNING\" enabled_by_default=\"true\">\n      <option name=\"ignoredPackages\">\n        <value>\n          <list size=\"1\">\n            <item index=\"0\" class=\"java.lang.String\" itemvalue=\"opencv-python\" />\n          </list>\n        </value>\n      </option>\n    </inspection_tool>\n  </profile>\n</component>"
  },
  {
    "path": ".idea/inspectionProfiles/profiles_settings.xml",
    "content": "<component name=\"InspectionProjectProfileManager\">\n  <settings>\n    <option name=\"USE_PROJECT_PROFILE\" value=\"false\" />\n    <version value=\"1.0\" />\n  </settings>\n</component>"
  },
  {
    "path": ".idea/modules.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ProjectModuleManager\">\n    <modules>\n      <module fileurl=\"file://$PROJECT_DIR$/.idea/CCSR.iml\" filepath=\"$PROJECT_DIR$/.idea/CCSR.iml\" />\n    </modules>\n  </component>\n</project>"
  },
  {
    "path": ".idea/vcs.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"VcsDirectoryMappings\">\n    <mapping directory=\"$PROJECT_DIR$\" vcs=\"Git\" />\n  </component>\n</project>"
  },
  {
    "path": ".idea/workspace.xml",
    "content": "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n<project version=\"4\">\n  <component name=\"ChangeListManager\">\n    <list default=\"true\" id=\"05c337e8-f34e-4a2e-abd3-dc65a3efd14a\" name=\"Changes\" comment=\"\" />\n    <option name=\"SHOW_DIALOG\" value=\"false\" />\n    <option name=\"HIGHLIGHT_CONFLICTS\" value=\"true\" />\n    <option name=\"HIGHLIGHT_NON_ACTIVE_CHANGELIST\" value=\"false\" />\n    <option name=\"LAST_RESOLUTION\" value=\"IGNORE\" />\n  </component>\n  <component name=\"Git.Settings\">\n    <option name=\"RECENT_GIT_ROOT_PATH\" value=\"$PROJECT_DIR$\" />\n  </component>\n  <component name=\"MarkdownSettingsMigration\">\n    <option name=\"stateVersion\" value=\"1\" />\n  </component>\n  <component name=\"ProjectId\" id=\"2qOgvG2MNzxubwwA97EeHYZVB9s\" />\n  <component name=\"ProjectLevelVcsManager\" settingsEditedManually=\"true\" />\n  <component name=\"ProjectViewState\">\n    <option name=\"hideEmptyMiddlePackages\" value=\"true\" />\n    <option name=\"showLibraryContents\" value=\"true\" />\n  </component>\n  <component name=\"SpellCheckerSettings\" RuntimeDictionaries=\"0\" Folders=\"0\" CustomDictionaries=\"0\" DefaultDictionary=\"application-level\" UseSingleDictionary=\"true\" transferred=\"true\" />\n  <component name=\"TaskManager\">\n    <task active=\"true\" id=\"Default\" summary=\"Default task\">\n      <changelist id=\"05c337e8-f34e-4a2e-abd3-dc65a3efd14a\" name=\"Changes\" comment=\"\" />\n      <created>1734539270044</created>\n      <option name=\"number\" value=\"Default\" />\n      <option name=\"presentableId\" value=\"Default\" />\n      <updated>1734539270044</updated>\n    </task>\n    <servers />\n  </component>\n</project>"
  },
  {
    "path": "ADD/dnnlib/__init__.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nfrom .util import EasyDict, make_cache_dir_path\n"
  },
  {
    "path": "ADD/dnnlib/util.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Miscellaneous utility classes and functions.\"\"\"\n\nimport ctypes\nimport fnmatch\nimport importlib\nimport inspect\nimport os\nimport sys\nimport types\nimport io\nimport pickle\nimport re\nimport requests\nimport html\nimport hashlib\nimport glob\nimport tempfile\nimport urllib\nimport urllib.request\nimport uuid\nfrom typing import Any, List, Tuple, Union, Optional\nfrom distutils.util import strtobool\nimport shutil\n\nimport numpy as np\n\n\n# Util classes\n# ------------------------------------------------------------------------------------------\n\nclass EasyDict(dict):\n    \"\"\"Convenience class that behaves like a dict but allows access with the attribute syntax.\"\"\"\n\n    def __getattr__(self, name: str) -> Any:\n        try:\n            return self[name]\n        except KeyError:\n            raise AttributeError(name)\n\n    def __setattr__(self, name: str, value: Any) -> None:\n        self[name] = value\n\n    def __delattr__(self, name: str) -> None:\n        del self[name]\n\n\nclass Logger(object):\n    \"\"\"Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file.\"\"\"\n\n    def __init__(self, file_name: Optional[str] = None, file_mode: str = \"w\", should_flush: bool = True):\n        self.file = None\n\n        if file_name is not None:\n            self.file = open(file_name, file_mode)\n\n        self.should_flush = should_flush\n        self.stdout = sys.stdout\n        self.stderr = sys.stderr\n\n        sys.stdout = self\n        sys.stderr = self\n\n    def __enter__(self) -> \"Logger\":\n        return self\n\n    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:\n        self.close()\n\n    def write(self, text: Union[str, bytes]) -> None:\n        \"\"\"Write text to stdout (and a file) and optionally flush.\"\"\"\n        if isinstance(text, bytes):\n            text = text.decode()\n        if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash\n            return\n\n        if self.file is not None:\n            self.file.write(text)\n\n        self.stdout.write(text)\n\n        if self.should_flush:\n            self.flush()\n\n    def flush(self) -> None:\n        \"\"\"Flush written text to both stdout and a file, if open.\"\"\"\n        if self.file is not None:\n            self.file.flush()\n\n        self.stdout.flush()\n\n    def close(self) -> None:\n        \"\"\"Flush, close possible files, and remove stdout/stderr mirroring.\"\"\"\n        self.flush()\n\n        # if using multiple loggers, prevent closing in wrong order\n        if sys.stdout is self:\n            sys.stdout = self.stdout\n        if sys.stderr is self:\n            sys.stderr = self.stderr\n\n        if self.file is not None:\n            self.file.close()\n            self.file = None\n\n\n# Cache directories\n# ------------------------------------------------------------------------------------------\n\n_dnnlib_cache_dir = None\n\ndef set_cache_dir(path: str) -> None:\n    global _dnnlib_cache_dir\n    _dnnlib_cache_dir = path\n\n\ndef make_cache_dir_path(*paths: str) -> str:\n    if _dnnlib_cache_dir is not None:\n        return os.path.join(_dnnlib_cache_dir, *paths)\n    if 'DNNLIB_CACHE_DIR' in os.environ:\n        return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)\n    if 'HOME' in os.environ:\n        return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)\n    if 'USERPROFILE' in os.environ:\n        return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)\n    return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)\n\n\n# Small util functions\n# ------------------------------------------------------------------------------------------\n\ndef format_time(seconds: Union[int, float]) -> str:\n    \"\"\"Convert the seconds to human readable string with days, hours, minutes and seconds.\"\"\"\n    s = int(np.rint(seconds))\n\n    if s < 60:\n        return \"{0}s\".format(s)\n    elif s < 60 * 60:\n        return \"{0}m {1:02}s\".format(s // 60, s % 60)\n    elif s < 24 * 60 * 60:\n        return \"{0}h {1:02}m {2:02}s\".format(s // (60 * 60), (s // 60) % 60, s % 60)\n    else:\n        return \"{0}d {1:02}h {2:02}m\".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)\n\n\ndef format_time_brief(seconds: Union[int, float]) -> str:\n    \"\"\"Convert the seconds to human readable string with days, hours, minutes and seconds.\"\"\"\n    s = int(np.rint(seconds))\n\n    if s < 60:\n        return \"{0}s\".format(s)\n    elif s < 60 * 60:\n        return \"{0}m {1:02}s\".format(s // 60, s % 60)\n    elif s < 24 * 60 * 60:\n        return \"{0}h {1:02}m\".format(s // (60 * 60), (s // 60) % 60)\n    else:\n        return \"{0}d {1:02}h\".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)\n\n\ndef ask_yes_no(question: str) -> bool:\n    \"\"\"Ask the user the question until the user inputs a valid answer.\"\"\"\n    while True:\n        try:\n            print(\"{0} [y/n]\".format(question))\n            return strtobool(input().lower())\n        except ValueError:\n            pass\n\n\ndef tuple_product(t: Tuple) -> Any:\n    \"\"\"Calculate the product of the tuple elements.\"\"\"\n    result = 1\n\n    for v in t:\n        result *= v\n\n    return result\n\n\n_str_to_ctype = {\n    \"uint8\": ctypes.c_ubyte,\n    \"uint16\": ctypes.c_uint16,\n    \"uint32\": ctypes.c_uint32,\n    \"uint64\": ctypes.c_uint64,\n    \"int8\": ctypes.c_byte,\n    \"int16\": ctypes.c_int16,\n    \"int32\": ctypes.c_int32,\n    \"int64\": ctypes.c_int64,\n    \"float32\": ctypes.c_float,\n    \"float64\": ctypes.c_double\n}\n\n\ndef get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:\n    \"\"\"Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.\"\"\"\n    type_str = None\n\n    if isinstance(type_obj, str):\n        type_str = type_obj\n    elif hasattr(type_obj, \"__name__\"):\n        type_str = type_obj.__name__\n    elif hasattr(type_obj, \"name\"):\n        type_str = type_obj.name\n    else:\n        raise RuntimeError(\"Cannot infer type name from input\")\n\n    assert type_str in _str_to_ctype.keys()\n\n    my_dtype = np.dtype(type_str)\n    my_ctype = _str_to_ctype[type_str]\n\n    assert my_dtype.itemsize == ctypes.sizeof(my_ctype)\n\n    return my_dtype, my_ctype\n\n\ndef is_pickleable(obj: Any) -> bool:\n    try:\n        with io.BytesIO() as stream:\n            pickle.dump(obj, stream)\n        return True\n    except:\n        return False\n\n\n# Functionality to import modules/objects by name, and call functions by name\n# ------------------------------------------------------------------------------------------\n\ndef get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:\n    \"\"\"Searches for the underlying module behind the name to some python object.\n    Returns the module and the object name (original name with module part removed).\"\"\"\n\n    # allow convenience shorthands, substitute them by full names\n    obj_name = re.sub(\"^np.\", \"numpy.\", obj_name)\n    obj_name = re.sub(\"^tf.\", \"tensorflow.\", obj_name)\n\n    # list alternatives for (module_name, local_obj_name)\n    parts = obj_name.split(\".\")\n    name_pairs = [(\".\".join(parts[:i]), \".\".join(parts[i:])) for i in range(len(parts), 0, -1)]\n\n    # try each alternative in turn\n    for module_name, local_obj_name in name_pairs:\n        try:\n            module = importlib.import_module(module_name) # may raise ImportError\n            get_obj_from_module(module, local_obj_name) # may raise AttributeError\n            return module, local_obj_name\n        except:\n            pass\n\n    # maybe some of the modules themselves contain errors?\n    for module_name, _local_obj_name in name_pairs:\n        try:\n            importlib.import_module(module_name) # may raise ImportError\n        except ImportError:\n            if not str(sys.exc_info()[1]).startswith(\"No module named '\" + module_name + \"'\"):\n                raise\n\n    # maybe the requested attribute is missing?\n    for module_name, local_obj_name in name_pairs:\n        try:\n            module = importlib.import_module(module_name) # may raise ImportError\n            get_obj_from_module(module, local_obj_name) # may raise AttributeError\n        except ImportError:\n            pass\n\n    # we are out of luck, but we have no idea why\n    raise ImportError(obj_name)\n\n\ndef get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:\n    \"\"\"Traverses the object name and returns the last (rightmost) python object.\"\"\"\n    if obj_name == '':\n        return module\n    obj = module\n    for part in obj_name.split(\".\"):\n        obj = getattr(obj, part)\n    return obj\n\n\ndef get_obj_by_name(name: str) -> Any:\n    \"\"\"Finds the python object with the given name.\"\"\"\n    module, obj_name = get_module_from_obj_name(name)\n    return get_obj_from_module(module, obj_name)\n\n\ndef call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:\n    \"\"\"Finds the python object with the given name and calls it as a function.\"\"\"\n    assert func_name is not None\n    func_obj = get_obj_by_name(func_name)\n    assert callable(func_obj)\n    return func_obj(*args, **kwargs)\n\n\ndef construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:\n    \"\"\"Finds the python class with the given name and constructs it with the given arguments.\"\"\"\n    return call_func_by_name(*args, func_name=class_name, **kwargs)\n\n\ndef get_module_dir_by_obj_name(obj_name: str) -> str:\n    \"\"\"Get the directory path of the module containing the given object name.\"\"\"\n    module, _ = get_module_from_obj_name(obj_name)\n    return os.path.dirname(inspect.getfile(module))\n\n\ndef is_top_level_function(obj: Any) -> bool:\n    \"\"\"Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'.\"\"\"\n    return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__\n\n\ndef get_top_level_function_name(obj: Any) -> str:\n    \"\"\"Return the fully-qualified name of a top-level function.\"\"\"\n    assert is_top_level_function(obj)\n    module = obj.__module__\n    if module == '__main__':\n        module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]\n    return module + \".\" + obj.__name__\n\n\n# File system helpers\n# ------------------------------------------------------------------------------------------\n\ndef list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:\n    \"\"\"List all files recursively in a given directory while ignoring given file and directory names.\n    Returns list of tuples containing both absolute and relative paths.\"\"\"\n    assert os.path.isdir(dir_path)\n    base_name = os.path.basename(os.path.normpath(dir_path))\n\n    if ignores is None:\n        ignores = []\n\n    result = []\n\n    for root, dirs, files in os.walk(dir_path, topdown=True):\n        for ignore_ in ignores:\n            dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]\n\n            # dirs need to be edited in-place\n            for d in dirs_to_remove:\n                dirs.remove(d)\n\n            files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]\n\n        absolute_paths = [os.path.join(root, f) for f in files]\n        relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]\n\n        if add_base_to_relative:\n            relative_paths = [os.path.join(base_name, p) for p in relative_paths]\n\n        assert len(absolute_paths) == len(relative_paths)\n        result += zip(absolute_paths, relative_paths)\n\n    return result\n\n\ndef copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:\n    \"\"\"Takes in a list of tuples of (src, dst) paths and copies files.\n    Will create all necessary directories.\"\"\"\n    for file in files:\n        target_dir_name = os.path.dirname(file[1])\n\n        # will create all intermediate-level directories\n        if not os.path.exists(target_dir_name):\n            os.makedirs(target_dir_name)\n\n        shutil.copyfile(file[0], file[1])\n\n\n# URL helpers\n# ------------------------------------------------------------------------------------------\n\ndef is_url(obj: Any, allow_file_urls: bool = False) -> bool:\n    \"\"\"Determine whether the given object is a valid URL string.\"\"\"\n    if not isinstance(obj, str) or not \"://\" in obj:\n        return False\n    if allow_file_urls and obj.startswith('file://'):\n        return True\n    try:\n        res = requests.compat.urlparse(obj)\n        if not res.scheme or not res.netloc or not \".\" in res.netloc:\n            return False\n        res = requests.compat.urlparse(requests.compat.urljoin(obj, \"/\"))\n        if not res.scheme or not res.netloc or not \".\" in res.netloc:\n            return False\n    except:\n        return False\n    return True\n\n\ndef open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:\n    \"\"\"Download the given URL and return a binary-mode file object to access the data.\"\"\"\n    assert num_attempts >= 1\n    assert not (return_filename and (not cache))\n\n    # Doesn't look like an URL scheme so interpret it as a local filename.\n    if not re.match('^[a-z]+://', url):\n        return url if return_filename else open(url, \"rb\")\n\n    # Handle file URLs.  This code handles unusual file:// patterns that\n    # arise on Windows:\n    #\n    # file:///c:/foo.txt\n    #\n    # which would translate to a local '/c:/foo.txt' filename that's\n    # invalid.  Drop the forward slash for such pathnames.\n    #\n    # If you touch this code path, you should test it on both Linux and\n    # Windows.\n    #\n    # Some internet resources suggest using urllib.request.url2pathname() but\n    # but that converts forward slashes to backslashes and this causes\n    # its own set of problems.\n    if url.startswith('file://'):\n        filename = urllib.parse.urlparse(url).path\n        if re.match(r'^/[a-zA-Z]:', filename):\n            filename = filename[1:]\n        return filename if return_filename else open(filename, \"rb\")\n\n    assert is_url(url)\n\n    # Lookup from cache.\n    if cache_dir is None:\n        cache_dir = make_cache_dir_path('downloads')\n\n    url_md5 = hashlib.md5(url.encode(\"utf-8\")).hexdigest()\n    if cache:\n        cache_files = glob.glob(os.path.join(cache_dir, url_md5 + \"_*\"))\n        if len(cache_files) == 1:\n            filename = cache_files[0]\n            return filename if return_filename else open(filename, \"rb\")\n\n    # Download.\n    url_name = None\n    url_data = None\n    with requests.Session() as session:\n        if verbose:\n            print(\"Downloading %s ...\" % url, end=\"\", flush=True)\n        for attempts_left in reversed(range(num_attempts)):\n            try:\n                with session.get(url) as res:\n                    res.raise_for_status()\n                    if len(res.content) == 0:\n                        raise IOError(\"No data received\")\n\n                    if len(res.content) < 8192:\n                        content_str = res.content.decode(\"utf-8\")\n                        if \"download_warning\" in res.headers.get(\"Set-Cookie\", \"\"):\n                            links = [html.unescape(link) for link in content_str.split('\"') if \"export=download\" in link]\n                            if len(links) == 1:\n                                url = requests.compat.urljoin(url, links[0])\n                                raise IOError(\"Google Drive virus checker nag\")\n                        if \"Google Drive - Quota exceeded\" in content_str:\n                            raise IOError(\"Google Drive download quota exceeded -- please try again later\")\n\n                    match = re.search(r'filename=\"([^\"]*)\"', res.headers.get(\"Content-Disposition\", \"\"))\n                    url_name = match[1] if match else url\n                    url_data = res.content\n                    if verbose:\n                        print(\" done\")\n                    break\n            except KeyboardInterrupt:\n                raise\n            except:\n                if not attempts_left:\n                    if verbose:\n                        print(\" failed\")\n                    raise\n                if verbose:\n                    print(\".\", end=\"\", flush=True)\n\n    # Save to cache.\n    if cache:\n        safe_name = re.sub(r\"[^0-9a-zA-Z-._]\", \"_\", url_name)\n        safe_name = safe_name[:min(len(safe_name), 128)]\n        cache_file = os.path.join(cache_dir, url_md5 + \"_\" + safe_name)\n        temp_file = os.path.join(cache_dir, \"tmp_\" + uuid.uuid4().hex + \"_\" + url_md5 + \"_\" + safe_name)\n        os.makedirs(cache_dir, exist_ok=True)\n        with open(temp_file, \"wb\") as f:\n            f.write(url_data)\n        os.replace(temp_file, cache_file) # atomic\n        if return_filename:\n            return cache_file\n\n    # Return data as file object.\n    assert not return_filename\n    return io.BytesIO(url_data)\n"
  },
  {
    "path": "ADD/layers/__init__.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\nfrom .dino_head import DINOHead\nfrom .mlp import Mlp\nfrom .patch_embed import PatchEmbed\nfrom .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused\nfrom .block import NestedTensorBlock\nfrom .attention import MemEffAttention\n"
  },
  {
    "path": "ADD/layers/attention.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nimport logging\nimport os\nimport warnings\n\nfrom torch import Tensor\nfrom torch import nn\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\ntry:\n    if XFORMERS_ENABLED:\n        from xformers.ops import memory_efficient_attention, unbind\n\n        XFORMERS_AVAILABLE = True\n        warnings.warn(\"xFormers is available (Attention)\")\n    else:\n        warnings.warn(\"xFormers is disabled (Attention)\")\n        raise ImportError\nexcept ImportError:\n    XFORMERS_AVAILABLE = False\n    warnings.warn(\"xFormers is not available (Attention)\")\n\n\nclass Attention(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int = 8,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        attn_drop: float = 0.0,\n        proj_drop: float = 0.0,\n    ) -> None:\n        super().__init__()\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim, bias=proj_bias)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)\n\n        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]\n        attn = q @ k.transpose(-2, -1)\n\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\nclass MemEffAttention(Attention):\n    def forward(self, x: Tensor, attn_bias=None) -> Tensor:\n        if not XFORMERS_AVAILABLE:\n            if attn_bias is not None:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return super().forward(x)\n\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)\n\n        q, k, v = unbind(qkv, 2)\n\n        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)\n        x = x.reshape([B, N, C])\n\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n"
  },
  {
    "path": "ADD/layers/block.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nimport logging\nimport os\nfrom typing import Callable, List, Any, Tuple, Dict\nimport warnings\n\nimport torch\nfrom torch import nn, Tensor\n\nfrom .attention import Attention, MemEffAttention\nfrom .drop_path import DropPath\nfrom .layer_scale import LayerScale\nfrom .mlp import Mlp\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\ntry:\n    if XFORMERS_ENABLED:\n        from xformers.ops import fmha, scaled_index_add, index_select_cat\n\n        XFORMERS_AVAILABLE = True\n        warnings.warn(\"xFormers is available (Block)\")\n    else:\n        warnings.warn(\"xFormers is disabled (Block)\")\n        raise ImportError\nexcept ImportError:\n    XFORMERS_AVAILABLE = False\n\n    warnings.warn(\"xFormers is not available (Block)\")\n\n\nclass Block(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        num_heads: int,\n        mlp_ratio: float = 4.0,\n        qkv_bias: bool = False,\n        proj_bias: bool = True,\n        ffn_bias: bool = True,\n        drop: float = 0.0,\n        attn_drop: float = 0.0,\n        init_values=None,\n        drop_path: float = 0.0,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,\n        attn_class: Callable[..., nn.Module] = Attention,\n        ffn_layer: Callable[..., nn.Module] = Mlp,\n    ) -> None:\n        super().__init__()\n        # print(f\"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}\")\n        self.norm1 = norm_layer(dim)\n        self.attn = attn_class(\n            dim,\n            num_heads=num_heads,\n            qkv_bias=qkv_bias,\n            proj_bias=proj_bias,\n            attn_drop=attn_drop,\n            proj_drop=drop,\n        )\n        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.norm2 = norm_layer(dim)\n        mlp_hidden_dim = int(dim * mlp_ratio)\n        self.mlp = ffn_layer(\n            in_features=dim,\n            hidden_features=mlp_hidden_dim,\n            act_layer=act_layer,\n            drop=drop,\n            bias=ffn_bias,\n        )\n        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()\n        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n\n        self.sample_drop_ratio = drop_path\n\n    def forward(self, x: Tensor) -> Tensor:\n        def attn_residual_func(x: Tensor) -> Tensor:\n            return self.ls1(self.attn(self.norm1(x)))\n\n        def ffn_residual_func(x: Tensor) -> Tensor:\n            return self.ls2(self.mlp(self.norm2(x)))\n\n        if self.training and self.sample_drop_ratio > 0.1:\n            # the overhead is compensated only for a drop path rate larger than 0.1\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n            x = drop_add_residual_stochastic_depth(\n                x,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n            )\n        elif self.training and self.sample_drop_ratio > 0.0:\n            x = x + self.drop_path1(attn_residual_func(x))\n            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2\n        else:\n            x = x + attn_residual_func(x)\n            x = x + ffn_residual_func(x)\n        return x\n\n\ndef drop_add_residual_stochastic_depth(\n    x: Tensor,\n    residual_func: Callable[[Tensor], Tensor],\n    sample_drop_ratio: float = 0.0,\n) -> Tensor:\n    # 1) extract subset using permutation\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    x_subset = x[brange]\n\n    # 2) apply residual_func to get residual\n    residual = residual_func(x_subset)\n\n    x_flat = x.flatten(1)\n    residual = residual.flatten(1)\n\n    residual_scale_factor = b / sample_subset_size\n\n    # 3) add the residual\n    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    return x_plus_residual.view_as(x)\n\n\ndef get_branges_scales(x, sample_drop_ratio=0.0):\n    b, n, d = x.shape\n    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)\n    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]\n    residual_scale_factor = b / sample_subset_size\n    return brange, residual_scale_factor\n\n\ndef add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):\n    if scaling_vector is None:\n        x_flat = x.flatten(1)\n        residual = residual.flatten(1)\n        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)\n    else:\n        x_plus_residual = scaled_index_add(\n            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor\n        )\n    return x_plus_residual\n\n\nattn_bias_cache: Dict[Tuple, Any] = {}\n\n\ndef get_attn_bias_and_cat(x_list, branges=None):\n    \"\"\"\n    this will perform the index select, cat the tensors, and provide the attn_bias from cache\n    \"\"\"\n    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]\n    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))\n    if all_shapes not in attn_bias_cache.keys():\n        seqlens = []\n        for b, x in zip(batch_sizes, x_list):\n            for _ in range(b):\n                seqlens.append(x.shape[1])\n        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)\n        attn_bias._batch_sizes = batch_sizes\n        attn_bias_cache[all_shapes] = attn_bias\n\n    if branges is not None:\n        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])\n    else:\n        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)\n        cat_tensors = torch.cat(tensors_bs1, dim=1)\n\n    return attn_bias_cache[all_shapes], cat_tensors\n\n\ndef drop_add_residual_stochastic_depth_list(\n    x_list: List[Tensor],\n    residual_func: Callable[[Tensor, Any], Tensor],\n    sample_drop_ratio: float = 0.0,\n    scaling_vector=None,\n) -> Tensor:\n    # 1) generate random set of indices for dropping samples in the batch\n    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]\n    branges = [s[0] for s in branges_scales]\n    residual_scale_factors = [s[1] for s in branges_scales]\n\n    # 2) get attention bias and index+concat the tensors\n    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)\n\n    # 3) apply residual_func to get residual, and split the result\n    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore\n\n    outputs = []\n    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):\n        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))\n    return outputs\n\n\nclass NestedTensorBlock(Block):\n    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:\n        \"\"\"\n        x_list contains a list of tensors to nest together and run\n        \"\"\"\n        assert isinstance(self.attn, MemEffAttention)\n\n        if self.training and self.sample_drop_ratio > 0.0:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.attn(self.norm1(x), attn_bias=attn_bias)\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.mlp(self.norm2(x))\n\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=attn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            x_list = drop_add_residual_stochastic_depth_list(\n                x_list,\n                residual_func=ffn_residual_func,\n                sample_drop_ratio=self.sample_drop_ratio,\n                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,\n            )\n            return x_list\n        else:\n\n            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))\n\n            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:\n                return self.ls2(self.mlp(self.norm2(x)))\n\n            attn_bias, x = get_attn_bias_and_cat(x_list)\n            x = x + attn_residual_func(x, attn_bias=attn_bias)\n            x = x + ffn_residual_func(x)\n            return attn_bias.split(x)\n\n    def forward(self, x_or_x_list):\n        if isinstance(x_or_x_list, Tensor):\n            return super().forward(x_or_x_list)\n        elif isinstance(x_or_x_list, list):\n            if not XFORMERS_AVAILABLE:\n                raise AssertionError(\"xFormers is required for using nested tensors\")\n            return self.forward_nested(x_or_x_list)\n        else:\n            raise AssertionError\n"
  },
  {
    "path": "ADD/layers/dino_head.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\nimport torch\nimport torch.nn as nn\nfrom torch.nn.init import trunc_normal_\nfrom torch.nn.utils import weight_norm\n\n\nclass DINOHead(nn.Module):\n    def __init__(\n        self,\n        in_dim,\n        out_dim,\n        use_bn=False,\n        nlayers=3,\n        hidden_dim=2048,\n        bottleneck_dim=256,\n        mlp_bias=True,\n    ):\n        super().__init__()\n        nlayers = max(nlayers, 1)\n        self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)\n        self.apply(self._init_weights)\n        self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))\n        self.last_layer.weight_g.data.fill_(1)\n\n    def _init_weights(self, m):\n        if isinstance(m, nn.Linear):\n            trunc_normal_(m.weight, std=0.02)\n            if isinstance(m, nn.Linear) and m.bias is not None:\n                nn.init.constant_(m.bias, 0)\n\n    def forward(self, x):\n        x = self.mlp(x)\n        eps = 1e-6 if x.dtype == torch.float16 else 1e-12\n        x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)\n        x = self.last_layer(x)\n        return x\n\n\ndef _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):\n    if nlayers == 1:\n        return nn.Linear(in_dim, bottleneck_dim, bias=bias)\n    else:\n        layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]\n        if use_bn:\n            layers.append(nn.BatchNorm1d(hidden_dim))\n        layers.append(nn.GELU())\n        for _ in range(nlayers - 2):\n            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))\n            if use_bn:\n                layers.append(nn.BatchNorm1d(hidden_dim))\n            layers.append(nn.GELU())\n        layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))\n        return nn.Sequential(*layers)\n"
  },
  {
    "path": "ADD/layers/drop_path.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py\n\n\nfrom torch import nn\n\n\ndef drop_path(x, drop_prob: float = 0.0, training: bool = False):\n    if drop_prob == 0.0 or not training:\n        return x\n    keep_prob = 1 - drop_prob\n    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets\n    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n    if keep_prob > 0.0:\n        random_tensor.div_(keep_prob)\n    output = x * random_tensor\n    return output\n\n\nclass DropPath(nn.Module):\n    \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).\"\"\"\n\n    def __init__(self, drop_prob=None):\n        super(DropPath, self).__init__()\n        self.drop_prob = drop_prob\n\n    def forward(self, x):\n        return drop_path(x, self.drop_prob, self.training)\n"
  },
  {
    "path": "ADD/layers/layer_scale.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110\n\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\nfrom torch import nn\n\n\nclass LayerScale(nn.Module):\n    def __init__(\n        self,\n        dim: int,\n        init_values: Union[float, Tensor] = 1e-5,\n        inplace: bool = False,\n    ) -> None:\n        super().__init__()\n        self.inplace = inplace\n        self.gamma = nn.Parameter(init_values * torch.ones(dim))\n\n    def forward(self, x: Tensor) -> Tensor:\n        return x.mul_(self.gamma) if self.inplace else x * self.gamma\n"
  },
  {
    "path": "ADD/layers/mlp.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py\n\n\nfrom typing import Callable, Optional\n\nfrom torch import Tensor, nn\n\n\nclass Mlp(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = nn.GELU,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)\n        self.act = act_layer()\n        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)\n        self.drop = nn.Dropout(drop)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.fc1(x)\n        x = self.act(x)\n        x = self.drop(x)\n        x = self.fc2(x)\n        x = self.drop(x)\n        return x\n"
  },
  {
    "path": "ADD/layers/patch_embed.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py\n\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom torch import Tensor\nimport torch.nn as nn\n\n\ndef make_2tuple(x):\n    if isinstance(x, tuple):\n        assert len(x) == 2\n        return x\n\n    assert isinstance(x, int)\n    return (x, x)\n\n\nclass PatchEmbed(nn.Module):\n    \"\"\"\n    2D image to patch embedding: (B,C,H,W) -> (B,N,D)\n\n    Args:\n        img_size: Image size.\n        patch_size: Patch token size.\n        in_chans: Number of input image channels.\n        embed_dim: Number of linear projection output channels.\n        norm_layer: Normalization layer.\n    \"\"\"\n\n    def __init__(\n        self,\n        img_size: Union[int, Tuple[int, int]] = 224,\n        patch_size: Union[int, Tuple[int, int]] = 16,\n        in_chans: int = 3,\n        embed_dim: int = 768,\n        norm_layer: Optional[Callable] = None,\n        flatten_embedding: bool = True,\n    ) -> None:\n        super().__init__()\n\n        image_HW = make_2tuple(img_size)\n        patch_HW = make_2tuple(patch_size)\n        patch_grid_size = (\n            image_HW[0] // patch_HW[0],\n            image_HW[1] // patch_HW[1],\n        )\n\n        self.img_size = image_HW\n        self.patch_size = patch_HW\n        self.patches_resolution = patch_grid_size\n        self.num_patches = patch_grid_size[0] * patch_grid_size[1]\n\n        #self.in_chans = in_chans\n        self.embed_dim = embed_dim\n\n        self.flatten_embedding = flatten_embedding\n\n        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)\n        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()\n\n    def forward(self, x: Tensor) -> Tensor:\n        _, _, H, W = x.shape\n        patch_H, patch_W = self.patch_size\n\n        assert H % patch_H == 0, f\"Input image height {H} is not a multiple of patch height {patch_H}\"\n        assert W % patch_W == 0, f\"Input image width {W} is not a multiple of patch width: {patch_W}\"\n\n        x = self.proj(x)  # B C H W\n        H, W = x.size(2), x.size(3)\n        x = x.flatten(2).transpose(1, 2)  # B HW C\n        x = self.norm(x)\n        if not self.flatten_embedding:\n            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C\n        return x\n\n    #def flops(self) -> float:\n        #Ho, Wo = self.patches_resolution\n        #flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])\n        #if self.norm is not None:\n         #   flops += Ho * Wo * self.embed_dim\n        #return flops\n"
  },
  {
    "path": "ADD/layers/swiglu_ffn.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\nimport os\nfrom typing import Callable, Optional\nimport warnings\n\nfrom torch import Tensor, nn\nimport torch.nn.functional as F\n\n\nclass SwiGLUFFN(nn.Module):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        super().__init__()\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)\n        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x12 = self.w12(x)\n        x1, x2 = x12.chunk(2, dim=-1)\n        hidden = F.silu(x1) * x2\n        return self.w3(hidden)\n\n\nXFORMERS_ENABLED = os.environ.get(\"XFORMERS_DISABLED\") is None\ntry:\n    if XFORMERS_ENABLED:\n        from xformers.ops import SwiGLU\n\n        XFORMERS_AVAILABLE = True\n        warnings.warn(\"xFormers is available (SwiGLU)\")\n    else:\n        warnings.warn(\"xFormers is disabled (SwiGLU)\")\n        raise ImportError\nexcept ImportError:\n    SwiGLU = SwiGLUFFN\n    XFORMERS_AVAILABLE = False\n\n    warnings.warn(\"xFormers is not available (SwiGLU)\")\n\n\nclass SwiGLUFFNFused(SwiGLU):\n    def __init__(\n        self,\n        in_features: int,\n        hidden_features: Optional[int] = None,\n        out_features: Optional[int] = None,\n        act_layer: Callable[..., nn.Module] = None,\n        drop: float = 0.0,\n        bias: bool = True,\n    ) -> None:\n        out_features = out_features or in_features\n        hidden_features = hidden_features or in_features\n        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8\n        super().__init__(\n            in_features=in_features,\n            hidden_features=hidden_features,\n            out_features=out_features,\n            bias=bias,\n        )\n"
  },
  {
    "path": "ADD/models/discriminator.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"\nProjected discriminator architecture from\n\"StyleGAN-T: Unlocking the Power of GANs for Fast Large-Scale Text-to-Image Synthesis\".\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn.utils.spectral_norm import SpectralNorm\nfrom torchvision.transforms import RandomCrop, Normalize\nimport timm\nfrom timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD\n\nfrom ADD.th_utils import misc\nfrom models.shared import ResidualBlock, FullyConnectedLayer\nfrom models.vit_utils import make_vit_backbone, forward_vit, make_sd_backbone\nfrom models.DiffAugment import DiffAugment\nfrom ADD.utils.util_net import reload_model_\n\nfrom functools import partial\n\nclass SpectralConv1d(nn.Conv1d):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12)\n\n\nclass BatchNormLocal(nn.Module):\n    def __init__(self, num_features: int, affine: bool = True, virtual_bs: int = 3, eps: float = 1e-5):\n        super().__init__()\n        self.virtual_bs = virtual_bs\n        self.eps = eps\n        self.affine = affine\n\n        if self.affine:\n            self.weight = nn.Parameter(torch.ones(num_features))\n            self.bias = nn.Parameter(torch.zeros(num_features))\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        shape = x.size()\n\n        # Reshape batch into groups.\n        G = np.ceil(x.size(0)/self.virtual_bs).astype(int)\n        x = x.view(G, -1, x.size(-2), x.size(-1))\n\n        # Calculate stats.\n        mean = x.mean([1, 3], keepdim=True)\n        var = x.var([1, 3], keepdim=True, unbiased=False)\n        x = (x - mean) / (torch.sqrt(var + self.eps))\n\n        if self.affine:\n            x = x * self.weight[None, :, None] + self.bias[None, :, None]\n\n        return x.view(shape)\n\n\ndef make_block(channels: int, kernel_size: int) -> nn.Module:\n    return nn.Sequential(\n        SpectralConv1d(\n            channels,\n            channels,\n            kernel_size = kernel_size,\n            padding = kernel_size//2,\n            padding_mode = 'circular',\n        ),\n        #BatchNormLocal(channels),\n        nn.GroupNorm(4, channels),\n        nn.LeakyReLU(0.2, True),\n    )\n\nclass DiscHead(nn.Module):\n    def __init__(self, channels: int, c_dim: int, cmap_dim: int = 64):\n        super().__init__()\n        self.channels = channels\n        self.c_dim = c_dim\n        self.cmap_dim = cmap_dim\n\n        self.main = nn.Sequential(\n            make_block(channels, kernel_size=1),\n            ResidualBlock(make_block(channels, kernel_size=9))\n        )\n\n        if self.c_dim > 0:\n            self.cmapper = FullyConnectedLayer(self.c_dim, cmap_dim)\n            self.cls = SpectralConv1d(channels, cmap_dim, kernel_size=1, padding=0)\n        else:\n            self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0)\n\n    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:\n        h = self.main(x)\n        out = self.cls(h)\n\n        if self.c_dim > 0:\n            cmap = self.cmapper(c).unsqueeze(-1)\n            out = (out * cmap).sum(1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))\n\n        return out\n\nclass DINO(torch.nn.Module):\n    def __init__(self, hooks: list[int] = [2,5,8,11], hook_patch: bool = True):\n        super().__init__()\n        self.n_hooks = len(hooks) + int(hook_patch)\n\n        self.model = make_vit_backbone(\n            timm.create_model('vit_small_patch16_224.dino', pretrained=False),\n            patch_size=[16,16], hooks=hooks, hook_patch=hook_patch,\n        )\n        reload_model_(self.model, torch.load('preset/models/dino/dino_deitsmall16_pretrain.pth'))\n        self.model = self.model.eval().requires_grad_(False)\n\n\n        self.img_resolution = self.model.model.patch_embed.img_size[0]\n        self.embed_dim = self.model.model.embed_dim\n        self.norm = Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        ''' input: x in [0, 1]; output: dict of activations '''\n        x = F.interpolate(x, self.img_resolution, mode='area')\n        x = self.norm(x)\n        features = forward_vit(self.model, x)\n        return features\n\n\nclass ProjectedDiscriminator(nn.Module):\n    def __init__(self, c_dim: int, diffaug: bool = True, p_crop: float = 0.5):\n        super().__init__()\n        self.c_dim = c_dim\n        self.diffaug = diffaug\n        self.p_crop = p_crop\n\n        self.dino = DINO()\n\n        heads = []\n        for i in range(self.dino.n_hooks):\n            heads += [str(i), DiscHead(self.dino.embed_dim, c_dim)],\n        self.heads = nn.ModuleDict(heads)\n\n    def train(self, mode: bool = True):\n        self.dino = self.dino.train(False)\n        self.heads = self.heads.train(mode)\n        return self\n\n    def eval(self):\n        return self.train(False)\n\n    def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:\n        # Apply augmentation (x in [-1, 1]).\n        if self.diffaug:\n            x = DiffAugment(x, policy='translation,cutout')\n\n        # Transform to [0, 1].\n        x = x.add(1).div(2)\n\n        # Take crops with probablity p_crop if the image is larger.\n        if x.size(-1) > self.dino.img_resolution and np.random.random() < self.p_crop:\n            x = RandomCrop(self.dino.img_resolution)(x)\n\n        # Forward pass through DINO ViT.\n        features = self.dino(x)\n\n        # Apply discriminator heads.\n        logits = []\n        for k, head in self.heads.items():\n            features[k].requires_grad_(True)\n            logits.append(head(features[k], c).view(x.size(0), -1))\n        #logits = torch.cat(logits, dim=1)\n\n        return logits, features\n\n"
  },
  {
    "path": "ADD/models/vit.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n#\n# This source code is licensed under the Apache License, Version 2.0\n# found in the LICENSE file in the root directory of this source tree.\n\n# References:\n#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py\n#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py\n\nfrom functools import partial\nimport math\nimport logging\nfrom typing import Sequence, Tuple, Union, Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\nfrom torch.nn.init import trunc_normal_\n\nfrom ADD.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block\n\n\nlogger = logging.getLogger(\"dinov2\")\n\n\ndef named_apply(fn: Callable, module: nn.Module, name=\"\", depth_first=True, include_root=False) -> nn.Module:\n    if not depth_first and include_root:\n        fn(module=module, name=name)\n    for child_name, child_module in module.named_children():\n        child_name = \".\".join((name, child_name)) if name else child_name\n        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)\n    if depth_first and include_root:\n        fn(module=module, name=name)\n    return module\n\n\nclass BlockChunk(nn.ModuleList):\n    def forward(self, x):\n        for b in self:\n            x = b(x)\n        return x\n\n\nclass DinoVisionTransformer(nn.Module):\n    def __init__(\n        self,\n        img_size=224,\n        patch_size=16,\n        in_chans=3,\n        embed_dim=768,\n        depth=12,\n        num_heads=12,\n        mlp_ratio=4.0,\n        qkv_bias=True,\n        ffn_bias=True,\n        proj_bias=True,\n        drop_path_rate=0.0,\n        drop_path_uniform=False,\n        init_values=None,  # for layerscale: None or 0 => no layerscale\n        embed_layer=PatchEmbed,\n        act_layer=nn.GELU,\n        block_fn=Block,\n        ffn_layer=\"mlp\",\n        block_chunks=1,\n        num_register_tokens=0,\n        interpolate_antialias=False,\n        interpolate_offset=0.1,\n    ):\n        \"\"\"\n        Args:\n            img_size (int, tuple): input image size\n            patch_size (int, tuple): patch size\n            in_chans (int): number of input channels\n            embed_dim (int): embedding dimension\n            depth (int): depth of transformer\n            num_heads (int): number of attention heads\n            mlp_ratio (int): ratio of mlp hidden dim to embedding dim\n            qkv_bias (bool): enable bias for qkv if True\n            proj_bias (bool): enable bias for proj in attn if True\n            ffn_bias (bool): enable bias for ffn if True\n            drop_path_rate (float): stochastic depth rate\n            drop_path_uniform (bool): apply uniform drop rate across blocks\n            weight_init (str): weight init scheme\n            init_values (float): layer-scale init values\n            embed_layer (nn.Module): patch embedding layer\n            act_layer (nn.Module): MLP activation layer\n            block_fn (nn.Module): transformer block class\n            ffn_layer (str): \"mlp\", \"swiglu\", \"swiglufused\" or \"identity\"\n            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap\n            num_register_tokens: (int) number of extra cls tokens (so-called \"registers\")\n            interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings\n            interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings\n        \"\"\"\n        super().__init__()\n        norm_layer = partial(nn.LayerNorm, eps=1e-6)\n\n        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models\n        self.num_tokens = 1\n        self.n_blocks = depth\n        self.num_heads = num_heads\n        self.patch_size = patch_size\n        self.num_register_tokens = num_register_tokens\n        self.interpolate_antialias = interpolate_antialias\n        self.interpolate_offset = interpolate_offset\n\n        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)\n        num_patches = self.patch_embed.num_patches\n\n        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))\n        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))\n        assert num_register_tokens >= 0\n        self.register_tokens = (\n            nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None\n        )\n\n        if drop_path_uniform is True:\n            dpr = [drop_path_rate] * depth\n        else:\n            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule\n\n        if ffn_layer == \"mlp\":\n            logger.info(\"using MLP layer as FFN\")\n            ffn_layer = Mlp\n        elif ffn_layer == \"swiglufused\" or ffn_layer == \"swiglu\":\n            logger.info(\"using SwiGLU layer as FFN\")\n            ffn_layer = SwiGLUFFNFused\n        elif ffn_layer == \"identity\":\n            logger.info(\"using Identity layer as FFN\")\n\n            def f(*args, **kwargs):\n                return nn.Identity()\n\n            ffn_layer = f\n        else:\n            raise NotImplementedError\n\n        blocks_list = [\n            block_fn(\n                dim=embed_dim,\n                num_heads=num_heads,\n                mlp_ratio=mlp_ratio,\n                qkv_bias=qkv_bias,\n                proj_bias=proj_bias,\n                ffn_bias=ffn_bias,\n                drop_path=dpr[i],\n                norm_layer=norm_layer,\n                act_layer=act_layer,\n                ffn_layer=ffn_layer,\n                init_values=init_values,\n            )\n            for i in range(depth)\n        ]\n        if block_chunks > 0:\n            self.chunked_blocks = True\n            chunked_blocks = []\n            chunksize = depth // block_chunks\n            for i in range(0, depth, chunksize):\n                # this is to keep the block index consistent if we chunk the block list\n                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])\n            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])\n        else:\n            self.chunked_blocks = False\n            self.blocks = nn.ModuleList(blocks_list)\n\n        self.norm = norm_layer(embed_dim)\n        self.head = nn.Identity()\n\n        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))\n\n        self.init_weights()\n\n    def init_weights(self):\n        trunc_normal_(self.pos_embed, std=0.02)\n        nn.init.normal_(self.cls_token, std=1e-6)\n        if self.register_tokens is not None:\n            nn.init.normal_(self.register_tokens, std=1e-6)\n        named_apply(init_weights_vit_timm, self)\n\n    def interpolate_pos_encoding(self, x, w, h):\n        previous_dtype = x.dtype\n        npatch = x.shape[1] - 1\n        N = self.pos_embed.shape[1] - 1\n        if npatch == N and w == h:\n            return self.pos_embed\n        pos_embed = self.pos_embed.float()\n        class_pos_embed = pos_embed[:, 0]\n        patch_pos_embed = pos_embed[:, 1:]\n        dim = x.shape[-1]\n        w0 = w // self.patch_size\n        h0 = h // self.patch_size\n        # we add a small number to avoid floating point error in the interpolation\n        # see discussion at https://github.com/facebookresearch/dino/issues/8\n        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset\n\n        sqrt_N = math.sqrt(N)\n        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N\n        patch_pos_embed = nn.functional.interpolate(\n            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),\n            scale_factor=(sx, sy),\n            mode=\"bicubic\",\n            antialias=self.interpolate_antialias,\n        )\n\n        assert int(w0) == patch_pos_embed.shape[-2]\n        assert int(h0) == patch_pos_embed.shape[-1]\n        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)\n        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)\n\n    def prepare_tokens_with_masks(self, x, masks=None):\n        B, nc, w, h = x.shape\n        x = self.patch_embed(x)\n        if masks is not None:\n            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)\n\n        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)\n        x = x + self.interpolate_pos_encoding(x, w, h)\n\n        if self.register_tokens is not None:\n            x = torch.cat(\n                (\n                    x[:, :1],\n                    self.register_tokens.expand(x.shape[0], -1, -1),\n                    x[:, 1:],\n                ),\n                dim=1,\n            )\n\n        return x\n\n    def forward_features_list(self, x_list, masks_list):\n        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]\n        for blk in self.blocks:\n            x = blk(x)\n\n        all_x = x\n        output = []\n        for x, masks in zip(all_x, masks_list):\n            x_norm = self.norm(x)\n            output.append(\n                {\n                    \"x_norm_clstoken\": x_norm[:, 0],\n                    \"x_norm_regtokens\": x_norm[:, 1 : self.num_register_tokens + 1],\n                    \"x_norm_patchtokens\": x_norm[:, self.num_register_tokens + 1 :],\n                    \"x_prenorm\": x,\n                    \"masks\": masks,\n                }\n            )\n        return output\n\n    def forward_features(self, x, masks=None):\n        fea_list = []\n        counter = 0\n        if isinstance(x, list):\n            return self.forward_features_list(x, masks)\n\n        x = self.prepare_tokens_with_masks(x, masks)\n        fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))\n\n        for blk in self.blocks:\n            x = blk(x)\n            counter += 1\n            if counter % 3 == 0:\n                fea_list.append(x[:, self.num_register_tokens + 1 :].permute(0, 2, 1))\n\n        x_norm = self.norm(x)\n        return fea_list, x_norm[:, 0]\n\n    def _get_intermediate_layers_not_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        # If n is an int, take the n last blocks. If it's a list, take them\n        output, total_block_len = [], len(self.blocks)\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for i, blk in enumerate(self.blocks):\n            x = blk(x)\n            if i in blocks_to_take:\n                output.append(x)\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def _get_intermediate_layers_chunked(self, x, n=1):\n        x = self.prepare_tokens_with_masks(x)\n        output, i, total_block_len = [], 0, len(self.blocks[-1])\n        # If n is an int, take the n last blocks. If it's a list, take them\n        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n\n        for block_chunk in self.blocks:\n            for blk in block_chunk[i:]:  # Passing the nn.Identity()\n                x = blk(x)\n                if i in blocks_to_take:\n                    output.append(x)\n                i += 1\n        assert len(output) == len(blocks_to_take), f\"only {len(output)} / {len(blocks_to_take)} blocks found\"\n        return output\n\n    def get_intermediate_layers(\n        self,\n        x: torch.Tensor,\n        n: Union[int, Sequence] = 1,  # Layers or n last layers to take\n        reshape: bool = False,\n        return_class_token: bool = False,\n        norm=True,\n    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:\n        if self.chunked_blocks:\n            outputs = self._get_intermediate_layers_chunked(x, n)\n        else:\n            outputs = self._get_intermediate_layers_not_chunked(x, n)\n        if norm:\n            outputs = [self.norm(out) for out in outputs]\n        class_tokens = [out[:, 0] for out in outputs]\n        outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]\n        if reshape:\n            B, _, w, h = x.shape\n            outputs = [\n                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()\n                for out in outputs\n            ]\n        if return_class_token:\n            return tuple(zip(outputs, class_tokens))\n        return tuple(outputs)\n\n    def forward(self, *args, is_training=False, **kwargs):\n        ret = self.forward_features(*args, **kwargs)\n        if is_training:\n            return ret\n        else:\n            return ret#self.head(ret[\"x_norm_clstoken\"])\n\n\ndef init_weights_vit_timm(module: nn.Module, name: str = \"\"):\n    \"\"\"ViT weight initialization, original timm impl (for reproducibility)\"\"\"\n    if isinstance(module, nn.Linear):\n        trunc_normal_(module.weight, std=0.02)\n        if module.bias is not None:\n            nn.init.zeros_(module.bias)\n\n\ndef vit_small(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=384,\n        depth=12,\n        num_heads=6,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\ndef vit_large(patch_size=16, num_register_tokens=0, **kwargs):\n    model = DinoVisionTransformer(\n        patch_size=patch_size,\n        embed_dim=1024,\n        depth=24,\n        num_heads=16,\n        mlp_ratio=4,\n        block_fn=partial(Block, attn_class=MemEffAttention),\n        num_register_tokens=num_register_tokens,\n        **kwargs,\n    )\n    return model\n\n\n# net = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0)\n# prefile = torch.load('../weights/dinov2_vits14_pretrain.pth')\n# net.load_state_dict(prefile, True)\n# out = net(torch.rand(1, 3, 518, 518))\n# print(out.shape)\n\n# net = vit_large(patch_size=14, img_size=526, block_chunks=0, init_values=1.0, num_register_tokens=4)\n# prefile = torch.load('../weights/dinov2_vitl14_reg4_pretrain.pth')\n# net.load_state_dict(prefile, True)\n# out = net(torch.rand(1, 3, 70, 70))\n# print(out.shape)\n"
  },
  {
    "path": "ADD/th_utils/__init__.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n# empty\n"
  },
  {
    "path": "ADD/th_utils/custom_ops.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport glob\nimport hashlib\nimport importlib\nimport os\nimport re\nimport shutil\nimport uuid\n\nimport torch\nimport torch.utils.cpp_extension\nfrom torch.utils.file_baton import FileBaton\n\n#----------------------------------------------------------------------------\n# Global options.\n\nverbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'\n\n#----------------------------------------------------------------------------\n# Internal helper funcs.\n\ndef _find_compiler_bindir():\n    patterns = [\n        'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',\n        'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',\n        'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',\n        'C:/Program Files*/Microsoft Visual Studio */vc/bin',\n    ]\n    for pattern in patterns:\n        matches = sorted(glob.glob(pattern))\n        if len(matches):\n            return matches[-1]\n    return None\n\n#----------------------------------------------------------------------------\n\ndef _get_mangled_gpu_name():\n    name = torch.cuda.get_device_name().lower()\n    out = []\n    for c in name:\n        if re.match('[a-z0-9_-]+', c):\n            out.append(c)\n        else:\n            out.append('-')\n    return ''.join(out)\n\n#----------------------------------------------------------------------------\n# Main entry point for compiling and loading C++/CUDA plugins.\n\n_cached_plugins = dict()\n\ndef get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):\n    assert verbosity in ['none', 'brief', 'full']\n    if headers is None:\n        headers = []\n    if source_dir is not None:\n        sources = [os.path.join(source_dir, fname) for fname in sources]\n        headers = [os.path.join(source_dir, fname) for fname in headers]\n\n    # Already cached?\n    if module_name in _cached_plugins:\n        return _cached_plugins[module_name]\n\n    # Print status.\n    if verbosity == 'full':\n        print(f'Setting up PyTorch plugin \"{module_name}\"...')\n    elif verbosity == 'brief':\n        print(f'Setting up PyTorch plugin \"{module_name}\"... ', end='', flush=True)\n    verbose_build = (verbosity == 'full')\n\n    # Compile and load.\n    try: # pylint: disable=too-many-nested-blocks\n        # Make sure we can find the necessary compiler binaries.\n        if os.name == 'nt' and os.system(\"where cl.exe >nul 2>nul\") != 0:\n            compiler_bindir = _find_compiler_bindir()\n            if compiler_bindir is None:\n                raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in \"{__file__}\".')\n            os.environ['PATH'] += ';' + compiler_bindir\n\n        # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either\n        # break the build or unnecessarily restrict what's available to nvcc.\n        # Unset it to let nvcc decide based on what's available on the\n        # machine.\n        os.environ['TORCH_CUDA_ARCH_LIST'] = ''\n\n        # Incremental build md5sum trickery.  Copies all the input source files\n        # into a cached build directory under a combined md5 digest of the input\n        # source files.  Copying is done only if the combined digest has changed.\n        # This keeps input file timestamps and filenames the same as in previous\n        # extension builds, allowing for fast incremental rebuilds.\n        #\n        # This optimization is done only in case all the source files reside in\n        # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR\n        # environment variable is set (we take this as a signal that the user\n        # actually cares about this.)\n        #\n        # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work\n        # around the *.cu dependency bug in ninja config.\n        #\n        all_source_files = sorted(sources + headers)\n        all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)\n        if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):\n\n            # Compute combined hash digest for all source files.\n            hash_md5 = hashlib.md5()\n            for src in all_source_files:\n                with open(src, 'rb') as f:\n                    hash_md5.update(f.read())\n\n            # Select cached build directory name.\n            source_digest = hash_md5.hexdigest()\n            build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access\n            cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')\n\n            if not os.path.isdir(cached_build_dir):\n                tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'\n                os.makedirs(tmpdir)\n                for src in all_source_files:\n                    shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))\n                try:\n                    os.replace(tmpdir, cached_build_dir) # atomic\n                except OSError:\n                    # source directory already exists, delete tmpdir and its contents.\n                    shutil.rmtree(tmpdir)\n                    if not os.path.isdir(cached_build_dir): raise\n\n            # Compile.\n            cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]\n            torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,\n                verbose=verbose_build, sources=cached_sources, **build_kwargs)\n        else:\n            torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)\n\n        # Load.\n        module = importlib.import_module(module_name)\n\n    except:\n        if verbosity == 'brief':\n            print('Failed!')\n        raise\n\n    # Print status and add to cache dict.\n    if verbosity == 'full':\n        print(f'Done setting up PyTorch plugin \"{module_name}\".')\n    elif verbosity == 'brief':\n        print('Done.')\n    _cached_plugins[module_name] = module\n    return module\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/misc.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport re\nimport contextlib\nimport numpy as np\nimport torch\nimport warnings\nimport ADD.dnnlib as dnnlib\n\n#----------------------------------------------------------------------------\n# Cached construction of constant tensors. Avoids CPU=>GPU copy when the\n# same constant is used multiple times.\n\n_constant_cache = dict()\n\ndef constant(value, shape=None, dtype=None, device=None, memory_format=None):\n    value = np.asarray(value)\n    if shape is not None:\n        shape = tuple(shape)\n    if dtype is None:\n        dtype = torch.get_default_dtype()\n    if device is None:\n        device = torch.device('cpu')\n    if memory_format is None:\n        memory_format = torch.contiguous_format\n\n    key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)\n    tensor = _constant_cache.get(key, None)\n    if tensor is None:\n        tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)\n        if shape is not None:\n            tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))\n        tensor = tensor.contiguous(memory_format=memory_format)\n        _constant_cache[key] = tensor\n    return tensor\n\n#----------------------------------------------------------------------------\n# Replace NaN/Inf with specified numerical values.\n\ntry:\n    nan_to_num = torch.nan_to_num # 1.8.0a0\nexcept AttributeError:\n    def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin\n        assert isinstance(input, torch.Tensor)\n        if posinf is None:\n            posinf = torch.finfo(input.dtype).max\n        if neginf is None:\n            neginf = torch.finfo(input.dtype).min\n        assert nan == 0\n        return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)\n\n#----------------------------------------------------------------------------\n# Symbolic assert.\n\ntry:\n    symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access\nexcept AttributeError:\n    symbolic_assert = torch.Assert # 1.7.0\n\n#----------------------------------------------------------------------------\n# Context manager to temporarily suppress known warnings in torch.jit.trace().\n# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672\n\n@contextlib.contextmanager\ndef suppress_tracer_warnings():\n    flt = ('ignore', None, torch.jit.TracerWarning, None, 0)\n    warnings.filters.insert(0, flt)\n    yield\n    warnings.filters.remove(flt)\n\n#----------------------------------------------------------------------------\n# Assert that the shape of a tensor matches the given list of integers.\n# None indicates that the size of a dimension is allowed to vary.\n# Performs symbolic assertion when used in torch.jit.trace().\n\ndef assert_shape(tensor, ref_shape):\n    if tensor.ndim != len(ref_shape):\n        raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')\n    for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):\n        if ref_size is None:\n            pass\n        elif isinstance(ref_size, torch.Tensor):\n            with suppress_tracer_warnings(): # as_tensor results are registered as constants\n                symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')\n        elif isinstance(size, torch.Tensor):\n            with suppress_tracer_warnings(): # as_tensor results are registered as constants\n                symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')\n        elif size != ref_size:\n            raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')\n\n#----------------------------------------------------------------------------\n# Function decorator that calls torch.autograd.profiler.record_function().\n\ndef profiled_function(fn):\n    def decorator(*args, **kwargs):\n        with torch.autograd.profiler.record_function(fn.__name__):\n            return fn(*args, **kwargs)\n    decorator.__name__ = fn.__name__\n    return decorator\n\n#----------------------------------------------------------------------------\n# Sampler for torch.utils.data.DataLoader that loops over the dataset\n# indefinitely, shuffling items as it goes.\n\nclass InfiniteSampler(torch.utils.data.Sampler):\n    def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):\n        assert len(dataset) > 0\n        assert num_replicas > 0\n        assert 0 <= rank < num_replicas\n        assert 0 <= window_size <= 1\n        super().__init__(dataset)\n        self.dataset = dataset\n        self.rank = rank\n        self.num_replicas = num_replicas\n        self.shuffle = shuffle\n        self.seed = seed\n        self.window_size = window_size\n\n    def __iter__(self):\n        order = np.arange(len(self.dataset))\n        rnd = None\n        window = 0\n        if self.shuffle:\n            rnd = np.random.RandomState(self.seed)\n            rnd.shuffle(order)\n            window = int(np.rint(order.size * self.window_size))\n\n        idx = 0\n        while True:\n            i = idx % order.size\n            if idx % self.num_replicas == self.rank:\n                yield order[i]\n            if window >= 2:\n                j = (i - rnd.randint(window)) % order.size\n                order[i], order[j] = order[j], order[i]\n            idx += 1\n\n#----------------------------------------------------------------------------\n# Utilities for operating with torch.nn.Module parameters and buffers.\ndef spectral_to_cpu(model: torch.nn.Module):\n    def wrapped_in_spectral(m): return hasattr(m, 'weight_v')\n    children = get_children(model)\n    for child in children:\n        if wrapped_in_spectral(child):\n            child.weight = child.weight.cpu()\n    return model\n\ndef get_children(model: torch.nn.Module):\n    children = list(model.children())\n    flatt_children = []\n    if children == []:\n        return model\n    else:\n       for child in children:\n            try:\n                flatt_children.extend(get_children(child))\n            except TypeError:\n                flatt_children.append(get_children(child))\n    return flatt_children\n\ndef params_and_buffers(module):\n    assert isinstance(module, torch.nn.Module)\n    return list(module.parameters()) + list(module.buffers())\n\ndef named_params_and_buffers(module):\n    assert isinstance(module, torch.nn.Module)\n    return list(module.named_parameters()) + list(module.named_buffers())\n\ndef copy_params_and_buffers(src_module, dst_module, require_all=False):\n    assert isinstance(src_module, torch.nn.Module)\n    assert isinstance(dst_module, torch.nn.Module)\n    src_tensors = dict(named_params_and_buffers(src_module))\n    for name, tensor in named_params_and_buffers(dst_module):\n        assert (name in src_tensors) or (not require_all)\n        if name in src_tensors:\n            tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)\n\n#----------------------------------------------------------------------------\n# Context manager for easily enabling/disabling DistributedDataParallel\n# synchronization.\n\n@contextlib.contextmanager\ndef ddp_sync(module, sync):\n    assert isinstance(module, torch.nn.Module)\n    if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):\n        yield\n    else:\n        with module.no_sync():\n            yield\n\n#----------------------------------------------------------------------------\n# Check DistributedDataParallel consistency across processes.\n\ndef check_ddp_consistency(module, ignore_regex=None):\n    assert isinstance(module, torch.nn.Module)\n    for name, tensor in named_params_and_buffers(module):\n        fullname = type(module).__name__ + '.' + name\n        if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):\n            continue\n        tensor = tensor.detach()\n        if tensor.is_floating_point():\n            tensor = nan_to_num(tensor)\n        other = tensor.clone()\n        torch.distributed.broadcast(tensor=other, src=0)\n        assert (tensor == other).all(), fullname\n\n#----------------------------------------------------------------------------\n# Print summary table of module hierarchy.\n\ndef print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):\n    assert isinstance(module, torch.nn.Module)\n    assert not isinstance(module, torch.jit.ScriptModule)\n    assert isinstance(inputs, (tuple, list))\n\n    # Register hooks.\n    entries = []\n    nesting = [0]\n    def pre_hook(_mod, _inputs):\n        nesting[0] += 1\n    def post_hook(mod, _inputs, outputs):\n        nesting[0] -= 1\n        if nesting[0] <= max_nesting:\n            outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]\n            outputs = [t for t in outputs if isinstance(t, torch.Tensor)]\n            entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))\n    hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]\n    hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]\n\n    # Run module.\n    outputs = module(*inputs)\n    for hook in hooks:\n        hook.remove()\n\n    # Identify unique outputs, parameters, and buffers.\n    tensors_seen = set()\n    for e in entries:\n        e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]\n        e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]\n        e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]\n        tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}\n\n    # Filter out redundant entries.\n    if skip_redundant:\n        entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]\n\n    # Construct table.\n    rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]\n    rows += [['---'] * len(rows[0])]\n    param_total = 0\n    buffer_total = 0\n    submodule_names = {mod: name for name, mod in module.named_modules()}\n    for e in entries:\n        name = '<top-level>' if e.mod is module else submodule_names[e.mod]\n        param_size = sum(t.numel() for t in e.unique_params)\n        buffer_size = sum(t.numel() for t in e.unique_buffers)\n        output_shapes = [str(list(t.shape)) for t in e.outputs]\n        output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]\n        rows += [[\n            name + (':0' if len(e.outputs) >= 2 else ''),\n            str(param_size) if param_size else '-',\n            str(buffer_size) if buffer_size else '-',\n            (output_shapes + ['-'])[0],\n            (output_dtypes + ['-'])[0],\n        ]]\n        for idx in range(1, len(e.outputs)):\n            rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]\n        param_total += param_size\n        buffer_total += buffer_size\n    rows += [['---'] * len(rows[0])]\n    rows += [['Total', str(param_total), str(buffer_total), '-', '-']]\n\n    # Print table.\n    widths = [max(len(cell) for cell in column) for column in zip(*rows)]\n    print()\n    for row in rows:\n        print('  '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))\n    print()\n    return outputs\n"
  },
  {
    "path": "ADD/th_utils/ops/__init__.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n# empty\n"
  },
  {
    "path": "ADD/th_utils/ops/bias_act.cpp",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include \"bias_act.h\"\n\n//------------------------------------------------------------------------\n\nstatic bool has_same_layout(torch::Tensor x, torch::Tensor y)\n{\n    if (x.dim() != y.dim())\n        return false;\n    for (int64_t i = 0; i < x.dim(); i++)\n    {\n        if (x.size(i) != y.size(i))\n            return false;\n        if (x.size(i) >= 2 && x.stride(i) != y.stride(i))\n            return false;\n    }\n    return true;\n}\n\n//------------------------------------------------------------------------\n\nstatic torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)\n{\n    // Validate arguments.\n    TORCH_CHECK(x.is_cuda(), \"x must reside on CUDA device\");\n    TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), \"b must have the same dtype and device as x\");\n    TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), \"xref must have the same shape, dtype, and device as x\");\n    TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), \"yref must have the same shape, dtype, and device as x\");\n    TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), \"dy must have the same dtype and device as x\");\n    TORCH_CHECK(x.numel() <= INT_MAX, \"x is too large\");\n    TORCH_CHECK(b.dim() == 1, \"b must have rank 1\");\n    TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), \"dim is out of bounds\");\n    TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), \"b has wrong number of elements\");\n    TORCH_CHECK(grad >= 0, \"grad must be non-negative\");\n\n    // Validate layout.\n    TORCH_CHECK(x.is_non_overlapping_and_dense(), \"x must be non-overlapping and dense\");\n    TORCH_CHECK(b.is_contiguous(), \"b must be contiguous\");\n    TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), \"xref must have the same layout as x\");\n    TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), \"yref must have the same layout as x\");\n    TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), \"dy must have the same layout as x\");\n\n    // Create output tensor.\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n    torch::Tensor y = torch::empty_like(x);\n    TORCH_CHECK(has_same_layout(y, x), \"y must have the same layout as x\");\n\n    // Initialize CUDA kernel parameters.\n    bias_act_kernel_params p;\n    p.x     = x.data_ptr();\n    p.b     = (b.numel()) ? b.data_ptr() : NULL;\n    p.xref  = (xref.numel()) ? xref.data_ptr() : NULL;\n    p.yref  = (yref.numel()) ? yref.data_ptr() : NULL;\n    p.dy    = (dy.numel()) ? dy.data_ptr() : NULL;\n    p.y     = y.data_ptr();\n    p.grad  = grad;\n    p.act   = act;\n    p.alpha = alpha;\n    p.gain  = gain;\n    p.clamp = clamp;\n    p.sizeX = (int)x.numel();\n    p.sizeB = (int)b.numel();\n    p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;\n\n    // Choose CUDA kernel.\n    void* kernel;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&]\n    {\n        kernel = choose_bias_act_kernel<scalar_t>(p);\n    });\n    TORCH_CHECK(kernel, \"no CUDA kernel found for the specified activation func\");\n\n    // Launch CUDA kernel.\n    p.loopX = 4;\n    int blockSize = 4 * 32;\n    int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;\n    void* args[] = {&p};\n    AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));\n    return y;\n}\n\n//------------------------------------------------------------------------\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"bias_act\", &bias_act);\n}\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/bias_act.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <c10/util/Half.h>\n#include \"bias_act.h\"\n\n//------------------------------------------------------------------------\n// Helpers.\n\ntemplate <class T> struct InternalType;\ntemplate <> struct InternalType<double>     { typedef double scalar_t; };\ntemplate <> struct InternalType<float>      { typedef float  scalar_t; };\ntemplate <> struct InternalType<c10::Half>  { typedef float  scalar_t; };\n\n//------------------------------------------------------------------------\n// CUDA kernel.\n\ntemplate <class T, int A>\n__global__ void bias_act_kernel(bias_act_kernel_params p)\n{\n    typedef typename InternalType<T>::scalar_t scalar_t;\n    int G                 = p.grad;\n    scalar_t alpha        = (scalar_t)p.alpha;\n    scalar_t gain         = (scalar_t)p.gain;\n    scalar_t clamp        = (scalar_t)p.clamp;\n    scalar_t one          = (scalar_t)1;\n    scalar_t two          = (scalar_t)2;\n    scalar_t expRange     = (scalar_t)80;\n    scalar_t halfExpRange = (scalar_t)40;\n    scalar_t seluScale    = (scalar_t)1.0507009873554804934193349852946;\n    scalar_t seluAlpha    = (scalar_t)1.6732632423543772848170429916717;\n\n    // Loop over elements.\n    int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;\n    for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)\n    {\n        // Load.\n        scalar_t x = (scalar_t)((const T*)p.x)[xi];\n        scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;\n        scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;\n        scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;\n        scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;\n        scalar_t yy = (gain != 0) ? yref / gain : 0;\n        scalar_t y = 0;\n\n        // Apply bias.\n        ((G == 0) ? x : xref) += b;\n\n        // linear\n        if (A == 1)\n        {\n            if (G == 0) y = x;\n            if (G == 1) y = x;\n        }\n\n        // relu\n        if (A == 2)\n        {\n            if (G == 0) y = (x > 0) ? x : 0;\n            if (G == 1) y = (yy > 0) ? x : 0;\n        }\n\n        // lrelu\n        if (A == 3)\n        {\n            if (G == 0) y = (x > 0) ? x : x * alpha;\n            if (G == 1) y = (yy > 0) ? x : x * alpha;\n        }\n\n        // tanh\n        if (A == 4)\n        {\n            if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }\n            if (G == 1) y = x * (one - yy * yy);\n            if (G == 2) y = x * (one - yy * yy) * (-two * yy);\n        }\n\n        // sigmoid\n        if (A == 5)\n        {\n            if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);\n            if (G == 1) y = x * yy * (one - yy);\n            if (G == 2) y = x * yy * (one - yy) * (one - two * yy);\n        }\n\n        // elu\n        if (A == 6)\n        {\n            if (G == 0) y = (x >= 0) ? x : exp(x) - one;\n            if (G == 1) y = (yy >= 0) ? x : x * (yy + one);\n            if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);\n        }\n\n        // selu\n        if (A == 7)\n        {\n            if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);\n            if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);\n            if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);\n        }\n\n        // softplus\n        if (A == 8)\n        {\n            if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);\n            if (G == 1) y = x * (one - exp(-yy));\n            if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }\n        }\n\n        // swish\n        if (A == 9)\n        {\n            if (G == 0)\n                y = (x < -expRange) ? 0 : x / (exp(-x) + one);\n            else\n            {\n                scalar_t c = exp(xref);\n                scalar_t d = c + one;\n                if (G == 1)\n                    y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);\n                else\n                    y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);\n                yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;\n            }\n        }\n\n        // Apply gain.\n        y *= gain * dy;\n\n        // Clamp.\n        if (clamp >= 0)\n        {\n            if (G == 0)\n                y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;\n            else\n                y = (yref > -clamp & yref < clamp) ? y : 0;\n        }\n\n        // Store.\n        ((T*)p.y)[xi] = (T)y;\n    }\n}\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)\n{\n    if (p.act == 1) return (void*)bias_act_kernel<T, 1>;\n    if (p.act == 2) return (void*)bias_act_kernel<T, 2>;\n    if (p.act == 3) return (void*)bias_act_kernel<T, 3>;\n    if (p.act == 4) return (void*)bias_act_kernel<T, 4>;\n    if (p.act == 5) return (void*)bias_act_kernel<T, 5>;\n    if (p.act == 6) return (void*)bias_act_kernel<T, 6>;\n    if (p.act == 7) return (void*)bias_act_kernel<T, 7>;\n    if (p.act == 8) return (void*)bias_act_kernel<T, 8>;\n    if (p.act == 9) return (void*)bias_act_kernel<T, 9>;\n    return NULL;\n}\n\n//------------------------------------------------------------------------\n// Template specializations.\n\ntemplate void* choose_bias_act_kernel<double>       (const bias_act_kernel_params& p);\ntemplate void* choose_bias_act_kernel<float>        (const bias_act_kernel_params& p);\ntemplate void* choose_bias_act_kernel<c10::Half>    (const bias_act_kernel_params& p);\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/bias_act.h",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n//------------------------------------------------------------------------\n// CUDA kernel parameters.\n\nstruct bias_act_kernel_params\n{\n    const void* x;      // [sizeX]\n    const void* b;      // [sizeB] or NULL\n    const void* xref;   // [sizeX] or NULL\n    const void* yref;   // [sizeX] or NULL\n    const void* dy;     // [sizeX] or NULL\n    void*       y;      // [sizeX]\n\n    int         grad;\n    int         act;\n    float       alpha;\n    float       gain;\n    float       clamp;\n\n    int         sizeX;\n    int         sizeB;\n    int         stepB;\n    int         loopX;\n};\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/bias_act.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Custom PyTorch ops for efficient bias and activation.\"\"\"\n\nimport os\nimport numpy as np\nimport torch\nimport ADD.dnnlib as dnnlib\n\nfrom .. import custom_ops\nfrom .. import misc\n\n#----------------------------------------------------------------------------\n\nactivation_funcs = {\n    'linear':   dnnlib.EasyDict(func=lambda x, **_:         x,                                          def_alpha=0,    def_gain=1,             cuda_idx=1, ref='',  has_2nd_grad=False),\n    'relu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.relu(x),                def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=2, ref='y', has_2nd_grad=False),\n    'lrelu':    dnnlib.EasyDict(func=lambda x, alpha, **_:  torch.nn.functional.leaky_relu(x, alpha),   def_alpha=0.2,  def_gain=np.sqrt(2),    cuda_idx=3, ref='y', has_2nd_grad=False),\n    'tanh':     dnnlib.EasyDict(func=lambda x, **_:         torch.tanh(x),                              def_alpha=0,    def_gain=1,             cuda_idx=4, ref='y', has_2nd_grad=True),\n    'sigmoid':  dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x),                           def_alpha=0,    def_gain=1,             cuda_idx=5, ref='y', has_2nd_grad=True),\n    'elu':      dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.elu(x),                 def_alpha=0,    def_gain=1,             cuda_idx=6, ref='y', has_2nd_grad=True),\n    'selu':     dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.selu(x),                def_alpha=0,    def_gain=1,             cuda_idx=7, ref='y', has_2nd_grad=True),\n    'softplus': dnnlib.EasyDict(func=lambda x, **_:         torch.nn.functional.softplus(x),            def_alpha=0,    def_gain=1,             cuda_idx=8, ref='y', has_2nd_grad=True),\n    'swish':    dnnlib.EasyDict(func=lambda x, **_:         torch.sigmoid(x) * x,                       def_alpha=0,    def_gain=np.sqrt(2),    cuda_idx=9, ref='x', has_2nd_grad=True),\n}\n\n#----------------------------------------------------------------------------\n\n_plugin = None\n_null_tensor = torch.empty([0])\n\ndef _init():\n    global _plugin\n    if _plugin is None:\n        _plugin = custom_ops.get_plugin(\n            module_name='bias_act_plugin',\n            sources=['bias_act.cpp', 'bias_act.cu'],\n            headers=['bias_act.h'],\n            source_dir=os.path.dirname(__file__),\n            extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],\n        )\n    return True\n\n#----------------------------------------------------------------------------\n\ndef bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):\n    r\"\"\"Fused bias and activation function.\n\n    Adds bias `b` to activation tensor `x`, evaluates activation function `act`,\n    and scales the result by `gain`. Each of the steps is optional. In most cases,\n    the fused op is considerably more efficient than performing the same calculation\n    using standard PyTorch ops. It supports first and second order gradients,\n    but not third order gradients.\n\n    Args:\n        x:      Input activation tensor. Can be of any shape.\n        b:      Bias vector, or `None` to disable. Must be a 1D tensor of the same type\n                as `x`. The shape must be known, and it must match the dimension of `x`\n                corresponding to `dim`.\n        dim:    The dimension in `x` corresponding to the elements of `b`.\n                The value of `dim` is ignored if `b` is not specified.\n        act:    Name of the activation function to evaluate, or `\"linear\"` to disable.\n                Can be e.g. `\"relu\"`, `\"lrelu\"`, `\"tanh\"`, `\"sigmoid\"`, `\"swish\"`, etc.\n                See `activation_funcs` for a full list. `None` is not allowed.\n        alpha:  Shape parameter for the activation function, or `None` to use the default.\n        gain:   Scaling factor for the output tensor, or `None` to use default.\n                See `activation_funcs` for the default scaling of each activation function.\n                If unsure, consider specifying 1.\n        clamp:  Clamp the output values to `[-clamp, +clamp]`, or `None` to disable\n                the clamping (default).\n        impl:   Name of the implementation to use. Can be `\"ref\"` or `\"cuda\"` (default).\n\n    Returns:\n        Tensor of the same shape and datatype as `x`.\n    \"\"\"\n    assert isinstance(x, torch.Tensor)\n    assert impl in ['ref', 'cuda']\n    if impl == 'cuda' and x.device.type == 'cuda' and _init():\n        return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)\n    return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)\n\n#----------------------------------------------------------------------------\n\n@misc.profiled_function\ndef _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):\n    \"\"\"Slow reference implementation of `bias_act()` using standard TensorFlow ops.\n    \"\"\"\n    assert isinstance(x, torch.Tensor)\n    assert clamp is None or clamp >= 0\n    spec = activation_funcs[act]\n    alpha = float(alpha if alpha is not None else spec.def_alpha)\n    gain = float(gain if gain is not None else spec.def_gain)\n    clamp = float(clamp if clamp is not None else -1)\n\n    # Add bias.\n    if b is not None:\n        assert isinstance(b, torch.Tensor) and b.ndim == 1\n        assert 0 <= dim < x.ndim\n        assert b.shape[0] == x.shape[dim]\n        x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])\n\n    # Evaluate activation function.\n    alpha = float(alpha)\n    x = spec.func(x, alpha=alpha)\n\n    # Scale by gain.\n    gain = float(gain)\n    if gain != 1:\n        x = x * gain\n\n    # Clamp.\n    if clamp >= 0:\n        x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type\n    return x\n\n#----------------------------------------------------------------------------\n\n_bias_act_cuda_cache = dict()\n\ndef _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):\n    \"\"\"Fast CUDA implementation of `bias_act()` using custom ops.\n    \"\"\"\n    # Parse arguments.\n    assert clamp is None or clamp >= 0\n    spec = activation_funcs[act]\n    alpha = float(alpha if alpha is not None else spec.def_alpha)\n    gain = float(gain if gain is not None else spec.def_gain)\n    clamp = float(clamp if clamp is not None else -1)\n\n    # Lookup from cache.\n    key = (dim, act, alpha, gain, clamp)\n    if key in _bias_act_cuda_cache:\n        return _bias_act_cuda_cache[key]\n\n    # Forward op.\n    class BiasActCuda(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, x, b): # pylint: disable=arguments-differ\n            ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format\n            x = x.contiguous(memory_format=ctx.memory_format)\n            b = b.contiguous() if b is not None else _null_tensor\n            y = x\n            if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:\n                y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)\n            ctx.save_for_backward(\n                x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,\n                b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,\n                y if 'y' in spec.ref else _null_tensor)\n            return y\n\n        @staticmethod\n        def backward(ctx, dy): # pylint: disable=arguments-differ\n            dy = dy.contiguous(memory_format=ctx.memory_format)\n            x, b, y = ctx.saved_tensors\n            dx = None\n            db = None\n\n            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:\n                dx = dy\n                if act != 'linear' or gain != 1 or clamp >= 0:\n                    dx = BiasActCudaGrad.apply(dy, x, b, y)\n\n            if ctx.needs_input_grad[1]:\n                db = dx.sum([i for i in range(dx.ndim) if i != dim])\n\n            return dx, db\n\n    # Backward op.\n    class BiasActCudaGrad(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ\n            ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format\n            dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)\n            ctx.save_for_backward(\n                dy if spec.has_2nd_grad else _null_tensor,\n                x, b, y)\n            return dx\n\n        @staticmethod\n        def backward(ctx, d_dx): # pylint: disable=arguments-differ\n            d_dx = d_dx.contiguous(memory_format=ctx.memory_format)\n            dy, x, b, y = ctx.saved_tensors\n            d_dy = None\n            d_x = None\n            d_b = None\n            d_y = None\n\n            if ctx.needs_input_grad[0]:\n                d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)\n\n            if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):\n                d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)\n\n            if spec.has_2nd_grad and ctx.needs_input_grad[2]:\n                d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])\n\n            return d_dy, d_x, d_b, d_y\n\n    # Add to cache.\n    _bias_act_cuda_cache[key] = BiasActCuda\n    return BiasActCuda\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/conv2d_gradfix.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Custom replacement for `torch.nn.functional.conv2d` that supports\narbitrarily high order gradients with zero performance penalty.\"\"\"\n\nimport contextlib\nimport torch\nfrom pkg_resources import parse_version\n\n# pylint: disable=redefined-builtin\n# pylint: disable=arguments-differ\n# pylint: disable=protected-access\n\n#----------------------------------------------------------------------------\n\nenabled = False                     # Enable the custom op by setting this to true.\nweight_gradients_disabled = False   # Forcefully disable computation of gradients with respect to the weights.\n_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11\n\n@contextlib.contextmanager\ndef no_weight_gradients(disable=True):\n    global weight_gradients_disabled\n    old = weight_gradients_disabled\n    if disable:\n        weight_gradients_disabled = True\n    yield\n    weight_gradients_disabled = old\n\n#----------------------------------------------------------------------------\n\ndef conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):\n    if _should_use_custom_op(input):\n        return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias)\n    return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)\n\ndef conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1):\n    if _should_use_custom_op(input):\n        return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias)\n    return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation)\n\n#----------------------------------------------------------------------------\n\ndef _should_use_custom_op(input):\n    assert isinstance(input, torch.Tensor)\n    if (not enabled) or (not torch.backends.cudnn.enabled):\n        return False\n    if _use_pytorch_1_11_api:\n        # The work-around code doesn't work on PyTorch 1.11.0 onwards\n        return False\n    if input.device.type != 'cuda':\n        return False\n    return True\n\ndef _tuple_of_ints(xs, ndim):\n    xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim\n    assert len(xs) == ndim\n    assert all(isinstance(x, int) for x in xs)\n    return xs\n\n#----------------------------------------------------------------------------\n\n_conv2d_gradfix_cache = dict()\n_null_tensor = torch.empty([0])\n\ndef _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups):\n    # Parse arguments.\n    ndim = 2\n    weight_shape = tuple(weight_shape)\n    stride = _tuple_of_ints(stride, ndim)\n    padding = _tuple_of_ints(padding, ndim)\n    output_padding = _tuple_of_ints(output_padding, ndim)\n    dilation = _tuple_of_ints(dilation, ndim)\n\n    # Lookup from cache.\n    key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)\n    if key in _conv2d_gradfix_cache:\n        return _conv2d_gradfix_cache[key]\n\n    # Validate arguments.\n    assert groups >= 1\n    assert len(weight_shape) == ndim + 2\n    assert all(stride[i] >= 1 for i in range(ndim))\n    assert all(padding[i] >= 0 for i in range(ndim))\n    assert all(dilation[i] >= 0 for i in range(ndim))\n    if not transpose:\n        assert all(output_padding[i] == 0 for i in range(ndim))\n    else: # transpose\n        assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim))\n\n    # Helpers.\n    common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups)\n    def calc_output_padding(input_shape, output_shape):\n        if transpose:\n            return [0, 0]\n        return [\n            input_shape[i + 2]\n            - (output_shape[i + 2] - 1) * stride[i]\n            - (1 - 2 * padding[i])\n            - dilation[i] * (weight_shape[i + 2] - 1)\n            for i in range(ndim)\n        ]\n\n    # Forward & backward.\n    class Conv2d(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, input, weight, bias):\n            assert weight.shape == weight_shape\n            ctx.save_for_backward(\n                input if weight.requires_grad else _null_tensor,\n                weight if input.requires_grad else _null_tensor,\n            )\n            ctx.input_shape = input.shape\n\n            # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).\n            if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0):\n                a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1])\n                b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1)\n                c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2)\n                c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1)\n                c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3)\n                return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))\n\n            # General case => cuDNN.\n            if transpose:\n                return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs)\n            return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)\n\n        @staticmethod\n        def backward(ctx, grad_output):\n            input, weight = ctx.saved_tensors\n            input_shape = ctx.input_shape\n            grad_input = None\n            grad_weight = None\n            grad_bias = None\n\n            if ctx.needs_input_grad[0]:\n                p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape)\n                op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)\n                grad_input = op.apply(grad_output, weight, None)\n                assert grad_input.shape == input_shape\n\n            if ctx.needs_input_grad[1] and not weight_gradients_disabled:\n                grad_weight = Conv2dGradWeight.apply(grad_output, input)\n                assert grad_weight.shape == weight_shape\n\n            if ctx.needs_input_grad[2]:\n                grad_bias = grad_output.sum([0, 2, 3])\n\n            return grad_input, grad_weight, grad_bias\n\n    # Gradient with respect to the weights.\n    class Conv2dGradWeight(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, grad_output, input):\n            ctx.save_for_backward(\n                grad_output if input.requires_grad else _null_tensor,\n                input if grad_output.requires_grad else _null_tensor,\n            )\n            ctx.grad_output_shape = grad_output.shape\n            ctx.input_shape = input.shape\n\n            # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere).\n            if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0):\n                a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)\n                b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2)\n                c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape)\n                return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format))\n\n            # General case => cuDNN.\n            name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight'\n            flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32]\n            return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags)\n\n        @staticmethod\n        def backward(ctx, grad2_grad_weight):\n            grad_output, input = ctx.saved_tensors\n            grad_output_shape = ctx.grad_output_shape\n            input_shape = ctx.input_shape\n            grad2_grad_output = None\n            grad2_input = None\n\n            if ctx.needs_input_grad[0]:\n                grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None)\n                assert grad2_grad_output.shape == grad_output_shape\n\n            if ctx.needs_input_grad[1]:\n                p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape)\n                op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs)\n                grad2_input = op.apply(grad_output, grad2_grad_weight, None)\n                assert grad2_input.shape == input_shape\n\n            return grad2_grad_output, grad2_input\n\n    _conv2d_gradfix_cache[key] = Conv2d\n    return Conv2d\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/conv2d_resample.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"2D convolution with optional up/downsampling.\"\"\"\n\nimport torch\n\nfrom .. import misc\nfrom . import conv2d_gradfix\nfrom . import upfirdn2d\nfrom .upfirdn2d import _parse_padding\nfrom .upfirdn2d import _get_filter_size\n\n#----------------------------------------------------------------------------\n\ndef _get_weight_shape(w):\n    with misc.suppress_tracer_warnings(): # this value will be treated as a constant\n        shape = [int(sz) for sz in w.shape]\n    misc.assert_shape(w, shape)\n    return shape\n\n#----------------------------------------------------------------------------\n\ndef _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):\n    \"\"\"Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.\n    \"\"\"\n    _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w)\n\n    # Flip weight if requested.\n    # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).\n    if not flip_weight and (kw > 1 or kh > 1):\n        w = w.flip([2, 3])\n\n    # Execute using conv2d_gradfix.\n    op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d\n    return op(x, w, stride=stride, padding=padding, groups=groups)\n\n#----------------------------------------------------------------------------\n\n@misc.profiled_function\ndef conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):\n    r\"\"\"2D convolution with optional up/downsampling.\n\n    Padding is performed only once at the beginning, not between the operations.\n\n    Args:\n        x:              Input tensor of shape\n                        `[batch_size, in_channels, in_height, in_width]`.\n        w:              Weight tensor of shape\n                        `[out_channels, in_channels//groups, kernel_height, kernel_width]`.\n        f:              Low-pass filter for up/downsampling. Must be prepared beforehand by\n                        calling upfirdn2d.setup_filter(). None = identity (default).\n        up:             Integer upsampling factor (default: 1).\n        down:           Integer downsampling factor (default: 1).\n        padding:        Padding with respect to the upsampled image. Can be a single number\n                        or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                        (default: 0).\n        groups:         Split input channels into N groups (default: 1).\n        flip_weight:    False = convolution, True = correlation (default: True).\n        flip_filter:    False = convolution, True = correlation (default: False).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    # Validate arguments.\n    assert isinstance(x, torch.Tensor) and (x.ndim == 4)\n    assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)\n    assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)\n    assert isinstance(up, int) and (up >= 1)\n    assert isinstance(down, int) and (down >= 1)\n    assert isinstance(groups, int) and (groups >= 1)\n    out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)\n    fw, fh = _get_filter_size(f)\n    px0, px1, py0, py1 = _parse_padding(padding)\n\n    # Adjust padding to account for up/downsampling.\n    if up > 1:\n        px0 += (fw + up - 1) // 2\n        px1 += (fw - up) // 2\n        py0 += (fh + up - 1) // 2\n        py1 += (fh - up) // 2\n    if down > 1:\n        px0 += (fw - down + 1) // 2\n        px1 += (fw - down) // 2\n        py0 += (fh - down + 1) // 2\n        py1 += (fh - down) // 2\n\n    # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.\n    if kw == 1 and kh == 1 and (down > 1 and up == 1):\n        x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter)\n        x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)\n        return x\n\n    # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.\n    if kw == 1 and kh == 1 and (up > 1 and down == 1):\n        x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)\n        x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)\n        return x\n\n    # Fast path: downsampling only => use strided convolution.\n    if down > 1 and up == 1:\n        x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter)\n        x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)\n        return x\n\n    # Fast path: upsampling with optional downsampling => use transpose strided convolution.\n    if up > 1:\n        if groups == 1:\n            w = w.transpose(0, 1)\n        else:\n            w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)\n            w = w.transpose(1, 2)\n            w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)\n        px0 -= kw - 1\n        px1 -= kw - up\n        py0 -= kh - 1\n        py1 -= kh - up\n        pxt = max(min(-px0, -px1), 0)\n        pyt = max(min(-py0, -py1), 0)\n        x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight))\n        x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter)\n        if down > 1:\n            x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)\n        return x\n\n    # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.\n    if up == 1 and down == 1:\n        if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:\n            return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight)\n\n    # Fallback: Generic reference implementation.\n    x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter)\n    x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)\n    if down > 1:\n        x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)\n    return x\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu.cpp",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include \"filtered_lrelu.h\"\n\n//------------------------------------------------------------------------\n\nstatic std::tuple<torch::Tensor, torch::Tensor, int> filtered_lrelu(\n    torch::Tensor x, torch::Tensor fu, torch::Tensor fd, torch::Tensor b, torch::Tensor si,\n    int up, int down, int px0, int px1, int py0, int py1, int sx, int sy, float gain, float slope, float clamp, bool flip_filters, bool writeSigns)\n{\n    // Set CUDA device.\n    TORCH_CHECK(x.is_cuda(), \"x must reside on CUDA device\");\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n\n    // Validate arguments.\n    TORCH_CHECK(fu.device() == x.device() && fd.device() == x.device() && b.device() == x.device(), \"all input tensors must reside on the same device\");\n    TORCH_CHECK(fu.dtype() == torch::kFloat && fd.dtype() == torch::kFloat, \"fu and fd must be float32\");\n    TORCH_CHECK(b.dtype() == x.dtype(), \"x and b must have the same dtype\");\n    TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat, \"x and b must be float16 or float32\");\n    TORCH_CHECK(x.dim() == 4, \"x must be rank 4\");\n    TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, \"x is too large\");\n    TORCH_CHECK(x.numel() > 0, \"x is empty\");\n    TORCH_CHECK((fu.dim() == 1 || fu.dim() == 2) && (fd.dim() == 1 || fd.dim() == 2), \"fu and fd must be rank 1 or 2\");\n    TORCH_CHECK(fu.size(0) <= INT_MAX && fu.size(-1) <= INT_MAX, \"fu is too large\");\n    TORCH_CHECK(fd.size(0) <= INT_MAX && fd.size(-1) <= INT_MAX, \"fd is too large\");\n    TORCH_CHECK(fu.numel() > 0, \"fu is empty\");\n    TORCH_CHECK(fd.numel() > 0, \"fd is empty\");\n    TORCH_CHECK(b.dim() == 1 && b.size(0) == x.size(1), \"b must be a vector with the same number of channels as x\");\n    TORCH_CHECK(up >= 1 && down >= 1, \"up and down must be at least 1\");\n\n    // Figure out how much shared memory is available on the device.\n    int maxSharedBytes = 0;\n    AT_CUDA_CHECK(cudaDeviceGetAttribute(&maxSharedBytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, x.device().index()));\n    int sharedKB = maxSharedBytes >> 10;\n\n    // Populate enough launch parameters to check if a CUDA kernel exists.\n    filtered_lrelu_kernel_params p;\n    p.up      = up;\n    p.down    = down;\n    p.fuShape = make_int2((int)fu.size(-1), fu.dim() == 2 ? (int)fu.size(0) : 0); // shape [n, 0] indicates separable filter.\n    p.fdShape = make_int2((int)fd.size(-1), fd.dim() == 2 ? (int)fd.size(0) : 0);\n    filtered_lrelu_kernel_spec test_spec = choose_filtered_lrelu_kernel<float, int32_t, false, false>(p, sharedKB);\n    if (!test_spec.exec)\n    {\n        // No kernel found - return empty tensors and indicate missing kernel with return code of -1.\n        return std::make_tuple(torch::Tensor(), torch::Tensor(), -1);\n    }\n\n    // Input/output element size.\n    int64_t sz = (x.dtype() == torch::kHalf) ? 2 : 4;\n\n    // Input sizes.\n    int64_t xw = (int)x.size(3);\n    int64_t xh = (int)x.size(2);\n    int64_t fut_w = (int)fu.size(-1) - 1;\n    int64_t fut_h = (int)fu.size(0)  - 1;\n    int64_t fdt_w = (int)fd.size(-1) - 1;\n    int64_t fdt_h = (int)fd.size(0)  - 1;\n\n    // Logical size of upsampled buffer.\n    int64_t cw = xw * up + (px0 + px1) - fut_w;\n    int64_t ch = xh * up + (py0 + py1) - fut_h;\n    TORCH_CHECK(cw > fdt_w && ch > fdt_h, \"upsampled buffer must be at least the size of downsampling filter\");\n    TORCH_CHECK(cw <= INT_MAX && ch <= INT_MAX, \"upsampled buffer is too large\");\n\n    // Compute output size and allocate.\n    int64_t yw = (cw - fdt_w + (down - 1)) / down;\n    int64_t yh = (ch - fdt_h + (down - 1)) / down;\n    TORCH_CHECK(yw > 0 && yh > 0, \"output must be at least 1x1\");\n    TORCH_CHECK(yw <= INT_MAX && yh <= INT_MAX, \"output is too large\");\n    torch::Tensor y = torch::empty({x.size(0), x.size(1), yh, yw}, x.options(), x.suggest_memory_format());\n\n    // Allocate sign tensor.\n    torch::Tensor so;\n    torch::Tensor s = si;\n    bool readSigns = !!s.numel();\n    int64_t sw_active = 0; // Active width of sign tensor.\n    if (writeSigns)\n    {\n        sw_active = yw * down - (down - 1) + fdt_w;     // Active width in elements.\n        int64_t sh = yh * down - (down - 1) + fdt_h;    // Height = active height.\n        int64_t sw = (sw_active + 15) & ~15;            // Width  = active width in elements, rounded up to multiple of 16.\n        TORCH_CHECK(sh <= INT_MAX && (sw >> 2) <= INT_MAX, \"signs is too large\");\n        s = so = torch::empty({x.size(0), x.size(1), sh, sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);\n    }\n    else if (readSigns)\n        sw_active = s.size(3) << 2;\n\n    // Validate sign tensor if in use.\n    if (readSigns || writeSigns)\n    {\n        TORCH_CHECK(s.is_contiguous(), \"signs must be contiguous\");\n        TORCH_CHECK(s.dtype() == torch::kUInt8, \"signs must be uint8\");\n        TORCH_CHECK(s.device() == x.device(), \"signs must reside on the same device as x\");\n        TORCH_CHECK(s.dim() == 4, \"signs must be rank 4\");\n        TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), \"signs must have same batch & channels as x\");\n        TORCH_CHECK(s.size(2) <= INT_MAX && s.size(3) <= INT_MAX, \"signs is too large\");\n    }\n\n    // Populate rest of CUDA kernel parameters.\n    p.x         = x.data_ptr();\n    p.y         = y.data_ptr();\n    p.b         = b.data_ptr();\n    p.s         = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;\n    p.fu        = fu.data_ptr<float>();\n    p.fd        = fd.data_ptr<float>();\n    p.pad0      = make_int2(px0, py0);\n    p.gain      = gain;\n    p.slope     = slope;\n    p.clamp     = clamp;\n    p.flip      = (flip_filters) ? 1 : 0;\n    p.xShape    = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));\n    p.yShape    = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));\n    p.sShape    = (readSigns || writeSigns) ? make_int2((int)s.size(3), (int)s.size(2)) : make_int2(0, 0); // Width is in bytes. Contiguous.\n    p.sOfs      = make_int2(sx, sy);\n    p.swLimit   = (sw_active + 3) >> 2; // Rounded up to bytes.\n\n    // x, y, b strides are in bytes.\n    p.xStride   = make_longlong4(sz * x.stride(3), sz * x.stride(2), sz * x.stride(1), sz * x.stride(0));\n    p.yStride   = make_longlong4(sz * y.stride(3), sz * y.stride(2), sz * y.stride(1), sz * y.stride(0));\n    p.bStride   = sz * b.stride(0);\n\n    // fu, fd strides are in elements.\n    p.fuStride  = make_longlong3(fu.stride(-1), fu.dim() == 2 ? fu.stride(0) : 0, 0);\n    p.fdStride  = make_longlong3(fd.stride(-1), fd.dim() == 2 ? fd.stride(0) : 0, 0);\n\n    // Determine if indices don't fit in int32. Support negative strides although Torch currently never produces those.\n    bool index64b = false;\n    if (std::abs(p.bStride * x.size(1)) > INT_MAX) index64b = true;\n    if (std::min(x.size(0) * p.xStride.w, 0ll) + std::min(x.size(1) * p.xStride.z, 0ll) + std::min(x.size(2) * p.xStride.y, 0ll) + std::min(x.size(3) * p.xStride.x, 0ll) < -INT_MAX) index64b = true;\n    if (std::max(x.size(0) * p.xStride.w, 0ll) + std::max(x.size(1) * p.xStride.z, 0ll) + std::max(x.size(2) * p.xStride.y, 0ll) + std::max(x.size(3) * p.xStride.x, 0ll) >  INT_MAX) index64b = true;\n    if (std::min(y.size(0) * p.yStride.w, 0ll) + std::min(y.size(1) * p.yStride.z, 0ll) + std::min(y.size(2) * p.yStride.y, 0ll) + std::min(y.size(3) * p.yStride.x, 0ll) < -INT_MAX) index64b = true;\n    if (std::max(y.size(0) * p.yStride.w, 0ll) + std::max(y.size(1) * p.yStride.z, 0ll) + std::max(y.size(2) * p.yStride.y, 0ll) + std::max(y.size(3) * p.yStride.x, 0ll) >  INT_MAX) index64b = true;\n    if (s.numel() > INT_MAX) index64b = true;\n\n    // Choose CUDA kernel.\n    filtered_lrelu_kernel_spec spec = { 0 };\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"filtered_lrelu_cuda\", [&]\n    {\n        if constexpr (sizeof(scalar_t) <= 4) // Exclude doubles. constexpr prevents template instantiation.\n        {\n            // Choose kernel based on index type, datatype and sign read/write modes.\n            if      (!index64b &&  writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, true,  false>(p, sharedKB);\n            else if (!index64b && !writeSigns &&  readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, true >(p, sharedKB);\n            else if (!index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int32_t, false, false>(p, sharedKB);\n            else if ( index64b &&  writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, true,  false>(p, sharedKB);\n            else if ( index64b && !writeSigns &&  readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, true >(p, sharedKB);\n            else if ( index64b && !writeSigns && !readSigns) spec = choose_filtered_lrelu_kernel<scalar_t, int64_t, false, false>(p, sharedKB);\n        }\n    });\n    TORCH_CHECK(spec.exec, \"internal error - CUDA kernel not found\") // This should not happen because we tested earlier that kernel exists.\n\n    // Launch CUDA kernel.\n    void* args[] = {&p};\n    int bx = spec.numWarps * 32;\n    int gx = (p.yShape.x - 1) / spec.tileOut.x + 1;\n    int gy = (p.yShape.y - 1) / spec.tileOut.y + 1;\n    int gz = p.yShape.z * p.yShape.w;\n\n    // Repeat multiple horizontal tiles in a CTA?\n    if (spec.xrep)\n    {\n        p.tilesXrep = spec.xrep;\n        p.tilesXdim = gx;\n\n        gx = (gx + p.tilesXrep - 1) / p.tilesXrep;\n        std::swap(gx, gy);\n    }\n    else\n    {\n        p.tilesXrep = 0;\n        p.tilesXdim = 0;\n    }\n\n    // Launch filter setup kernel.\n    AT_CUDA_CHECK(cudaLaunchKernel(spec.setup, 1, 1024, args, 0, at::cuda::getCurrentCUDAStream()));\n\n    // Copy kernels to constant memory.\n    if      ( writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<true,  false>(at::cuda::getCurrentCUDAStream())));\n    else if (!writeSigns &&  readSigns) AT_CUDA_CHECK((copy_filters<false, true >(at::cuda::getCurrentCUDAStream())));\n    else if (!writeSigns && !readSigns) AT_CUDA_CHECK((copy_filters<false, false>(at::cuda::getCurrentCUDAStream())));\n\n    // Set cache and shared memory configurations for main kernel.\n    AT_CUDA_CHECK(cudaFuncSetCacheConfig(spec.exec, cudaFuncCachePreferShared));\n    if (spec.dynamicSharedKB) // Need dynamically allocated shared memory?\n        AT_CUDA_CHECK(cudaFuncSetAttribute(spec.exec, cudaFuncAttributeMaxDynamicSharedMemorySize, spec.dynamicSharedKB << 10));\n    AT_CUDA_CHECK(cudaFuncSetSharedMemConfig(spec.exec, cudaSharedMemBankSizeFourByte));\n\n    // Launch main kernel.\n    const int maxSubGz = 65535; // CUDA maximum for block z dimension.\n    for (int zofs=0; zofs < gz; zofs += maxSubGz) // Do multiple launches if gz is too big.\n    {\n        p.blockZofs = zofs;\n        int subGz = std::min(maxSubGz, gz - zofs);\n        AT_CUDA_CHECK(cudaLaunchKernel(spec.exec, dim3(gx, gy, subGz), bx, args, spec.dynamicSharedKB << 10, at::cuda::getCurrentCUDAStream()));\n    }\n\n    // Done.\n    return std::make_tuple(y, so, 0);\n}\n\n//------------------------------------------------------------------------\n\nstatic torch::Tensor filtered_lrelu_act(torch::Tensor x, torch::Tensor si, int sx, int sy, float gain, float slope, float clamp, bool writeSigns)\n{\n    // Set CUDA device.\n    TORCH_CHECK(x.is_cuda(), \"x must reside on CUDA device\");\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n\n    // Validate arguments.\n    TORCH_CHECK(x.dim() == 4, \"x must be rank 4\");\n    TORCH_CHECK(x.size(0) * x.size(1) <= INT_MAX && x.size(2) <= INT_MAX && x.size(3) <= INT_MAX, \"x is too large\");\n    TORCH_CHECK(x.numel() > 0, \"x is empty\");\n    TORCH_CHECK(x.dtype() == torch::kHalf || x.dtype() == torch::kFloat || x.dtype() == torch::kDouble, \"x must be float16, float32 or float64\");\n\n    // Output signs if we don't have sign input.\n    torch::Tensor so;\n    torch::Tensor s = si;\n    bool readSigns = !!s.numel();\n    if (writeSigns)\n    {\n        int64_t sw = x.size(3);\n        sw = (sw + 15) & ~15; // Round to a multiple of 16 for coalescing.\n        s = so = torch::empty({x.size(0), x.size(1), x.size(2), sw >> 2}, x.options().dtype(torch::kUInt8), at::MemoryFormat::Contiguous);\n    }\n\n    // Validate sign tensor if in use.\n    if (readSigns || writeSigns)\n    {\n        TORCH_CHECK(s.is_contiguous(), \"signs must be contiguous\");\n        TORCH_CHECK(s.dtype() == torch::kUInt8, \"signs must be uint8\");\n        TORCH_CHECK(s.device() == x.device(), \"signs must reside on the same device as x\");\n        TORCH_CHECK(s.dim() == 4, \"signs must be rank 4\");\n        TORCH_CHECK(s.size(0) == x.size(0) && s.size(1) == x.size(1), \"signs must have same batch & channels as x\");\n        TORCH_CHECK(s.size(2) <= INT_MAX && (s.size(3) << 2) <= INT_MAX, \"signs tensor is too large\");\n    }\n\n    // Initialize CUDA kernel parameters.\n    filtered_lrelu_act_kernel_params p;\n    p.x         = x.data_ptr();\n    p.s         = (readSigns || writeSigns) ? s.data_ptr<unsigned char>() : 0;\n    p.gain      = gain;\n    p.slope     = slope;\n    p.clamp     = clamp;\n    p.xShape    = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));\n    p.xStride   = make_longlong4(x.stride(3), x.stride(2), x.stride(1), x.stride(0));\n    p.sShape    = (readSigns || writeSigns) ? make_int2((int)s.size(3) << 2, (int)s.size(2)) : make_int2(0, 0); // Width is in elements. Contiguous.\n    p.sOfs      = make_int2(sx, sy);\n\n    // Choose CUDA kernel.\n    void* func = 0;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"filtered_lrelu_act_cuda\", [&]\n    {\n        if (writeSigns)\n            func = choose_filtered_lrelu_act_kernel<scalar_t, true, false>();\n        else if (readSigns)\n            func = choose_filtered_lrelu_act_kernel<scalar_t, false, true>();\n        else\n            func = choose_filtered_lrelu_act_kernel<scalar_t, false, false>();\n    });\n    TORCH_CHECK(func, \"internal error - CUDA kernel not found\");\n\n    // Launch CUDA kernel.\n    void* args[] = {&p};\n    int bx = 128; // 4 warps per block.\n\n    // Logical size of launch = writeSigns ? p.s : p.x\n    uint32_t gx = writeSigns ? p.sShape.x : p.xShape.x;\n    uint32_t gy = writeSigns ? p.sShape.y : p.xShape.y;\n    uint32_t gz = p.xShape.z * p.xShape.w; // Same as in p.sShape if signs are in use.\n    gx = (gx - 1) / bx + 1;\n\n    // Make sure grid y and z dimensions are within CUDA launch limits. Kernel loops internally to do the rest.\n    const uint32_t gmax = 65535;\n    gy = std::min(gy, gmax);\n    gz = std::min(gz, gmax);\n\n    // Launch.\n    AT_CUDA_CHECK(cudaLaunchKernel(func, dim3(gx, gy, gz), bx, args, 0, at::cuda::getCurrentCUDAStream()));\n    return so;\n}\n\n//------------------------------------------------------------------------\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"filtered_lrelu\",      &filtered_lrelu);      // The whole thing.\n    m.def(\"filtered_lrelu_act_\", &filtered_lrelu_act);  // Activation and sign tensor handling only. Modifies data tensor in-place.\n}\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <c10/util/Half.h>\n#include \"filtered_lrelu.h\"\n#include <cstdint>\n\n//------------------------------------------------------------------------\n// Helpers.\n\nenum // Filter modes.\n{\n    MODE_SUSD = 0,  // Separable upsampling, separable downsampling.\n    MODE_FUSD = 1,  // Full upsampling, separable downsampling.\n    MODE_SUFD = 2,  // Separable upsampling, full downsampling.\n    MODE_FUFD = 3,  // Full upsampling, full downsampling.\n};\n\ntemplate <class T> struct InternalType;\ntemplate <> struct InternalType<double>\n{\n    typedef double scalar_t; typedef double2 vec2_t; typedef double4 vec4_t;\n    __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_double2(0, 0); }\n    __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_double4(0, 0, 0, 0); }\n    __device__ __forceinline__ static double clamp(double x, double c) { return fmin(fmax(x, -c), c); }\n};\ntemplate <> struct InternalType<float>\n{\n    typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;\n    __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }\n    __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }\n    __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }\n};\ntemplate <> struct InternalType<c10::Half>\n{\n    typedef float scalar_t; typedef float2 vec2_t; typedef float4 vec4_t;\n    __device__ __forceinline__ static vec2_t zero_vec2(void) { return make_float2(0, 0); }\n    __device__ __forceinline__ static vec4_t zero_vec4(void) { return make_float4(0, 0, 0, 0); }\n    __device__ __forceinline__ static float clamp(float x, float c) { return fminf(fmaxf(x, -c), c); }\n};\n\n#define MIN(A, B)       ((A) < (B) ? (A) : (B))\n#define MAX(A, B)       ((A) > (B) ? (A) : (B))\n#define CEIL_DIV(A, B) (((B)==1) ? (A) : \\\n                        ((B)==2) ? ((int)((A)+1) >> 1) : \\\n                        ((B)==4) ? ((int)((A)+3) >> 2) : \\\n                        (((A) + ((A) > 0 ? (B) - 1 : 0)) / (B)))\n\n// This works only up to blocks of size 256 x 256 and for all N that are powers of two.\ntemplate <int N> __device__ __forceinline__ void fast_div_mod(int& x, int& y, unsigned int i)\n{\n    if ((N & (N-1)) && N <= 256)\n        y = (i * ((1<<24)/N + 1)) >> 24; // Assumes N <= 256, i < N*256.\n    else\n        y = i/N;\n\n    x = i - y*N;\n}\n\n// Type cast stride before reading it.\ntemplate <class T> __device__ __forceinline__ T get_stride(const int64_t& x)\n{\n    return *reinterpret_cast<const T*>(&x);\n}\n\n//------------------------------------------------------------------------\n// Filters, setup kernel, copying function.\n\n#define MAX_FILTER_SIZE 32\n\n// Combined up/down filter buffers so that transfer can be done with one copy.\n__device__              float g_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in global memory, written by setup kernel.\n__device__ __constant__ float c_fbuf[2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE]; // Filters in constant memory, read by main kernel.\n\n// Accessors to combined buffers to index up/down filters individually.\n#define c_fu (c_fbuf)\n#define c_fd (c_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)\n#define g_fu (g_fbuf)\n#define g_fd (g_fbuf + MAX_FILTER_SIZE * MAX_FILTER_SIZE)\n\n// Set up filters into global memory buffer.\nstatic __global__ void setup_filters_kernel(filtered_lrelu_kernel_params p)\n{\n    for (int idx = threadIdx.x; idx < MAX_FILTER_SIZE * MAX_FILTER_SIZE; idx += blockDim.x)\n    {\n        int x, y;\n        fast_div_mod<MAX_FILTER_SIZE>(x, y, idx);\n\n        int fu_x = p.flip ? x : (p.fuShape.x - 1 - x);\n        int fu_y = p.flip ? y : (p.fuShape.y - 1 - y);\n        if (p.fuShape.y > 0)\n            g_fu[idx] = (x >= p.fuShape.x || y >= p.fuShape.y) ? 0.0f : p.fu[fu_x * p.fuStride.x + fu_y * p.fuStride.y];\n        else\n            g_fu[idx] = (x >= p.fuShape.x || y > 0) ? 0.0f : p.fu[fu_x * p.fuStride.x];\n\n        int fd_x = p.flip ? x : (p.fdShape.x - 1 - x);\n        int fd_y = p.flip ? y : (p.fdShape.y - 1 - y);\n        if (p.fdShape.y > 0)\n            g_fd[idx] = (x >= p.fdShape.x || y >= p.fdShape.y) ? 0.0f : p.fd[fd_x * p.fdStride.x + fd_y * p.fdStride.y];\n        else\n            g_fd[idx] = (x >= p.fdShape.x || y > 0) ? 0.0f : p.fd[fd_x * p.fdStride.x];\n    }\n}\n\n// Host function to copy filters written by setup kernel into constant buffer for main kernel.\ntemplate <bool, bool> static cudaError_t copy_filters(cudaStream_t stream)\n{\n    void* src = 0;\n    cudaError_t err = cudaGetSymbolAddress(&src, g_fbuf);\n    if (err) return err;\n    return cudaMemcpyToSymbolAsync(c_fbuf, src, 2 * MAX_FILTER_SIZE * MAX_FILTER_SIZE * sizeof(float), 0, cudaMemcpyDeviceToDevice, stream);\n}\n\n//------------------------------------------------------------------------\n// Coordinate spaces:\n// - Relative to input tensor:      inX, inY, tileInX, tileInY\n// - Relative to input tile:        relInX, relInY, tileInW, tileInH\n// - Relative to upsampled tile:    relUpX, relUpY, tileUpW, tileUpH\n// - Relative to output tile:       relOutX, relOutY, tileOutW, tileOutH\n// - Relative to output tensor:     outX, outY, tileOutX, tileOutY\n//\n// Relationships between coordinate spaces:\n// - inX = tileInX + relInX\n// - inY = tileInY + relInY\n// - relUpX = relInX * up + phaseInX\n// - relUpY = relInY * up + phaseInY\n// - relUpX = relOutX * down\n// - relUpY = relOutY * down\n// - outX = tileOutX + relOutX\n// - outY = tileOutY + relOutY\n\nextern __shared__ char s_buf_raw[]; // When sharedKB <= 48, allocate shared memory statically inside the kernel, otherwise use the externally allocated shared memory buffer.\n\ntemplate <class T, class index_t, int sharedKB, bool signWrite, bool signRead, int filterMode, int up, int fuSize, int down, int fdSize, int tileOutW, int tileOutH, int threadsPerBlock, bool enableXrep, bool enableWriteSkip>\nstatic __global__ void filtered_lrelu_kernel(filtered_lrelu_kernel_params p)\n{\n    // Check that we don't try to support non-existing filter modes.\n    static_assert(up   == 1 || up   == 2 || up   == 4, \"only up=1, up=2, up=4 scales supported\");\n    static_assert(down == 1 || down == 2 || down == 4, \"only down=1, down=2, down=4 scales supported\");\n    static_assert(fuSize >= up,   \"upsampling filter size must be at least upsampling factor\");\n    static_assert(fdSize >= down, \"downsampling filter size must be at least downsampling factor\");\n    static_assert(fuSize % up   == 0, \"upsampling filter size must be divisible with upsampling factor\");\n    static_assert(fdSize % down == 0, \"downsampling filter size must be divisible with downsampling factor\");\n    static_assert(fuSize <= MAX_FILTER_SIZE && fdSize <= MAX_FILTER_SIZE, \"filter size greater than MAX_FILTER_SIZE\");\n    static_assert(up   != 1 || (fuSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), \"up=1 supported only for 1x1 full filters\");\n    static_assert(down != 1 || (fdSize == 1 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), \"down=1 supported only for 1x1 full filters\");\n    static_assert(!(up   == 4 && (filterMode == MODE_FUFD || filterMode == MODE_FUSD)), \"full filters not supported for up=4\");\n    static_assert(!(down == 4 && (filterMode == MODE_FUFD || filterMode == MODE_SUFD)), \"full filters not supported for down=4\");\n\n    // Static definitions.\n    typedef typename InternalType<T>::scalar_t scalar_t;\n    typedef typename InternalType<T>::vec2_t vec2_t;\n    typedef typename InternalType<T>::vec4_t vec4_t;\n    const int tileUpW    = (tileOutW * down + (fdSize - 1) - (down - 1) + 3) & ~3;  // Upsampled tile width, rounded up to multiple of 4.\n    const int tileUpH    = tileOutH * down + (fdSize - 1) - (down - 1);             // Upsampled tile height.\n    const int tileInW    = CEIL_DIV(tileUpW  + (fuSize - 1), up);                   // Input tile width.\n    const int tileInH    = CEIL_DIV(tileUpH  + (fuSize - 1), up);                   // Input tile height.\n    const int tileUpH_up = CEIL_DIV(tileUpH, up) * up;                              // Upsampled tile height rounded up to a multiple of up.\n    const int tileInH_up = CEIL_DIV(tileUpH_up + (fuSize - 1), up);                 // For allocations only, to avoid shared memory read overruns with up=2 and up=4.\n\n    // Merge 1x1 downsampling into last upsampling step for upf1 and ups2.\n    const bool downInline = (down == 1) && ((up == 1 && filterMode == MODE_FUFD) || (up == 2 && filterMode == MODE_SUFD));\n\n    // Sizes of logical buffers.\n    const int szIn    = tileInH_up * tileInW;\n    const int szUpX   = tileInH_up * tileUpW;\n    const int szUpXY  = downInline ? 0 : (tileUpH * tileUpW);\n    const int szDownX = tileUpH * tileOutW;\n\n    // Sizes for shared memory arrays.\n    const int s_buf0_size_base =\n        (filterMode == MODE_SUSD) ? MAX(szIn, szUpXY) :\n        (filterMode == MODE_FUSD) ? MAX(szIn, szDownX) :\n        (filterMode == MODE_SUFD) ? MAX(szIn, szUpXY) :\n        (filterMode == MODE_FUFD) ? szIn :\n        -1;\n    const int s_buf1_size_base =\n        (filterMode == MODE_SUSD) ? MAX(szUpX, szDownX) :\n        (filterMode == MODE_FUSD) ? szUpXY :\n        (filterMode == MODE_SUFD) ? szUpX  :\n        (filterMode == MODE_FUFD) ? szUpXY :\n        -1;\n\n    // Ensure U128 alignment.\n    const int s_buf0_size = (s_buf0_size_base + 3) & ~3;\n    const int s_buf1_size = (s_buf1_size_base + 3) & ~3;\n\n    // Check at compile time that we don't use too much shared memory.\n    static_assert((s_buf0_size + s_buf1_size) * sizeof(scalar_t) <= (sharedKB << 10), \"shared memory overflow\");\n\n    // Declare shared memory arrays.\n    scalar_t* s_buf0;\n    scalar_t* s_buf1;\n    if (sharedKB <= 48)\n    {\n        // Allocate shared memory arrays here.\n        __shared__ scalar_t s_buf0_st[(sharedKB > 48) ? (1<<24) : (s_buf0_size + s_buf1_size)]; // Prevent launching if this isn't optimized away when unused.\n        s_buf0 = s_buf0_st;\n        s_buf1 = s_buf0 + s_buf0_size;\n    }\n    else\n    {\n        // Use the dynamically allocated shared memory array.\n        s_buf0 = (scalar_t*)s_buf_raw;\n        s_buf1 = s_buf0 + s_buf0_size;\n    }\n\n    // Pointers to the buffers.\n    scalar_t* s_tileIn;       // Input tile:                      [relInX * tileInH + relInY]\n    scalar_t* s_tileUpX;      // After horizontal upsampling:     [relInY * tileUpW + relUpX]\n    scalar_t* s_tileUpXY;     // After upsampling:                [relUpY * tileUpW + relUpX]\n    scalar_t* s_tileDownX;    // After horizontal downsampling:   [relUpY * tileOutW + relOutX]\n    if (filterMode == MODE_SUSD)\n    {\n        s_tileIn    = s_buf0;\n        s_tileUpX   = s_buf1;\n        s_tileUpXY  = s_buf0;\n        s_tileDownX = s_buf1;\n    }\n    else if (filterMode == MODE_FUSD)\n    {\n        s_tileIn    = s_buf0;\n        s_tileUpXY  = s_buf1;\n        s_tileDownX = s_buf0;\n    }\n    else if (filterMode == MODE_SUFD)\n    {\n        s_tileIn    = s_buf0;\n        s_tileUpX   = s_buf1;\n        s_tileUpXY  = s_buf0;\n    }\n    else if (filterMode == MODE_FUFD)\n    {\n        s_tileIn    = s_buf0;\n        s_tileUpXY  = s_buf1;\n    }\n\n    // Allow large grids in z direction via per-launch offset.\n    int channelIdx = blockIdx.z + p.blockZofs;\n    int batchIdx = channelIdx / p.yShape.z;\n    channelIdx -= batchIdx * p.yShape.z;\n\n    // Offset to output feature map. In bytes.\n    index_t mapOfsOut = channelIdx * get_stride<index_t>(p.yStride.z) + batchIdx * get_stride<index_t>(p.yStride.w);\n\n    // Sign shift amount.\n    uint32_t signXo = ((threadIdx.x + p.sOfs.x) << 1) & 6;\n\n    // Inner tile loop.\n    #pragma unroll 1\n    for (int tileIdx = 0; !enableXrep || (tileIdx < MIN(p.tilesXrep, p.tilesXdim - p.tilesXrep * blockIdx.y)); tileIdx++)\n    {\n        // Locate output tile.\n        int tileX = enableXrep ? blockIdx.y * p.tilesXrep + tileIdx : blockIdx.x;\n        int tileOutX = tileX * tileOutW;\n        int tileOutY = (enableXrep ? blockIdx.x : blockIdx.y) * tileOutH;\n\n        // Locate input tile.\n        int tmpX = tileOutX * down - p.pad0.x;\n        int tmpY = tileOutY * down - p.pad0.y;\n        int tileInX = CEIL_DIV(tmpX, up);\n        int tileInY = CEIL_DIV(tmpY, up);\n        const int phaseInX = tileInX * up - tmpX;\n        const int phaseInY = tileInY * up - tmpY;\n\n        // Extra sync if input and output buffers are the same and we are not on first tile.\n        if (enableXrep && tileIdx > 0 && (filterMode == MODE_FUSD || (filterMode == MODE_SUFD && !downInline) || (filterMode == MODE_FUFD && downInline)))\n            __syncthreads();\n\n        // Load input tile & apply bias. Unrolled.\n        scalar_t b = (scalar_t)*(const T*)((const char*)p.b + (channelIdx * get_stride<index_t>(p.bStride)));\n        index_t mapOfsIn = channelIdx * get_stride<index_t>(p.xStride.z) + batchIdx * get_stride<index_t>(p.xStride.w);\n        int idx = threadIdx.x;\n        const int loopCountIN = CEIL_DIV(tileInW * tileInH, threadsPerBlock);\n        #pragma unroll\n        for (int loop = 0; loop < loopCountIN; loop++)\n        {\n            int relInX, relInY;\n            fast_div_mod<tileInW>(relInX, relInY, idx);\n            int inX = tileInX + relInX;\n            int inY = tileInY + relInY;\n            scalar_t v = 0;\n\n            if ((uint32_t)inX < p.xShape.x && (uint32_t)inY < p.xShape.y)\n                v = (scalar_t)*((const T*)((const char*)p.x + (inX * get_stride<index_t>(p.xStride.x) + inY * get_stride<index_t>(p.xStride.y) + mapOfsIn))) + b;\n\n            bool skip = (loop == loopCountIN-1) && (idx >= tileInW * tileInH);\n            if (!skip)\n                s_tileIn[idx] = v;\n\n            idx += threadsPerBlock;\n        }\n\n        if (filterMode == MODE_SUSD || filterMode == MODE_SUFD) // Separable upsampling filter.\n        {\n            // Horizontal upsampling.\n            __syncthreads();\n            if (up == 4)\n            {\n                for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)\n                {\n                    int relUpX0, relInY;\n                    fast_div_mod<tileUpW>(relUpX0, relInY, idx);\n                    int relInX0 = relUpX0 / up;\n                    int src0 = relInX0 + tileInW * relInY;\n                    int dst = relInY * tileUpW + relUpX0;\n                    vec4_t v = InternalType<T>::zero_vec4();\n                    scalar_t a = s_tileIn[src0];\n                    if (phaseInX == 0)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 3];\n                            v.z += a * (scalar_t)c_fu[step * up + 2];\n                            v.w += a * (scalar_t)c_fu[step * up + 1];\n                        }\n                    }\n                    else if (phaseInX == 1)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                            v.z += a * (scalar_t)c_fu[step * up + 3];\n                            v.w += a * (scalar_t)c_fu[step * up + 2];\n                        }\n                    }\n                    else if (phaseInX == 2)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 2];\n                            v.y += a * (scalar_t)c_fu[step * up + 1];\n                            v.z += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                            v.w += a * (scalar_t)c_fu[step * up + 3];\n                        }\n                    }\n                    else // (phaseInX == 3)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 3];\n                            v.y += a * (scalar_t)c_fu[step * up + 2];\n                            v.z += a * (scalar_t)c_fu[step * up + 1];\n                            v.w += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                        }\n                    }\n                    s_tileUpX[dst+0] = v.x;\n                    s_tileUpX[dst+1] = v.y;\n                    s_tileUpX[dst+2] = v.z;\n                    s_tileUpX[dst+3] = v.w;\n                }\n            }\n            else if (up == 2)\n            {\n                bool p0 = (phaseInX == 0);\n                for (int idx = threadIdx.x*up; idx < tileUpW * tileInH; idx += blockDim.x*up)\n                {\n                    int relUpX0, relInY;\n                    fast_div_mod<tileUpW>(relUpX0, relInY, idx);\n                    int relInX0 = relUpX0 / up;\n                    int src0 = relInX0 + tileInW * relInY;\n                    int dst = relInY * tileUpW + relUpX0;\n                    vec2_t v = InternalType<T>::zero_vec2();\n                    scalar_t a = s_tileIn[src0];\n                    if (p0) // (phaseInX == 0)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 1];\n                        }\n                    }\n                    else // (phaseInX == 1)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileIn[src0 + step + 1];\n                        }\n                    }\n                    s_tileUpX[dst+0] = v.x;\n                    s_tileUpX[dst+1] = v.y;\n                }\n            }\n\n            // Vertical upsampling & nonlinearity.\n\n            __syncthreads();\n            int groupMask = 15 << ((threadIdx.x & 31) & ~3);\n            int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.\n            int sShapeMaxY = MIN(p.sShape.y, tileOutY * down + tileUpH); // Avoid out-of-tile sign writes.\n            if (up == 4)\n            {\n                minY -= 3; // Adjust according to block height.\n                for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)\n                {\n                    int relUpX, relInY0;\n                    fast_div_mod<tileUpW>(relUpX, relInY0, idx);\n                    int relUpY0 = relInY0 * up;\n                    int src0 = relInY0 * tileUpW + relUpX;\n                    int dst = relUpY0 * tileUpW + relUpX;\n                    vec4_t v = InternalType<T>::zero_vec4();\n\n                    scalar_t a = s_tileUpX[src0];\n                    if (phaseInY == 0)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                            v.y += a * (scalar_t)c_fu[step * up + 3];\n                            v.z += a * (scalar_t)c_fu[step * up + 2];\n                            v.w += a * (scalar_t)c_fu[step * up + 1];\n                        }\n                    }\n                    else if (phaseInY == 1)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                            v.z += a * (scalar_t)c_fu[step * up + 3];\n                            v.w += a * (scalar_t)c_fu[step * up + 2];\n                        }\n                    }\n                    else if (phaseInY == 2)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 2];\n                            v.y += a * (scalar_t)c_fu[step * up + 1];\n                            v.z += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                            v.w += a * (scalar_t)c_fu[step * up + 3];\n                        }\n                    }\n                    else // (phaseInY == 3)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 3];\n                            v.y += a * (scalar_t)c_fu[step * up + 2];\n                            v.z += a * (scalar_t)c_fu[step * up + 1];\n                            v.w += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                        }\n                    }\n\n                    int x = tileOutX * down + relUpX;\n                    int y = tileOutY * down + relUpY0;\n                    int signX = x + p.sOfs.x;\n                    int signY = y + p.sOfs.y;\n                    int signZ = blockIdx.z + p.blockZofs;\n                    int signXb = signX >> 2;\n                    index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);\n                    index_t si1 = si0 + p.sShape.x;\n                    index_t si2 = si0 + p.sShape.x * 2;\n                    index_t si3 = si0 + p.sShape.x * 3;\n\n                    v.x *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.y *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.z *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.w *= (scalar_t)((float)up * (float)up * p.gain);\n\n                    if (signWrite)\n                    {\n                        if (!enableWriteSkip)\n                        {\n                            // Determine and write signs.\n                            int sx = __float_as_uint(v.x) >> 31 <<  0;\n                            int sy = __float_as_uint(v.y) >> 31 <<  8;\n                            int sz = __float_as_uint(v.z) >> 31 << 16;\n                            int sw = __float_as_uint(v.w) >> 31 << 24;\n                            if (sx) v.x *= p.slope;\n                            if (sy) v.y *= p.slope;\n                            if (sz) v.z *= p.slope;\n                            if (sw) v.w *= p.slope;\n                            if (fabsf(v.x) > p.clamp) { sx = 2 <<  0; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                            if (fabsf(v.y) > p.clamp) { sy = 2 <<  8; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n                            if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }\n                            if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }\n\n                            if ((uint32_t)signXb < p.swLimit && signY >= minY)\n                            {\n                                // Combine signs.\n                                uint32_t s = sx + sy + sw + sz;\n                                s <<= (signX & 3) << 1;\n                                s |= __shfl_xor_sync(groupMask, s, 1);\n                                s |= __shfl_xor_sync(groupMask, s, 2);\n\n                                // Write signs.\n                                if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >>  0); }\n                                if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >>  8); }\n                                if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }\n                                if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }\n                            }\n                        }\n                        else\n                        {\n                            // Determine and write signs.\n                            if ((uint32_t)signXb < p.swLimit && signY >= minY)\n                            {\n                                int sx = __float_as_uint(v.x) >> 31 <<  0;\n                                int sy = __float_as_uint(v.y) >> 31 <<  8;\n                                int sz = __float_as_uint(v.z) >> 31 << 16;\n                                int sw = __float_as_uint(v.w) >> 31 << 24;\n                                if (sx) v.x *= p.slope;\n                                if (sy) v.y *= p.slope;\n                                if (sz) v.z *= p.slope;\n                                if (sw) v.w *= p.slope;\n                                if (fabsf(v.x) > p.clamp) { sx = 2 <<  0; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                                if (fabsf(v.y) > p.clamp) { sy = 2 <<  8; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n                                if (fabsf(v.z) > p.clamp) { sz = 2 << 16; v.z = InternalType<T>::clamp(v.z, p.clamp); }\n                                if (fabsf(v.w) > p.clamp) { sw = 2 << 24; v.w = InternalType<T>::clamp(v.w, p.clamp); }\n\n                                // Combine signs.\n                                uint32_t s = sx + sy + sw + sz;\n                                s <<= (signX & 3) << 1;\n                                s |= __shfl_xor_sync(groupMask, s, 1);\n                                s |= __shfl_xor_sync(groupMask, s, 2);\n\n                                // Write signs.\n                                if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >>  0); }\n                                if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >>  8); }\n                                if ((uint32_t)(signY + 2) < sShapeMaxY) { p.s[si2] = (unsigned char)(s >> 16); }\n                                if ((uint32_t)(signY + 3) < sShapeMaxY) { p.s[si3] = (unsigned char)(s >> 24); }\n                            }\n                            else\n                            {\n                                // Just compute the values.\n                                if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                                if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                                if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);\n                                if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);\n                            }\n                        }\n                    }\n                    else if (signRead) // Read signs and apply.\n                    {\n                        if ((uint32_t)signXb < p.swLimit)\n                        {\n                            int ss = (signX & 3) << 1;\n                            if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> ss; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }\n                            if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> ss; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }\n                            if ((uint32_t)(signY + 2) < p.sShape.y) { int s = p.s[si2] >> ss; if (s & 1) v.z *= p.slope; if (s & 2) v.z = 0.f; }\n                            if ((uint32_t)(signY + 3) < p.sShape.y) { int s = p.s[si3] >> ss; if (s & 1) v.w *= p.slope; if (s & 2) v.w = 0.f; }\n                        }\n                    }\n                    else // Forward pass with no sign write.\n                    {\n                        if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                        if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                        if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);\n                        if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);\n                    }\n\n                    s_tileUpXY[dst + 0 * tileUpW] = v.x;\n                    if (relUpY0 + 1 < tileUpH) s_tileUpXY[dst + 1 * tileUpW] = v.y;\n                    if (relUpY0 + 2 < tileUpH) s_tileUpXY[dst + 2 * tileUpW] = v.z;\n                    if (relUpY0 + 3 < tileUpH) s_tileUpXY[dst + 3 * tileUpW] = v.w;\n                }\n            }\n            else if (up == 2)\n            {\n                minY -= 1; // Adjust according to block height.\n                for (int idx = threadIdx.x; idx < tileUpW * tileUpH_up / up; idx += blockDim.x)\n                {\n                    int relUpX, relInY0;\n                    fast_div_mod<tileUpW>(relUpX, relInY0, idx);\n                    int relUpY0 = relInY0 * up;\n                    int src0 = relInY0 * tileUpW + relUpX;\n                    int dst = relUpY0 * tileUpW + relUpX;\n                    vec2_t v = InternalType<T>::zero_vec2();\n\n                    scalar_t a = s_tileUpX[src0];\n                    if (phaseInY == 0)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                            v.y += a * (scalar_t)c_fu[step * up + 1];\n                        }\n                    }\n                    else // (phaseInY == 1)\n                    {\n                        #pragma unroll\n                        for (int step = 0; step < fuSize / up; step++)\n                        {\n                            v.x += a * (scalar_t)c_fu[step * up + 1];\n                            v.y += a * (scalar_t)c_fu[step * up + 0];\n                            a = s_tileUpX[src0 + (step + 1) * tileUpW];\n                        }\n                    }\n\n                    int x = tileOutX * down + relUpX;\n                    int y = tileOutY * down + relUpY0;\n                    int signX = x + p.sOfs.x;\n                    int signY = y + p.sOfs.y;\n                    int signZ = blockIdx.z + p.blockZofs;\n                    int signXb = signX >> 2;\n                    index_t si0 = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);\n                    index_t si1 = si0 + p.sShape.x;\n\n                    v.x *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.y *= (scalar_t)((float)up * (float)up * p.gain);\n\n                    if (signWrite)\n                    {\n                        if (!enableWriteSkip)\n                        {\n                            // Determine and write signs.\n                            int sx = __float_as_uint(v.x) >> 31 << 0;\n                            int sy = __float_as_uint(v.y) >> 31 << 8;\n                            if (sx) v.x *= p.slope;\n                            if (sy) v.y *= p.slope;\n                            if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                            if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n\n                            if ((uint32_t)signXb < p.swLimit && signY >= minY)\n                            {\n                                // Combine signs.\n                                int s = sx + sy;\n                                s <<= signXo;\n                                s |= __shfl_xor_sync(groupMask, s, 1);\n                                s |= __shfl_xor_sync(groupMask, s, 2);\n\n                                // Write signs.\n                                if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >>  0); }\n                                if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >>  8); }\n                            }\n                        }\n                        else\n                        {\n                            // Determine and write signs.\n                            if ((uint32_t)signXb < p.swLimit && signY >= minY)\n                            {\n                                int sx = __float_as_uint(v.x) >> 31 << 0;\n                                int sy = __float_as_uint(v.y) >> 31 << 8;\n                                if (sx) v.x *= p.slope;\n                                if (sy) v.y *= p.slope;\n                                if (fabsf(v.x) > p.clamp) { sx = 2 << 0; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                                if (fabsf(v.y) > p.clamp) { sy = 2 << 8; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n\n                                // Combine signs.\n                                int s = sx + sy;\n                                s <<= signXo;\n                                s |= __shfl_xor_sync(groupMask, s, 1);\n                                s |= __shfl_xor_sync(groupMask, s, 2);\n\n                                // Write signs.\n                                if ((uint32_t)(signY + 0) < sShapeMaxY) { p.s[si0] = (unsigned char)(s >>  0); }\n                                if ((uint32_t)(signY + 1) < sShapeMaxY) { p.s[si1] = (unsigned char)(s >>  8); }\n                            }\n                            else\n                            {\n                                // Just compute the values.\n                                if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                                if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                            }\n                        }\n                    }\n                    else if (signRead) // Read signs and apply.\n                    {\n                        if ((uint32_t)signXb < p.swLimit)\n                        {\n                            if ((uint32_t)(signY + 0) < p.sShape.y) { int s = p.s[si0] >> signXo; if (s & 1) v.x *= p.slope; if (s & 2) v.x = 0.f; }\n                            if ((uint32_t)(signY + 1) < p.sShape.y) { int s = p.s[si1] >> signXo; if (s & 1) v.y *= p.slope; if (s & 2) v.y = 0.f; }\n                        }\n                    }\n                    else // Forward pass with no sign write.\n                    {\n                        if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                        if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                    }\n\n                    if (!downInline)\n                    {\n                        // Write into temporary buffer.\n                        s_tileUpXY[dst] = v.x;\n                        if (relUpY0 < tileUpH - 1)\n                            s_tileUpXY[dst + tileUpW] = v.y;\n                    }\n                    else\n                    {\n                        // Write directly into output buffer.\n                        if ((uint32_t)x < p.yShape.x)\n                        {\n                            int ymax = MIN(p.yShape.y, tileUpH + tileOutY * down);\n                            index_t ofs = x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut;\n                            if ((uint32_t)y + 0 < p.yShape.y) *((T*)((char*)p.y + ofs)) = (T)(v.x * (scalar_t)c_fd[0]);\n                            if ((uint32_t)y + 1 < ymax) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.y))) = (T)(v.y * (scalar_t)c_fd[0]);\n                        }\n                    }\n                }\n            }\n        }\n        else if (filterMode == MODE_FUSD || filterMode == MODE_FUFD)\n        {\n            // Full upsampling filter.\n\n            if (up == 2)\n            {\n                // 2 x 2-wide.\n                __syncthreads();\n                int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH + p.sOfs.y : 0; // Skip already written signs.\n                for (int idx = threadIdx.x * 4; idx < tileUpW * tileUpH; idx += blockDim.x * 4)\n                {\n                    int relUpX0, relUpY0;\n                    fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);\n                    int relInX0 = CEIL_DIV(relUpX0 - phaseInX, up);\n                    int relInY0 = CEIL_DIV(relUpY0 - phaseInY, up);\n                    int src0 = relInX0 + tileInW * relInY0;\n                    int tap0y = (relInY0 * up + phaseInY - relUpY0);\n\n                    #define X_LOOP(TAPY, PX) \\\n                        for (int sx = 0; sx < fuSize / up; sx++) \\\n                        { \\\n                            v.x += a * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \\\n                            v.z += b * (scalar_t)c_fu[(sx * up + (((PX) - 0) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 0) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \\\n                            v.y += a * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; \\\n                            v.w += b * (scalar_t)c_fu[(sx * up + (((PX) - 1) & (up - 1))) + (sy * up + (TAPY)) * MAX_FILTER_SIZE]; if ((PX) == 1) { a = b; b = s_tileIn[src0 + 2 + sx + sy * tileInW]; } \\\n                        }\n\n                    vec4_t v = InternalType<T>::zero_vec4();\n                    if (tap0y == 0 && phaseInX == 0)\n                        #pragma unroll\n                        for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];\n                            #pragma unroll\n                            X_LOOP(0, 0) }\n                    if (tap0y == 0 && phaseInX == 1)\n                        #pragma unroll\n                        for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];\n                            #pragma unroll\n                            X_LOOP(0, 1) }\n                    if (tap0y == 1 && phaseInX == 0)\n                        #pragma unroll\n                        for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];\n                            #pragma unroll\n                            X_LOOP(1, 0) }\n                    if (tap0y == 1 && phaseInX == 1)\n                        #pragma unroll\n                        for (int sy = 0; sy < fuSize / up; sy++) { scalar_t a = s_tileIn[src0 + sy * tileInW]; scalar_t b = s_tileIn[src0 + sy * tileInW + 1];\n                            #pragma unroll\n                            X_LOOP(1, 1) }\n\n                    #undef X_LOOP\n\n                    int x = tileOutX * down + relUpX0;\n                    int y = tileOutY * down + relUpY0;\n                    int signX = x + p.sOfs.x;\n                    int signY = y + p.sOfs.y;\n                    int signZ = blockIdx.z + p.blockZofs;\n                    int signXb = signX >> 2;\n                    index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);\n\n                    v.x *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.y *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.z *= (scalar_t)((float)up * (float)up * p.gain);\n                    v.w *= (scalar_t)((float)up * (float)up * p.gain);\n\n                    if (signWrite)\n                    {\n                        if (!enableWriteSkip)\n                        {\n                            // Determine and write signs.\n                            int sx = __float_as_uint(v.x) >> 31;\n                            int sy = __float_as_uint(v.y) >> 31;\n                            int sz = __float_as_uint(v.z) >> 31;\n                            int sw = __float_as_uint(v.w) >> 31;\n                            if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                            if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n                            if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }\n                            if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }\n\n                            if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)\n                            {\n                                p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);\n                            }\n                        }\n                        else\n                        {\n                            // Determine and write signs.\n                            if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)\n                            {\n                                int sx = __float_as_uint(v.x) >> 31;\n                                int sy = __float_as_uint(v.y) >> 31;\n                                int sz = __float_as_uint(v.z) >> 31;\n                                int sw = __float_as_uint(v.w) >> 31;\n                                if (sx) v.x *= p.slope; if (fabsf(v.x) > p.clamp) { sx = 2; v.x = InternalType<T>::clamp(v.x, p.clamp); }\n                                if (sy) v.y *= p.slope; if (fabsf(v.y) > p.clamp) { sy = 2; v.y = InternalType<T>::clamp(v.y, p.clamp); }\n                                if (sz) v.z *= p.slope; if (fabsf(v.z) > p.clamp) { sz = 2; v.z = InternalType<T>::clamp(v.z, p.clamp); }\n                                if (sw) v.w *= p.slope; if (fabsf(v.w) > p.clamp) { sw = 2; v.w = InternalType<T>::clamp(v.w, p.clamp); }\n\n                                p.s[si] = sx + (sy << 2) + (sz << 4) + (sw << 6);\n                            }\n                            else\n                            {\n                                // Just compute the values.\n                                if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                                if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                                if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);\n                                if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);\n                            }\n                        }\n                    }\n                    else if (signRead) // Read sign and apply.\n                    {\n                        if ((uint32_t)signY < p.sShape.y)\n                        {\n                            int s = 0;\n                            if ((uint32_t)signXb     < p.swLimit) s  = p.s[si];\n                            if ((uint32_t)signXb + 1 < p.swLimit) s |= p.s[si + 1] << 8;\n                            s >>= (signX & 3) << 1;\n                            if (s & 0x01) v.x *= p.slope; if (s & 0x02) v.x = 0.f;\n                            if (s & 0x04) v.y *= p.slope; if (s & 0x08) v.y = 0.f;\n                            if (s & 0x10) v.z *= p.slope; if (s & 0x20) v.z = 0.f;\n                            if (s & 0x40) v.w *= p.slope; if (s & 0x80) v.w = 0.f;\n                        }\n                    }\n                    else // Forward pass with no sign write.\n                    {\n                        if (v.x < 0.f) v.x *= p.slope; v.x = InternalType<T>::clamp(v.x, p.clamp);\n                        if (v.y < 0.f) v.y *= p.slope; v.y = InternalType<T>::clamp(v.y, p.clamp);\n                        if (v.z < 0.f) v.z *= p.slope; v.z = InternalType<T>::clamp(v.z, p.clamp);\n                        if (v.w < 0.f) v.w *= p.slope; v.w = InternalType<T>::clamp(v.w, p.clamp);\n                    }\n\n                    s_tileUpXY[idx + 0] = v.x;\n                    s_tileUpXY[idx + 1] = v.y;\n                    s_tileUpXY[idx + 2] = v.z;\n                    s_tileUpXY[idx + 3] = v.w;\n                }\n            }\n            else if (up == 1)\n            {\n                __syncthreads();\n                uint32_t groupMask = 15 << ((threadIdx.x & 31) & ~3);\n                int minY = tileOutY ? (tileOutY - tileOutH) * down + tileUpH : 0; // Skip already written signs.\n                for (int idx = threadIdx.x; idx < tileUpW * tileUpH; idx += blockDim.x)\n                {\n                    int relUpX0, relUpY0;\n                    fast_div_mod<tileUpW>(relUpX0, relUpY0, idx);\n                    scalar_t v = s_tileIn[idx] * (scalar_t)c_fu[0]; // 1x1 filter.\n\n                    int x = tileOutX * down + relUpX0;\n                    int y = tileOutY * down + relUpY0;\n                    int signX = x + p.sOfs.x;\n                    int signY = y + p.sOfs.y;\n                    int signZ = blockIdx.z + p.blockZofs;\n                    int signXb = signX >> 2;\n                    index_t si = signXb + p.sShape.x * (signY + (index_t)p.sShape.y * signZ);\n                    v *= (scalar_t)((float)up * (float)up * p.gain);\n\n                    if (signWrite)\n                    {\n                        if (!enableWriteSkip)\n                        {\n                            // Determine and write sign.\n                            uint32_t s = 0;\n                            uint32_t signXbit = (1u << signXo);\n                            if (v < 0.f)\n                            {\n                                s = signXbit;\n                                v *= p.slope;\n                            }\n                            if (fabsf(v) > p.clamp)\n                            {\n                                s = signXbit * 2;\n                                v = InternalType<T>::clamp(v, p.clamp);\n                            }\n                            if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)\n                            {\n                                s += __shfl_xor_sync(groupMask, s, 1);  // Coalesce.\n                                s += __shfl_xor_sync(groupMask, s, 2);  // Coalesce.\n                                p.s[si] = s;                            // Write.\n                            }\n                        }\n                        else\n                        {\n                            // Determine and write sign.\n                            if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y && signY >= minY)\n                            {\n                                uint32_t s = 0;\n                                uint32_t signXbit = (1u << signXo);\n                                if (v < 0.f)\n                                {\n                                    s = signXbit;\n                                    v *= p.slope;\n                                }\n                                if (fabsf(v) > p.clamp)\n                                {\n                                    s = signXbit * 2;\n                                    v = InternalType<T>::clamp(v, p.clamp);\n                                }\n                                s += __shfl_xor_sync(groupMask, s, 1);  // Coalesce.\n                                s += __shfl_xor_sync(groupMask, s, 2);  // Coalesce.\n                                p.s[si] = s;                            // Write.\n                            }\n                            else\n                            {\n                                // Just compute the value.\n                                if (v < 0.f) v *= p.slope;\n                                v = InternalType<T>::clamp(v, p.clamp);\n                            }\n                        }\n                    }\n                    else if (signRead)\n                    {\n                        // Read sign and apply if within sign tensor bounds.\n                        if ((uint32_t)signXb < p.swLimit && (uint32_t)signY < p.sShape.y)\n                        {\n                            int s = p.s[si];\n                            s >>= signXo;\n                            if (s & 1) v *= p.slope;\n                            if (s & 2) v = 0.f;\n                        }\n                    }\n                    else // Forward pass with no sign write.\n                    {\n                        if (v < 0.f) v *= p.slope;\n                        v = InternalType<T>::clamp(v, p.clamp);\n                    }\n\n                    if (!downInline) // Write into temporary buffer.\n                        s_tileUpXY[idx] = v;\n                    else if ((uint32_t)x < p.yShape.x && (uint32_t)y < p.yShape.y) // Write directly into output buffer\n                        *((T*)((char*)p.y + (x * get_stride<index_t>(p.yStride.x) + y * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)(v * (scalar_t)c_fd[0]);\n                }\n            }\n        }\n\n        // Downsampling.\n        if (filterMode == MODE_SUSD || filterMode == MODE_FUSD)\n        {\n            // Horizontal downsampling.\n            __syncthreads();\n            if (down == 4 && tileOutW % 4 == 0)\n            {\n                // Calculate 4 pixels at a time.\n                for (int idx = threadIdx.x * 4; idx < tileOutW * tileUpH; idx += blockDim.x * 4)\n                {\n                    int relOutX0, relUpY;\n                    fast_div_mod<tileOutW>(relOutX0, relUpY, idx);\n                    int relUpX0 = relOutX0 * down;\n                    int src0 = relUpY * tileUpW + relUpX0;\n                    vec4_t v = InternalType<T>::zero_vec4();\n                    #pragma unroll\n                    for (int step = 0; step < fdSize; step++)\n                    {\n                        v.x += s_tileUpXY[src0 +  0 + step] * (scalar_t)c_fd[step];\n                        v.y += s_tileUpXY[src0 +  4 + step] * (scalar_t)c_fd[step];\n                        v.z += s_tileUpXY[src0 +  8 + step] * (scalar_t)c_fd[step];\n                        v.w += s_tileUpXY[src0 + 12 + step] * (scalar_t)c_fd[step];\n                    }\n                    s_tileDownX[idx+0] = v.x;\n                    s_tileDownX[idx+1] = v.y;\n                    s_tileDownX[idx+2] = v.z;\n                    s_tileDownX[idx+3] = v.w;\n                }\n            }\n            else if ((down == 2 || down == 4) && (tileOutW % 2 == 0))\n            {\n                // Calculate 2 pixels at a time.\n                for (int idx = threadIdx.x * 2; idx < tileOutW * tileUpH; idx += blockDim.x * 2)\n                {\n                    int relOutX0, relUpY;\n                    fast_div_mod<tileOutW>(relOutX0, relUpY, idx);\n                    int relUpX0 = relOutX0 * down;\n                    int src0 = relUpY * tileUpW + relUpX0;\n                    vec2_t v = InternalType<T>::zero_vec2();\n                    #pragma unroll\n                    for (int step = 0; step < fdSize; step++)\n                    {\n                        v.x += s_tileUpXY[src0 +    0 + step] * (scalar_t)c_fd[step];\n                        v.y += s_tileUpXY[src0 + down + step] * (scalar_t)c_fd[step];\n                    }\n                    s_tileDownX[idx+0] = v.x;\n                    s_tileDownX[idx+1] = v.y;\n                }\n            }\n            else\n            {\n                // Calculate 1 pixel at a time.\n                for (int idx = threadIdx.x; idx < tileOutW * tileUpH; idx += blockDim.x)\n                {\n                    int relOutX0, relUpY;\n                    fast_div_mod<tileOutW>(relOutX0, relUpY, idx);\n                    int relUpX0 = relOutX0 * down;\n                    int src = relUpY * tileUpW + relUpX0;\n                    scalar_t v = 0.f;\n                    #pragma unroll\n                    for (int step = 0; step < fdSize; step++)\n                        v += s_tileUpXY[src + step] * (scalar_t)c_fd[step];\n                    s_tileDownX[idx] = v;\n                }\n            }\n\n            // Vertical downsampling & store output tile.\n            __syncthreads();\n            for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)\n            {\n                int relOutX, relOutY0;\n                fast_div_mod<tileOutW>(relOutX, relOutY0, idx);\n                int relUpY0 = relOutY0 * down;\n                int src0 = relUpY0 * tileOutW + relOutX;\n                scalar_t v = 0;\n                #pragma unroll\n                for (int step = 0; step < fdSize; step++)\n                    v += s_tileDownX[src0 + step * tileOutW] * (scalar_t)c_fd[step];\n\n                int outX = tileOutX + relOutX;\n                int outY = tileOutY + relOutY0;\n\n                if (outX < p.yShape.x & outY < p.yShape.y)\n                    *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;\n            }\n        }\n        else if (filterMode == MODE_SUFD || filterMode == MODE_FUFD)\n        {\n            // Full downsampling filter.\n            if (down == 2)\n            {\n                // 2-wide.\n                __syncthreads();\n                for (int idx = threadIdx.x * 2; idx < tileOutW * tileOutH; idx += blockDim.x * 2)\n                {\n                    int relOutX0, relOutY0;\n                    fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);\n                    int relUpX0 = relOutX0 * down;\n                    int relUpY0 = relOutY0 * down;\n                    int src0 = relUpY0 * tileUpW + relUpX0;\n                    vec2_t v = InternalType<T>::zero_vec2();\n                    #pragma unroll\n                    for (int sy = 0; sy < fdSize; sy++)\n                    #pragma unroll\n                    for (int sx = 0; sx < fdSize; sx++)\n                    {\n                        v.x += s_tileUpXY[src0 + 0 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];\n                        v.y += s_tileUpXY[src0 + 2 + sx + sy * tileUpW] * (scalar_t)c_fd[sx + sy * MAX_FILTER_SIZE];\n                    }\n\n                    int outX = tileOutX + relOutX0;\n                    int outY = tileOutY + relOutY0;\n                    if ((uint32_t)outY < p.yShape.y)\n                    {\n                        index_t ofs = outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut;\n                        if (outX + 0 < p.yShape.x) *((T*)((char*)p.y + ofs)) = (T)v.x;\n                        if (outX + 1 < p.yShape.x) *((T*)((char*)p.y + ofs + get_stride<index_t>(p.yStride.x))) = (T)v.y;\n                    }\n                }\n            }\n            else if (down == 1 && !downInline)\n            {\n                // Thread per pixel.\n                __syncthreads();\n                for (int idx = threadIdx.x; idx < tileOutW * tileOutH; idx += blockDim.x)\n                {\n                    int relOutX0, relOutY0;\n                    fast_div_mod<tileOutW>(relOutX0, relOutY0, idx);\n                    scalar_t v = s_tileUpXY[idx] * (scalar_t)c_fd[0]; // 1x1 filter.\n\n                    int outX = tileOutX + relOutX0;\n                    int outY = tileOutY + relOutY0;\n                    if ((uint32_t)outX < p.yShape.x && (uint32_t)outY < p.yShape.y)\n                        *((T*)((char*)p.y + (outX * get_stride<index_t>(p.yStride.x) + outY * get_stride<index_t>(p.yStride.y) + mapOfsOut))) = (T)v;\n                }\n            }\n        }\n\n        if (!enableXrep)\n            break;\n    }\n}\n\n//------------------------------------------------------------------------\n// Compute activation function and signs for upsampled data tensor, modifying data tensor in-place. Used for accelerating the generic variant.\n// Sign tensor is known to be contiguous, and p.x and p.s have the same z, w dimensions. 64-bit indexing is always used.\n\ntemplate <class T, bool signWrite, bool signRead>\nstatic __global__ void filtered_lrelu_act_kernel(filtered_lrelu_act_kernel_params p)\n{\n    typedef typename InternalType<T>::scalar_t scalar_t;\n\n    // Indexing.\n    int32_t x = threadIdx.x + blockIdx.x * blockDim.x;\n    int32_t ymax = signWrite ? p.sShape.y : p.xShape.y;\n    int32_t qmax = p.xShape.z * p.xShape.w; // Combined minibatch*channel maximum index.\n\n    // Loop to accommodate oversized tensors.\n    for (int32_t q = blockIdx.z; q < qmax; q += gridDim.z)\n    for (int32_t y = blockIdx.y; y < ymax; y += gridDim.y)\n    {\n        // Extract z and w (channel, minibatch index).\n        int32_t w = q / p.xShape.z;\n        int32_t z = q - w * p.xShape.z;\n\n        // Choose behavior based on sign read/write mode.\n        if (signWrite)\n        {\n            // Process value if in p.x.\n            uint32_t s = 0;\n            if (x < p.xShape.x && y < p.xShape.y)\n            {\n                int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;\n                T* pv = ((T*)p.x) + ix;\n                scalar_t v = (scalar_t)(*pv);\n\n                // Gain, LReLU, clamp.\n                v *= p.gain;\n                if (v < 0.f)\n                {\n                    v *= p.slope;\n                    s = 1; // Sign.\n                }\n                if (fabsf(v) > p.clamp)\n                {\n                    v = InternalType<T>::clamp(v, p.clamp);\n                    s = 2; // Clamp.\n                }\n\n                *pv = (T)v; // Write value.\n            }\n\n            // Coalesce into threads 0 and 16 of warp.\n            uint32_t m = (threadIdx.x & 16) ? 0xffff0000u : 0x0000ffffu;\n            s <<= ((threadIdx.x & 15) << 1); // Shift into place.\n            s |= __shfl_xor_sync(m, s, 1); // Distribute.\n            s |= __shfl_xor_sync(m, s, 2);\n            s |= __shfl_xor_sync(m, s, 4);\n            s |= __shfl_xor_sync(m, s, 8);\n\n            // Write signs if leader and in p.s.\n            if (!(threadIdx.x & 15) && x < p.sShape.x) // y is always in.\n            {\n                uint64_t is = x + p.sShape.x * (y + (int64_t)p.sShape.y * q); // Contiguous.\n                ((uint32_t*)p.s)[is >> 4] = s;\n            }\n        }\n        else if (signRead)\n        {\n            // Process value if in p.x.\n            if (x < p.xShape.x) // y is always in.\n            {\n                int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;\n                T* pv = ((T*)p.x) + ix;\n                scalar_t v = (scalar_t)(*pv);\n                v *= p.gain;\n\n                // Apply sign buffer offset.\n                uint32_t sx = x + p.sOfs.x;\n                uint32_t sy = y + p.sOfs.y;\n\n                // Read and apply signs if we land inside valid region of sign buffer.\n                if (sx < p.sShape.x && sy < p.sShape.y)\n                {\n                    uint64_t is = (sx >> 2) + (p.sShape.x >> 2) * (sy + (uint64_t)p.sShape.y * q); // Contiguous.\n                    unsigned char s = p.s[is];\n                    s >>= (sx & 3) << 1; // Shift into place.\n                    if (s & 1) // Sign?\n                        v *= p.slope;\n                    if (s & 2) // Clamp?\n                        v = 0.f;\n                }\n\n                *pv = (T)v; // Write value.\n            }\n        }\n        else\n        {\n            // Forward pass with no sign write. Process value if in p.x.\n            if (x < p.xShape.x) // y is always in.\n            {\n                int64_t ix = x * p.xStride.x + y * p.xStride.y + z * p.xStride.z + w * p.xStride.w;\n                T* pv = ((T*)p.x) + ix;\n                scalar_t v = (scalar_t)(*pv);\n                v *= p.gain;\n                if (v < 0.f)\n                    v *= p.slope;\n                if (fabsf(v) > p.clamp)\n                    v = InternalType<T>::clamp(v, p.clamp);\n                *pv = (T)v; // Write value.\n            }\n        }\n    }\n}\n\ntemplate <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void)\n{\n    return (void*)filtered_lrelu_act_kernel<T, signWrite, signRead>;\n}\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB)\n{\n    filtered_lrelu_kernel_spec s = { 0 };\n\n    // Return the first matching kernel.\n#define CASE(SH, U, FU, D, FD, MODE, TW, TH, W, XR, WS) \\\n    if (sharedKB >= SH) \\\n    if ((p.fuShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_SUFD)) || (p.fuShape.y > 0 && (MODE == MODE_FUSD || MODE == MODE_FUFD))) \\\n    if ((p.fdShape.y == 0 && (MODE == MODE_SUSD || MODE == MODE_FUSD)) || (p.fdShape.y > 0 && (MODE == MODE_SUFD || MODE == MODE_FUFD))) \\\n    if (p.up == U && p.fuShape.x <= FU && p.fuShape.y <= FU && p.down == D && p.fdShape.x <= FD && p.fdShape.y <= FD) \\\n    { \\\n        static_assert((D*TW % 4) == 0, \"down * tileWidth must be divisible by 4\"); \\\n        static_assert(FU % U == 0, \"upscaling filter size must be multiple of upscaling factor\"); \\\n        static_assert(FD % D == 0, \"downscaling filter size must be multiple of downscaling factor\"); \\\n        s.setup = (void*)setup_filters_kernel; \\\n        s.exec = (void*)filtered_lrelu_kernel<T, index_t, SH, signWrite, signRead, MODE, U, FU, D, FD, TW, TH, W*32, !!XR, !!WS>; \\\n        s.tileOut = make_int2(TW, TH); \\\n        s.numWarps = W; \\\n        s.xrep = XR; \\\n        s.dynamicSharedKB = (SH == 48) ? 0 : SH; \\\n        return s; \\\n    }\n\n    // Launch parameters for various kernel specializations.\n    // Small filters must be listed before large filters, otherwise the kernel for larger filter will always match first.\n    // Kernels that use more shared memory must be listed before those that use less, for the same reason.\n\n    CASE(/*sharedKB*/48, /*up,fu*/1,1,  /*down,fd*/1,1,  /*mode*/MODE_FUFD, /*tw,th,warps,xrep,wskip*/64,  178, 32,  0,  0) // 1t-upf1-downf1\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/1,1,  /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/152, 95,  16,  0,  0) // 4t-ups2-downf1\n    CASE(/*sharedKB*/48, /*up,fu*/1,1,  /*down,fd*/2,8,  /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56,  22,  16,  0,  0) // 4t-upf1-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/2,8,  /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56,  29,  16,  11, 0) // 4t-ups2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/2,8,  /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/60,  28,  16,  0,  0) // 4t-upf2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/2,8,  /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56,  28,  16,  0,  0) // 4t-ups2-downf2\n    CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8,  /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/56,  31,  16,  11, 0) // 4t-ups4-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/4,16, /*down,fd*/2,8,  /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/56,  36,  16,  0,  0) // 4t-ups4-downf2\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/4,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16,  22,  16,  12, 0) // 4t-ups2-downs4\n    CASE(/*sharedKB*/48, /*up,fu*/2,8,  /*down,fd*/4,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/29,  15,  16,  0,  0) // 4t-upf2-downs4\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/1,1,  /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/96,  150, 28,  0,  0) // 6t-ups2-downf1\n    CASE(/*sharedKB*/48, /*up,fu*/1,1,  /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32,  35,  24,  0,  0) // 6t-upf1-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32,  46,  16,  10, 0) // 6t-ups2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/58,  28,  24,  8,  0) // 6t-upf2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/52,  28,  16,  0,  0) // 6t-ups2-downf2\n    CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32,  51,  16,  5,  0) // 6t-ups4-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/4,24, /*down,fd*/2,12, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32,  56,  16,  6,  0) // 6t-ups4-downf2\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16,  18,  16,  12, 0) // 6t-ups2-downs4\n    CASE(/*sharedKB*/96, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27,  31,  32,  6,  0) // 6t-upf2-downs4 96kB\n    CASE(/*sharedKB*/48, /*up,fu*/2,12, /*down,fd*/4,24, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/27,  13,  24,  0,  0) // 6t-upf2-downs4\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/1,1,  /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/148, 89,  24,  0,  0) // 8t-ups2-downf1\n    CASE(/*sharedKB*/48, /*up,fu*/1,1,  /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/32,  31,  16,  5,  0) // 8t-upf1-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32,  41,  16,  9,  0) // 8t-ups2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/56,  26,  24,  0,  0) // 8t-upf2-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32,  40,  16,  0,  0) // 8t-ups2-downf2\n    CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/32,  46,  24,  5,  0) // 8t-ups4-downs2\n    CASE(/*sharedKB*/48, /*up,fu*/4,32, /*down,fd*/2,16, /*mode*/MODE_SUFD, /*tw,th,warps,xrep,wskip*/32,  50,  16,  0,  0) // 8t-ups4-downf2\n    CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/24,  24,  32,  12, 1) // 8t-ups2-downs4 96kB\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_SUSD, /*tw,th,warps,xrep,wskip*/16,  13,  16,  10, 1) // 8t-ups2-downs4\n    CASE(/*sharedKB*/96, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25,  28,  28,  4,  0) // 8t-upf2-downs4 96kB\n    CASE(/*sharedKB*/48, /*up,fu*/2,16, /*down,fd*/4,32, /*mode*/MODE_FUSD, /*tw,th,warps,xrep,wskip*/25,  10,  24,  0,  0) // 8t-upf2-downs4\n\n    #undef CASE\n    return s; // No kernel found.\n}\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu.h",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <cuda_runtime.h>\n\n//------------------------------------------------------------------------\n// CUDA kernel parameters.\n\nstruct filtered_lrelu_kernel_params\n{\n    // These parameters decide which kernel to use.\n    int             up;         // upsampling ratio (1, 2, 4)\n    int             down;       // downsampling ratio (1, 2, 4)\n    int2            fuShape;    // [size, 1] | [size, size]\n    int2            fdShape;    // [size, 1] | [size, size]\n\n    int             _dummy;     // Alignment.\n\n    // Rest of the parameters.\n    const void*     x;          // Input tensor.\n    void*           y;          // Output tensor.\n    const void*     b;          // Bias tensor.\n    unsigned char*  s;          // Sign tensor in/out. NULL if unused.\n    const float*    fu;         // Upsampling filter.\n    const float*    fd;         // Downsampling filter.\n\n    int2            pad0;       // Left/top padding.\n    float           gain;       // Additional gain factor.\n    float           slope;      // Leaky ReLU slope on negative side.\n    float           clamp;      // Clamp after nonlinearity.\n    int             flip;       // Filter kernel flip for gradient computation.\n\n    int             tilesXdim;  // Original number of horizontal output tiles.\n    int             tilesXrep;  // Number of horizontal tiles per CTA.\n    int             blockZofs;  // Block z offset to support large minibatch, channel dimensions.\n\n    int4            xShape;     // [width, height, channel, batch]\n    int4            yShape;     // [width, height, channel, batch]\n    int2            sShape;     // [width, height] - width is in bytes. Contiguous. Zeros if unused.\n    int2            sOfs;       // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.\n    int             swLimit;    // Active width of sign tensor in bytes.\n\n    longlong4       xStride;    // Strides of all tensors except signs, same component order as shapes.\n    longlong4       yStride;    //\n    int64_t         bStride;    //\n    longlong3       fuStride;   //\n    longlong3       fdStride;   //\n};\n\nstruct filtered_lrelu_act_kernel_params\n{\n    void*           x;          // Input/output, modified in-place.\n    unsigned char*  s;          // Sign tensor in/out. NULL if unused.\n\n    float           gain;       // Additional gain factor.\n    float           slope;      // Leaky ReLU slope on negative side.\n    float           clamp;      // Clamp after nonlinearity.\n\n    int4            xShape;     // [width, height, channel, batch]\n    longlong4       xStride;    // Input/output tensor strides, same order as in shape.\n    int2            sShape;     // [width, height] - width is in elements. Contiguous. Zeros if unused.\n    int2            sOfs;       // [ofs_x, ofs_y] - offset between upsampled data and sign tensor.\n};\n\n//------------------------------------------------------------------------\n// CUDA kernel specialization.\n\nstruct filtered_lrelu_kernel_spec\n{\n    void*   setup;              // Function for filter kernel setup.\n    void*   exec;               // Function for main operation.\n    int2    tileOut;            // Width/height of launch tile.\n    int     numWarps;           // Number of warps per thread block, determines launch block size.\n    int     xrep;               // For processing multiple horizontal tiles per thread block.\n    int     dynamicSharedKB;    // How much dynamic shared memory the exec kernel wants.\n};\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T, class index_t, bool signWrite, bool signRead> filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate <class T, bool signWrite, bool signRead> void* choose_filtered_lrelu_act_kernel(void);\ntemplate <bool signWrite, bool signRead> cudaError_t copy_filters(cudaStream_t stream);\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport os\nimport numpy as np\nimport torch\nimport warnings\n\nfrom .. import custom_ops\nfrom .. import misc\nfrom . import upfirdn2d\nfrom . import bias_act\n\n#----------------------------------------------------------------------------\n\n_plugin = None\n\ndef _init():\n    global _plugin\n    if _plugin is None:\n        _plugin = custom_ops.get_plugin(\n            module_name='filtered_lrelu_plugin',\n            sources=['filtered_lrelu.cpp', 'filtered_lrelu_wr.cu', 'filtered_lrelu_rd.cu', 'filtered_lrelu_ns.cu'],\n            headers=['filtered_lrelu.h', 'filtered_lrelu.cu'],\n            source_dir=os.path.dirname(__file__),\n            extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],\n        )\n    return True\n\ndef _get_filter_size(f):\n    if f is None:\n        return 1, 1\n    assert isinstance(f, torch.Tensor)\n    assert 1 <= f.ndim <= 2\n    return f.shape[-1], f.shape[0] # width, height\n\ndef _parse_padding(padding):\n    if isinstance(padding, int):\n        padding = [padding, padding]\n    assert isinstance(padding, (list, tuple))\n    assert all(isinstance(x, (int, np.integer)) for x in padding)\n    padding = [int(x) for x in padding]\n    if len(padding) == 2:\n        px, py = padding\n        padding = [px, px, py, py]\n    px0, px1, py0, py1 = padding\n    return px0, px1, py0, py1\n\n#----------------------------------------------------------------------------\n\ndef filtered_lrelu(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False, impl='cuda'):\n    r\"\"\"Filtered leaky ReLU for a batch of 2D images.\n\n    Performs the following sequence of operations for each channel:\n\n    1. Add channel-specific bias if provided (`b`).\n\n    2. Upsample the image by inserting N-1 zeros after each pixel (`up`).\n\n    3. Pad the image with the specified number of zeros on each side (`padding`).\n       Negative padding corresponds to cropping the image.\n\n    4. Convolve the image with the specified upsampling FIR filter (`fu`), shrinking it\n       so that the footprint of all output pixels lies within the input image.\n\n    5. Multiply each value by the provided gain factor (`gain`).\n\n    6. Apply leaky ReLU activation function to each value.\n\n    7. Clamp each value between -clamp and +clamp, if `clamp` parameter is provided.\n\n    8. Convolve the image with the specified downsampling FIR filter (`fd`), shrinking\n       it so that the footprint of all output pixels lies within the input image.\n\n    9. Downsample the image by keeping every Nth pixel (`down`).\n\n    The fused op is considerably more efficient than performing the same calculation\n    using standard PyTorch ops. It supports gradients of arbitrary order.\n\n    Args:\n        x:           Float32/float16/float64 input tensor of the shape\n                     `[batch_size, num_channels, in_height, in_width]`.\n        fu:          Float32 upsampling FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        fd:          Float32 downsampling FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        b:           Bias vector, or `None` to disable. Must be a 1D tensor of the same type\n                     as `x`. The length of vector must must match the channel dimension of `x`.\n        up:          Integer upsampling factor (default: 1).\n        down:        Integer downsampling factor. (default: 1).\n        padding:     Padding with respect to the upsampled image. Can be a single number\n                     or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                     (default: 0).\n        gain:        Overall scaling factor for signal magnitude (default: sqrt(2)).\n        slope:       Slope on the negative side of leaky ReLU (default: 0.2).\n        clamp:       Maximum magnitude for leaky ReLU output (default: None).\n        flip_filter: False = convolution, True = correlation (default: False).\n        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    assert isinstance(x, torch.Tensor)\n    assert impl in ['ref', 'cuda']\n    if impl == 'cuda' and x.device.type == 'cuda' and _init():\n        return _filtered_lrelu_cuda(up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter).apply(x, fu, fd, b, None, 0, 0)\n    return _filtered_lrelu_ref(x, fu=fu, fd=fd, b=b, up=up, down=down, padding=padding, gain=gain, slope=slope, clamp=clamp, flip_filter=flip_filter)\n\n#----------------------------------------------------------------------------\n\n@misc.profiled_function\ndef _filtered_lrelu_ref(x, fu=None, fd=None, b=None, up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):\n    \"\"\"Slow and memory-inefficient reference implementation of `filtered_lrelu()` using\n    existing `upfirdn2n()` and `bias_act()` ops.\n    \"\"\"\n    assert isinstance(x, torch.Tensor) and x.ndim == 4\n    fu_w, fu_h = _get_filter_size(fu)\n    fd_w, fd_h = _get_filter_size(fd)\n    if b is not None:\n        assert isinstance(b, torch.Tensor) and b.dtype == x.dtype\n        misc.assert_shape(b, [x.shape[1]])\n    assert isinstance(up, int) and up >= 1\n    assert isinstance(down, int) and down >= 1\n    px0, px1, py0, py1 = _parse_padding(padding)\n    assert gain == float(gain) and gain > 0\n    assert slope == float(slope) and slope >= 0\n    assert clamp is None or (clamp == float(clamp) and clamp >= 0)\n\n    # Calculate output size.\n    batch_size, channels, in_h, in_w = x.shape\n    in_dtype = x.dtype\n    out_w = (in_w * up + (px0 + px1) - (fu_w - 1) - (fd_w - 1) + (down - 1)) // down\n    out_h = (in_h * up + (py0 + py1) - (fu_h - 1) - (fd_h - 1) + (down - 1)) // down\n\n    # Compute using existing ops.\n    x = bias_act.bias_act(x=x, b=b) # Apply bias.\n    x = upfirdn2d.upfirdn2d(x=x, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.\n    x = bias_act.bias_act(x=x, act='lrelu', alpha=slope, gain=gain, clamp=clamp) # Bias, leaky ReLU, clamp.\n    x = upfirdn2d.upfirdn2d(x=x, f=fd, down=down, flip_filter=flip_filter) # Downsample.\n\n    # Check output shape & dtype.\n    misc.assert_shape(x, [batch_size, channels, out_h, out_w])\n    assert x.dtype == in_dtype\n    return x\n\n#----------------------------------------------------------------------------\n\n_filtered_lrelu_cuda_cache = dict()\n\ndef _filtered_lrelu_cuda(up=1, down=1, padding=0, gain=np.sqrt(2), slope=0.2, clamp=None, flip_filter=False):\n    \"\"\"Fast CUDA implementation of `filtered_lrelu()` using custom ops.\n    \"\"\"\n    assert isinstance(up, int) and up >= 1\n    assert isinstance(down, int) and down >= 1\n    px0, px1, py0, py1 = _parse_padding(padding)\n    assert gain == float(gain) and gain > 0\n    gain = float(gain)\n    assert slope == float(slope) and slope >= 0\n    slope = float(slope)\n    assert clamp is None or (clamp == float(clamp) and clamp >= 0)\n    clamp = float(clamp if clamp is not None else 'inf')\n\n    # Lookup from cache.\n    key = (up, down, px0, px1, py0, py1, gain, slope, clamp, flip_filter)\n    if key in _filtered_lrelu_cuda_cache:\n        return _filtered_lrelu_cuda_cache[key]\n\n    # Forward op.\n    class FilteredLReluCuda(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, x, fu, fd, b, si, sx, sy): # pylint: disable=arguments-differ\n            assert isinstance(x, torch.Tensor) and x.ndim == 4\n\n            # Replace empty up/downsample kernels with full 1x1 kernels (faster than separable).\n            if fu is None:\n                fu = torch.ones([1, 1], dtype=torch.float32, device=x.device)\n            if fd is None:\n                fd = torch.ones([1, 1], dtype=torch.float32, device=x.device)\n            assert 1 <= fu.ndim <= 2\n            assert 1 <= fd.ndim <= 2\n\n            # Replace separable 1x1 kernels with full 1x1 kernels when scale factor is 1.\n            if up == 1 and fu.ndim == 1 and fu.shape[0] == 1:\n                fu = fu.square()[None]\n            if down == 1 and fd.ndim == 1 and fd.shape[0] == 1:\n                fd = fd.square()[None]\n\n            # Missing sign input tensor.\n            if si is None:\n                si = torch.empty([0])\n\n            # Missing bias tensor.\n            if b is None:\n                b = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device)\n\n            # Construct internal sign tensor only if gradients are needed.\n            write_signs = (si.numel() == 0) and (x.requires_grad or b.requires_grad)\n\n            # Warn if input storage strides are not in decreasing order due to e.g. channels-last layout.\n            strides = [x.stride(i) for i in range(x.ndim) if x.size(i) > 1]\n            if any(a < b for a, b in zip(strides[:-1], strides[1:])):\n                warnings.warn(\"low-performance memory layout detected in filtered_lrelu input\", RuntimeWarning)\n\n            # Call C++/Cuda plugin if datatype is supported.\n            if x.dtype in [torch.float16, torch.float32]:\n                if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device):\n                    warnings.warn(\"filtered_lrelu called with non-default cuda stream but concurrent execution is not supported\", RuntimeWarning)\n                y, so, return_code = _plugin.filtered_lrelu(x, fu, fd, b, si, up, down, px0, px1, py0, py1, sx, sy, gain, slope, clamp, flip_filter, write_signs)\n            else:\n                return_code = -1\n\n            # No Cuda kernel found? Fall back to generic implementation. Still more memory efficient than the reference implementation because\n            # only the bit-packed sign tensor is retained for gradient computation.\n            if return_code < 0:\n                warnings.warn(\"filtered_lrelu called with parameters that have no optimized CUDA kernel, using generic fallback\", RuntimeWarning)\n\n                y = x.add(b.unsqueeze(-1).unsqueeze(-1)) # Add bias.\n                y = upfirdn2d.upfirdn2d(x=y, f=fu, up=up, padding=[px0, px1, py0, py1], gain=up**2, flip_filter=flip_filter) # Upsample.\n                so = _plugin.filtered_lrelu_act_(y, si, sx, sy, gain, slope, clamp, write_signs) # Activation function and sign handling. Modifies y in-place.\n                y = upfirdn2d.upfirdn2d(x=y, f=fd, down=down, flip_filter=flip_filter) # Downsample.\n\n            # Prepare for gradient computation.\n            ctx.save_for_backward(fu, fd, (si if si.numel() else so))\n            ctx.x_shape = x.shape\n            ctx.y_shape = y.shape\n            ctx.s_ofs = sx, sy\n            return y\n\n        @staticmethod\n        def backward(ctx, dy): # pylint: disable=arguments-differ\n            fu, fd, si = ctx.saved_tensors\n            _, _, xh, xw = ctx.x_shape\n            _, _, yh, yw = ctx.y_shape\n            sx, sy = ctx.s_ofs\n            dx  = None # 0\n            dfu = None; assert not ctx.needs_input_grad[1]\n            dfd = None; assert not ctx.needs_input_grad[2]\n            db  = None # 3\n            dsi = None; assert not ctx.needs_input_grad[4]\n            dsx = None; assert not ctx.needs_input_grad[5]\n            dsy = None; assert not ctx.needs_input_grad[6]\n\n            if ctx.needs_input_grad[0] or ctx.needs_input_grad[3]:\n                pp = [\n                    (fu.shape[-1] - 1) + (fd.shape[-1] - 1) - px0,\n                    xw * up - yw * down + px0 - (up - 1),\n                    (fu.shape[0] - 1) + (fd.shape[0] - 1) - py0,\n                    xh * up - yh * down + py0 - (up - 1),\n                ]\n                gg = gain * (up ** 2) / (down ** 2)\n                ff = (not flip_filter)\n                sx = sx - (fu.shape[-1] - 1) + px0\n                sy = sy - (fu.shape[0]  - 1) + py0\n                dx = _filtered_lrelu_cuda(up=down, down=up, padding=pp, gain=gg, slope=slope, clamp=None, flip_filter=ff).apply(dy, fd, fu, None, si, sx, sy)\n\n            if ctx.needs_input_grad[3]:\n                db = dx.sum([0, 2, 3])\n\n            return dx, dfu, dfd, db, dsi, dsx, dsy\n\n    # Add to cache.\n    _filtered_lrelu_cuda_cache[key] = FilteredLReluCuda\n    return FilteredLReluCuda\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu_ns.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include \"filtered_lrelu.cu\"\n\n// Template/kernel specializations for no signs mode (no gradients required).\n\n// Full op, 32-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Full op, 64-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, false, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Activation/signs only for generic variant. 64-bit indexing.\ntemplate void* choose_filtered_lrelu_act_kernel<c10::Half, false, false>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<float,     false, false>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<double,    false, false>(void);\n\n// Copy filters to constant memory.\ntemplate cudaError_t copy_filters<false, false>(cudaStream_t stream);\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu_rd.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include \"filtered_lrelu.cu\"\n\n// Template/kernel specializations for sign read mode.\n\n// Full op, 32-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Full op, 64-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, false, true>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Activation/signs only for generic variant. 64-bit indexing.\ntemplate void* choose_filtered_lrelu_act_kernel<c10::Half, false, true>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<float,     false, true>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<double,    false, true>(void);\n\n// Copy filters to constant memory.\ntemplate cudaError_t copy_filters<false, true>(cudaStream_t stream);\n"
  },
  {
    "path": "ADD/th_utils/ops/filtered_lrelu_wr.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include \"filtered_lrelu.cu\"\n\n// Template/kernel specializations for sign write mode.\n\n// Full op, 32-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int32_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Full op, 64-bit indexing.\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<c10::Half, int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\ntemplate filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel<float,     int64_t, true, false>(const filtered_lrelu_kernel_params& p, int sharedKB);\n\n// Activation/signs only for generic variant. 64-bit indexing.\ntemplate void* choose_filtered_lrelu_act_kernel<c10::Half, true, false>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<float,     true, false>(void);\ntemplate void* choose_filtered_lrelu_act_kernel<double,    true, false>(void);\n\n// Copy filters to constant memory.\ntemplate cudaError_t copy_filters<true, false>(cudaStream_t stream);\n"
  },
  {
    "path": "ADD/th_utils/ops/fma.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.\"\"\"\n\nimport torch\n\n#----------------------------------------------------------------------------\n\ndef fma(a, b, c): # => a * b + c\n    return _FusedMultiplyAdd.apply(a, b, c)\n\n#----------------------------------------------------------------------------\n\nclass _FusedMultiplyAdd(torch.autograd.Function): # a * b + c\n    @staticmethod\n    def forward(ctx, a, b, c): # pylint: disable=arguments-differ\n        out = torch.addcmul(c, a, b)\n        ctx.save_for_backward(a, b)\n        ctx.c_shape = c.shape\n        return out\n\n    @staticmethod\n    def backward(ctx, dout): # pylint: disable=arguments-differ\n        a, b = ctx.saved_tensors\n        c_shape = ctx.c_shape\n        da = None\n        db = None\n        dc = None\n\n        if ctx.needs_input_grad[0]:\n            da = _unbroadcast(dout * b, a.shape)\n\n        if ctx.needs_input_grad[1]:\n            db = _unbroadcast(dout * a, b.shape)\n\n        if ctx.needs_input_grad[2]:\n            dc = _unbroadcast(dout, c_shape)\n\n        return da, db, dc\n\n#----------------------------------------------------------------------------\n\ndef _unbroadcast(x, shape):\n    extra_dims = x.ndim - len(shape)\n    assert extra_dims >= 0\n    dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)]\n    if len(dim):\n        x = x.sum(dim=dim, keepdim=True)\n    if extra_dims:\n        x = x.reshape(-1, *x.shape[extra_dims+1:])\n    assert x.shape == shape\n    return x\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/grid_sample_gradfix.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Custom replacement for `torch.nn.functional.grid_sample` that\nsupports arbitrarily high order gradients between the input and output.\nOnly works on 2D images and assumes\n`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.\"\"\"\n\nimport torch\nfrom pkg_resources import parse_version\n\n# pylint: disable=redefined-builtin\n# pylint: disable=arguments-differ\n# pylint: disable=protected-access\n\n#----------------------------------------------------------------------------\n\nenabled = False  # Enable the custom op by setting this to true.\n_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11\n\n#----------------------------------------------------------------------------\n\ndef grid_sample(input, grid):\n    if _should_use_custom_op():\n        return _GridSample2dForward.apply(input, grid)\n    return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)\n\n#----------------------------------------------------------------------------\n\ndef _should_use_custom_op():\n    return enabled\n\n#----------------------------------------------------------------------------\n\nclass _GridSample2dForward(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, input, grid):\n        assert input.ndim == 4\n        assert grid.ndim == 4\n        output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)\n        ctx.save_for_backward(input, grid)\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        input, grid = ctx.saved_tensors\n        grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)\n        return grad_input, grad_grid\n\n#----------------------------------------------------------------------------\n\nclass _GridSample2dBackward(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, grad_output, input, grid):\n        op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')\n        if _use_pytorch_1_11_api:\n            output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])\n            grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask)\n        else:\n            grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)\n        ctx.save_for_backward(grid)\n        return grad_input, grad_grid\n\n    @staticmethod\n    def backward(ctx, grad2_grad_input, grad2_grad_grid):\n        _ = grad2_grad_grid # unused\n        grid, = ctx.saved_tensors\n        grad2_grad_output = None\n        grad2_input = None\n        grad2_grid = None\n\n        if ctx.needs_input_grad[0]:\n            grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)\n\n        assert not ctx.needs_input_grad[2]\n        return grad2_grad_output, grad2_input, grad2_grid\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/upfirdn2d.cpp",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <torch/extension.h>\n#include <ATen/cuda/CUDAContext.h>\n#include <c10/cuda/CUDAGuard.h>\n#include \"upfirdn2d.h\"\n\n//------------------------------------------------------------------------\n\nstatic torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain)\n{\n    // Validate arguments.\n    TORCH_CHECK(x.is_cuda(), \"x must reside on CUDA device\");\n    TORCH_CHECK(f.device() == x.device(), \"f must reside on the same device as x\");\n    TORCH_CHECK(f.dtype() == torch::kFloat, \"f must be float32\");\n    TORCH_CHECK(x.numel() <= INT_MAX, \"x is too large\");\n    TORCH_CHECK(f.numel() <= INT_MAX, \"f is too large\");\n    TORCH_CHECK(x.numel() > 0, \"x has zero size\");\n    TORCH_CHECK(f.numel() > 0, \"f has zero size\");\n    TORCH_CHECK(x.dim() == 4, \"x must be rank 4\");\n    TORCH_CHECK(f.dim() == 2, \"f must be rank 2\");\n    TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, \"x memory footprint is too large\");\n    TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, \"f must be at least 1x1\");\n    TORCH_CHECK(upx >= 1 && upy >= 1, \"upsampling factor must be at least 1\");\n    TORCH_CHECK(downx >= 1 && downy >= 1, \"downsampling factor must be at least 1\");\n\n    // Create output tensor.\n    const at::cuda::OptionalCUDAGuard device_guard(device_of(x));\n    int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx;\n    int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy;\n    TORCH_CHECK(outW >= 1 && outH >= 1, \"output must be at least 1x1\");\n    torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format());\n    TORCH_CHECK(y.numel() <= INT_MAX, \"output is too large\");\n    TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, \"output memory footprint is too large\");\n\n    // Initialize CUDA kernel parameters.\n    upfirdn2d_kernel_params p;\n    p.x             = x.data_ptr();\n    p.f             = f.data_ptr<float>();\n    p.y             = y.data_ptr();\n    p.up            = make_int2(upx, upy);\n    p.down          = make_int2(downx, downy);\n    p.pad0          = make_int2(padx0, pady0);\n    p.flip          = (flip) ? 1 : 0;\n    p.gain          = gain;\n    p.inSize        = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0));\n    p.inStride      = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0));\n    p.filterSize    = make_int2((int)f.size(1), (int)f.size(0));\n    p.filterStride  = make_int2((int)f.stride(1), (int)f.stride(0));\n    p.outSize       = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0));\n    p.outStride     = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0));\n    p.sizeMajor     = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z;\n    p.sizeMinor     = (p.inStride.z == 1) ? p.inSize.z : 1;\n\n    // Choose CUDA kernel.\n    upfirdn2d_kernel_spec spec;\n    AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), \"upfirdn2d_cuda\", [&]\n    {\n        spec = choose_upfirdn2d_kernel<scalar_t>(p);\n    });\n\n    // Set looping options.\n    p.loopMajor     = (p.sizeMajor - 1) / 16384 + 1;\n    p.loopMinor     = spec.loopMinor;\n    p.loopX         = spec.loopX;\n    p.launchMinor   = (p.sizeMinor - 1) / p.loopMinor + 1;\n    p.launchMajor   = (p.sizeMajor - 1) / p.loopMajor + 1;\n\n    // Compute grid size.\n    dim3 blockSize, gridSize;\n    if (spec.tileOutW < 0) // large\n    {\n        blockSize = dim3(4, 32, 1);\n        gridSize = dim3(\n            ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor,\n            (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1,\n            p.launchMajor);\n    }\n    else // small\n    {\n        blockSize = dim3(256, 1, 1);\n        gridSize = dim3(\n            ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor,\n            (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1,\n            p.launchMajor);\n    }\n\n    // Launch CUDA kernel.\n    void* args[] = {&p};\n    AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));\n    return y;\n}\n\n//------------------------------------------------------------------------\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m)\n{\n    m.def(\"upfirdn2d\", &upfirdn2d);\n}\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/upfirdn2d.cu",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <c10/util/Half.h>\n#include \"upfirdn2d.h\"\n\n//------------------------------------------------------------------------\n// Helpers.\n\ntemplate <class T> struct InternalType;\ntemplate <> struct InternalType<double>     { typedef double scalar_t; };\ntemplate <> struct InternalType<float>      { typedef float  scalar_t; };\ntemplate <> struct InternalType<c10::Half>  { typedef float  scalar_t; };\n\nstatic __device__ __forceinline__ int floor_div(int a, int b)\n{\n    int t = 1 - a / b;\n    return (a + t * b) / b - t;\n}\n\n//------------------------------------------------------------------------\n// Generic CUDA implementation for large filters.\n\ntemplate <class T> static __global__ void upfirdn2d_kernel_large(upfirdn2d_kernel_params p)\n{\n    typedef typename InternalType<T>::scalar_t scalar_t;\n\n    // Calculate thread index.\n    int minorBase = blockIdx.x * blockDim.x + threadIdx.x;\n    int outY = minorBase / p.launchMinor;\n    minorBase -= outY * p.launchMinor;\n    int outXBase = blockIdx.y * p.loopX * blockDim.y + threadIdx.y;\n    int majorBase = blockIdx.z * p.loopMajor;\n    if (outXBase >= p.outSize.x | outY >= p.outSize.y | majorBase >= p.sizeMajor)\n        return;\n\n    // Setup Y receptive field.\n    int midY = outY * p.down.y + p.up.y - 1 - p.pad0.y;\n    int inY = min(max(floor_div(midY, p.up.y), 0), p.inSize.y);\n    int h = min(max(floor_div(midY + p.filterSize.y, p.up.y), 0), p.inSize.y) - inY;\n    int filterY = midY + p.filterSize.y - (inY + 1) * p.up.y;\n    if (p.flip)\n        filterY = p.filterSize.y - 1 - filterY;\n\n    // Loop over major, minor, and X.\n    for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)\n    for (int minorIdx = 0, minor = minorBase; minorIdx < p.loopMinor & minor < p.sizeMinor; minorIdx++, minor += p.launchMinor)\n    {\n        int nc = major * p.sizeMinor + minor;\n        int n = nc / p.inSize.z;\n        int c = nc - n * p.inSize.z;\n        for (int loopX = 0, outX = outXBase; loopX < p.loopX & outX < p.outSize.x; loopX++, outX += blockDim.y)\n        {\n            // Setup X receptive field.\n            int midX = outX * p.down.x + p.up.x - 1 - p.pad0.x;\n            int inX = min(max(floor_div(midX, p.up.x), 0), p.inSize.x);\n            int w = min(max(floor_div(midX + p.filterSize.x, p.up.x), 0), p.inSize.x) - inX;\n            int filterX = midX + p.filterSize.x - (inX + 1) * p.up.x;\n            if (p.flip)\n                filterX = p.filterSize.x - 1 - filterX;\n\n            // Initialize pointers.\n            const T* xp = &((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];\n            const float* fp = &p.f[filterX * p.filterStride.x + filterY * p.filterStride.y];\n            int filterStepX = ((p.flip) ? p.up.x : -p.up.x) * p.filterStride.x;\n            int filterStepY = ((p.flip) ? p.up.y : -p.up.y) * p.filterStride.y;\n\n            // Inner loop.\n            scalar_t v = 0;\n            for (int y = 0; y < h; y++)\n            {\n                for (int x = 0; x < w; x++)\n                {\n                    v += (scalar_t)(*xp) * (scalar_t)(*fp);\n                    xp += p.inStride.x;\n                    fp += filterStepX;\n                }\n                xp += p.inStride.y - w * p.inStride.x;\n                fp += filterStepY - w * filterStepX;\n            }\n\n            // Store result.\n            v *= p.gain;\n            ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;\n        }\n    }\n}\n\n//------------------------------------------------------------------------\n// Specialized CUDA implementation for small filters.\n\ntemplate <class T, int upx, int upy, int downx, int downy, int filterW, int filterH, int tileOutW, int tileOutH, int loopMinor>\nstatic __global__ void upfirdn2d_kernel_small(upfirdn2d_kernel_params p)\n{\n    typedef typename InternalType<T>::scalar_t scalar_t;\n    const int tileInW = ((tileOutW - 1) * downx + filterW - 1) / upx + 1;\n    const int tileInH = ((tileOutH - 1) * downy + filterH - 1) / upy + 1;\n    __shared__ volatile scalar_t sf[filterH][filterW];\n    __shared__ volatile scalar_t sx[tileInH][tileInW][loopMinor];\n\n    // Calculate tile index.\n    int minorBase = blockIdx.x;\n    int tileOutY = minorBase / p.launchMinor;\n    minorBase -= tileOutY * p.launchMinor;\n    minorBase *= loopMinor;\n    tileOutY *= tileOutH;\n    int tileOutXBase = blockIdx.y * p.loopX * tileOutW;\n    int majorBase = blockIdx.z * p.loopMajor;\n    if (tileOutXBase >= p.outSize.x | tileOutY >= p.outSize.y | majorBase >= p.sizeMajor)\n        return;\n\n    // Load filter (flipped).\n    for (int tapIdx = threadIdx.x; tapIdx < filterH * filterW; tapIdx += blockDim.x)\n    {\n        int fy = tapIdx / filterW;\n        int fx = tapIdx - fy * filterW;\n        scalar_t v = 0;\n        if (fx < p.filterSize.x & fy < p.filterSize.y)\n        {\n            int ffx = (p.flip) ? fx : p.filterSize.x - 1 - fx;\n            int ffy = (p.flip) ? fy : p.filterSize.y - 1 - fy;\n            v = (scalar_t)p.f[ffx * p.filterStride.x + ffy * p.filterStride.y];\n        }\n        sf[fy][fx] = v;\n    }\n\n    // Loop over major and X.\n    for (int majorIdx = 0, major = majorBase; majorIdx < p.loopMajor & major < p.sizeMajor; majorIdx++, major++)\n    {\n        int baseNC = major * p.sizeMinor + minorBase;\n        int n = baseNC / p.inSize.z;\n        int baseC = baseNC - n * p.inSize.z;\n        for (int loopX = 0, tileOutX = tileOutXBase; loopX < p.loopX & tileOutX < p.outSize.x; loopX++, tileOutX += tileOutW)\n        {\n            // Load input pixels.\n            int tileMidX = tileOutX * downx + upx - 1 - p.pad0.x;\n            int tileMidY = tileOutY * downy + upy - 1 - p.pad0.y;\n            int tileInX = floor_div(tileMidX, upx);\n            int tileInY = floor_div(tileMidY, upy);\n            __syncthreads();\n            for (int inIdx = threadIdx.x; inIdx < tileInH * tileInW * loopMinor; inIdx += blockDim.x)\n            {\n                int relC = inIdx;\n                int relInX = relC / loopMinor;\n                int relInY = relInX / tileInW;\n                relC -= relInX * loopMinor;\n                relInX -= relInY * tileInW;\n                int c = baseC + relC;\n                int inX = tileInX + relInX;\n                int inY = tileInY + relInY;\n                scalar_t v = 0;\n                if (inX >= 0 & inY >= 0 & inX < p.inSize.x & inY < p.inSize.y & c < p.inSize.z)\n                    v = (scalar_t)((const T*)p.x)[inX * p.inStride.x + inY * p.inStride.y + c * p.inStride.z + n * p.inStride.w];\n                sx[relInY][relInX][relC] = v;\n            }\n\n            // Loop over output pixels.\n            __syncthreads();\n            for (int outIdx = threadIdx.x; outIdx < tileOutH * tileOutW * loopMinor; outIdx += blockDim.x)\n            {\n                int relC = outIdx;\n                int relOutX = relC / loopMinor;\n                int relOutY = relOutX / tileOutW;\n                relC -= relOutX * loopMinor;\n                relOutX -= relOutY * tileOutW;\n                int c = baseC + relC;\n                int outX = tileOutX + relOutX;\n                int outY = tileOutY + relOutY;\n\n                // Setup receptive field.\n                int midX = tileMidX + relOutX * downx;\n                int midY = tileMidY + relOutY * downy;\n                int inX = floor_div(midX, upx);\n                int inY = floor_div(midY, upy);\n                int relInX = inX - tileInX;\n                int relInY = inY - tileInY;\n                int filterX = (inX + 1) * upx - midX - 1; // flipped\n                int filterY = (inY + 1) * upy - midY - 1; // flipped\n\n                // Inner loop.\n                if (outX < p.outSize.x & outY < p.outSize.y & c < p.outSize.z)\n                {\n                    scalar_t v = 0;\n                    #pragma unroll\n                    for (int y = 0; y < filterH / upy; y++)\n                        #pragma unroll\n                        for (int x = 0; x < filterW / upx; x++)\n                            v += sx[relInY + y][relInX + x][relC] * sf[filterY + y * upy][filterX + x * upx];\n                    v *= p.gain;\n                    ((T*)p.y)[outX * p.outStride.x + outY * p.outStride.y + c * p.outStride.z + n * p.outStride.w] = (T)v;\n                }\n            }\n        }\n    }\n}\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p)\n{\n    int s = p.inStride.z, fx = p.filterSize.x, fy = p.filterSize.y;\n    upfirdn2d_kernel_spec spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,1, 4}; // contiguous\n    if (s == 1)           spec = {(void*)upfirdn2d_kernel_large<T>, -1,-1,4, 1}; // channels_last\n\n    // No up/downsampling.\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 64,32,1>, 64,32,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 64,32,1>, 64,32,1, 1};\n        if (s != 1 && fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1,  128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1,  128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,   128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24,  32,32,1>, 32,32,1, 1};\n        if (s != 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16,  32,32,1>, 32,32,1, 1};\n        if (s != 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,   32,32,1>, 32,32,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,24, 32,32,1>,  32,32,1,  1};\n        if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,16, 32,32,1>,  32,32,1,  1};\n        if (s == 1 && fx <= 7  && fy <= 7 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 7,7,   16,16,8>,  16,16,8,  1};\n        if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 6,6,   16,16,8>,  16,16,8,  1};\n        if (s == 1 && fx <= 5  && fy <= 5 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 5,5,   16,16,8>,  16,16,8,  1};\n        if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 4,4,   16,16,8>,  16,16,8,  1};\n        if (s == 1 && fx <= 3  && fy <= 3 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 3,3,   16,16,8>,  16,16,8,  1};\n        if (s == 1 && fx <= 24 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 24,1,  128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 16 && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 16,1,  128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 8  && fy <= 1 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 8,1,   128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,24,  1,128,16>, 1,128,16, 1};\n        if (s == 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,16,  1,128,16>, 1,128,16, 1};\n        if (s == 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,1, 1,8,   1,128,16>, 1,128,16, 1};\n    }\n\n    // 2x upsampling.\n    if (p.up.x == 2 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 64,32,1>, 64,32,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 64,32,1>, 64,32,1, 1};\n        if (s != 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4,   64,16,1>, 64,16,1, 1};\n        if (s != 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2,   64,16,1>, 64,16,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 24,24, 32,32,1>, 32,32,1, 1};\n        if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 16,16, 32,32,1>, 32,32,1, 1};\n        if (s == 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 8,8,   16,16,8>, 16,16,8, 1};\n        if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 6,6,   16,16,8>, 16,16,8, 1};\n        if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 4,4,   16,16,8>, 16,16,8, 1};\n        if (s == 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 2,2, 1,1, 2,2,   16,16,8>, 16,16,8, 1};\n    }\n    if (p.up.x == 2 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,8,1>, 128,8,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 24,1, 128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 16,1, 128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 2,1, 1,1, 8,1,  128,1,16>, 128,1,16, 1};\n    }\n    if (p.up.x == 1 && p.up.y == 2 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 32,32,1>, 32,32,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 32,32,1>, 32,32,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  32,32,1>, 32,32,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,24, 1,128,16>, 1,128,16, 1};\n        if (s == 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,16, 1,128,16>, 1,128,16, 1};\n        if (s == 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,2, 1,1, 1,8,  1,128,16>, 1,128,16, 1};\n    }\n\n    // 2x downsampling.\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 2)\n    {\n        // contiguous\n        if (s != 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 32,16,1>, 32,16,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 32,16,1>, 32,16,1, 1};\n        if (s != 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,   32,8,1>,  32,8,1,  1};\n        if (s != 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,   32,8,1>,  32,8,1,  1};\n        if (s != 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,   32,8,1>,  32,8,1,  1};\n        if (s != 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,   32,8,1>,  32,8,1,  1};\n        // channels_last\n        if (s == 1 && fx <= 24 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 24,24, 16,16,1>, 16,16,1, 1};\n        if (s == 1 && fx <= 16 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 16,16, 16,16,1>, 16,16,1, 1};\n        if (s == 1 && fx <= 8  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 8,8,   8,8,8>,   8,8,8,   1};\n        if (s == 1 && fx <= 6  && fy <= 6 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 6,6,   8,8,8>,   8,8,8,   1};\n        if (s == 1 && fx <= 4  && fy <= 4 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 4,4,   8,8,8>,   8,8,8,   1};\n        if (s == 1 && fx <= 2  && fy <= 2 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,2, 2,2,   8,8,8>,   8,8,8,   1};\n    }\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 2 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,8,1>, 64,8,1, 1};\n        if (s != 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,8,1>, 64,8,1, 1};\n        if (s != 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,8,1>, 64,8,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 24 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 24,1, 64,1,8>, 64,1,8, 1};\n        if (s == 1 && fx <= 16 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 16,1, 64,1,8>, 64,1,8, 1};\n        if (s == 1 && fx <= 8  && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 2,1, 8,1,  64,1,8>, 64,1,8, 1};\n    }\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 2)\n    {\n        // contiguous\n        if (s != 1 && fx <= 1 && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 32,16,1>, 32,16,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 32,16,1>, 32,16,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  32,16,1>, 32,16,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 1  && fy <= 24) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,24, 1,64,8>, 1,64,8, 1};\n        if (s == 1 && fx <= 1  && fy <= 16) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,16, 1,64,8>, 1,64,8, 1};\n        if (s == 1 && fx <= 1  && fy <= 8 ) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,2, 1,8,  1,64,8>, 1,64,8, 1};\n    }\n\n    // 4x upsampling.\n    if (p.up.x == 4 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 64,32,1>, 64,32,1, 1};\n        if (s != 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 64,32,1>, 64,32,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 48 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 48,48, 32,32,1>, 32,32,1, 1};\n        if (s == 1 && fx <= 32 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 4,4, 1,1, 32,32, 32,32,1>, 32,32,1, 1};\n    }\n    if (p.up.x == 4 && p.up.y == 1 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,8,1>, 128,8,1, 1};\n        if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,8,1>, 128,8,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 48,1, 128,1,16>, 128,1,16, 1};\n        if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 4,1, 1,1, 32,1, 128,1,16>, 128,1,16, 1};\n    }\n    if (p.up.x == 1 && p.up.y == 4 && p.down.x == 1 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 32,32,1>, 32,32,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 32,32,1>, 32,32,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,48, 1,128,16>, 1,128,16, 1};\n        if (s == 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,4, 1,1, 1,32, 1,128,16>, 1,128,16, 1};\n    }\n\n    // 4x downsampling (inefficient).\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 4 && p.down.y == 1)\n    {\n        // contiguous\n        if (s != 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,8,1>, 32,8,1, 1};\n        if (s != 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,8,1>, 32,8,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 48 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 48,1, 32,1,8>, 32,1,8, 1};\n        if (s == 1 && fx <= 32 && fy <= 1) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 4,1, 32,1, 32,1,8>, 32,1,8, 1};\n    }\n    if (p.up.x == 1 && p.up.y == 1 && p.down.x == 1 && p.down.y == 4)\n    {\n        // contiguous\n        if (s != 1 && fx <= 1 && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 32,8,1>, 32,8,1, 1};\n        if (s != 1 && fx <= 1 && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 32,8,1>, 32,8,1, 1};\n        // channels_last\n        if (s == 1 && fx <= 1  && fy <= 48) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,48, 1,32,8>, 1,32,8, 1};\n        if (s == 1 && fx <= 1  && fy <= 32) spec = {(void*)upfirdn2d_kernel_small<T, 1,1, 1,4, 1,32, 1,32,8>, 1,32,8, 1};\n    }\n    return spec;\n}\n\n//------------------------------------------------------------------------\n// Template specializations.\n\ntemplate upfirdn2d_kernel_spec choose_upfirdn2d_kernel<double>   (const upfirdn2d_kernel_params& p);\ntemplate upfirdn2d_kernel_spec choose_upfirdn2d_kernel<float>    (const upfirdn2d_kernel_params& p);\ntemplate upfirdn2d_kernel_spec choose_upfirdn2d_kernel<c10::Half>(const upfirdn2d_kernel_params& p);\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/upfirdn2d.h",
    "content": "// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain all intellectual property\n// and proprietary rights in and to this software, related documentation\n// and any modifications thereto.  Any use, reproduction, disclosure or\n// distribution of this software and related documentation without an express\n// license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n#include <cuda_runtime.h>\n\n//------------------------------------------------------------------------\n// CUDA kernel parameters.\n\nstruct upfirdn2d_kernel_params\n{\n    const void*     x;\n    const float*    f;\n    void*           y;\n\n    int2            up;\n    int2            down;\n    int2            pad0;\n    int             flip;\n    float           gain;\n\n    int4            inSize;         // [width, height, channel, batch]\n    int4            inStride;\n    int2            filterSize;     // [width, height]\n    int2            filterStride;\n    int4            outSize;        // [width, height, channel, batch]\n    int4            outStride;\n    int             sizeMinor;\n    int             sizeMajor;\n\n    int             loopMinor;\n    int             loopMajor;\n    int             loopX;\n    int             launchMinor;\n    int             launchMajor;\n};\n\n//------------------------------------------------------------------------\n// CUDA kernel specialization.\n\nstruct upfirdn2d_kernel_spec\n{\n    void*   kernel;\n    int     tileOutW;\n    int     tileOutH;\n    int     loopMinor;\n    int     loopX;\n};\n\n//------------------------------------------------------------------------\n// CUDA kernel selection.\n\ntemplate <class T> upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p);\n\n//------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/th_utils/ops/upfirdn2d.py",
    "content": "# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Custom PyTorch ops for efficient resampling of 2D images.\"\"\"\n\nimport os\nimport numpy as np\nimport torch\n\nfrom .. import custom_ops\nfrom .. import misc\nfrom . import conv2d_gradfix\n\n#----------------------------------------------------------------------------\n\n_plugin = None\n\ndef _init():\n    global _plugin\n    if _plugin is None:\n        _plugin = custom_ops.get_plugin(\n            module_name='upfirdn2d_plugin',\n            sources=['upfirdn2d.cpp', 'upfirdn2d.cu'],\n            headers=['upfirdn2d.h'],\n            source_dir=os.path.dirname(__file__),\n            extra_cuda_cflags=['--use_fast_math', '--allow-unsupported-compiler'],\n        )\n    return True\n\ndef _parse_scaling(scaling):\n    if isinstance(scaling, int):\n        scaling = [scaling, scaling]\n    assert isinstance(scaling, (list, tuple))\n    assert all(isinstance(x, int) for x in scaling)\n    sx, sy = scaling\n    assert sx >= 1 and sy >= 1\n    return sx, sy\n\ndef _parse_padding(padding):\n    if isinstance(padding, int):\n        padding = [padding, padding]\n    assert isinstance(padding, (list, tuple))\n    assert all(isinstance(x, int) for x in padding)\n    if len(padding) == 2:\n        padx, pady = padding\n        padding = [padx, padx, pady, pady]\n    padx0, padx1, pady0, pady1 = padding\n    return padx0, padx1, pady0, pady1\n\ndef _get_filter_size(f):\n    if f is None:\n        return 1, 1\n    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]\n    fw = f.shape[-1]\n    fh = f.shape[0]\n    with misc.suppress_tracer_warnings():\n        fw = int(fw)\n        fh = int(fh)\n    misc.assert_shape(f, [fh, fw][:f.ndim])\n    assert fw >= 1 and fh >= 1\n    return fw, fh\n\n#----------------------------------------------------------------------------\n\ndef setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):\n    r\"\"\"Convenience function to setup 2D FIR filter for `upfirdn2d()`.\n\n    Args:\n        f:           Torch tensor, numpy array, or python list of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable),\n                     `[]` (impulse), or\n                     `None` (identity).\n        device:      Result device (default: cpu).\n        normalize:   Normalize the filter so that it retains the magnitude\n                     for constant input signal (DC)? (default: True).\n        flip_filter: Flip the filter? (default: False).\n        gain:        Overall scaling factor for signal magnitude (default: 1).\n        separable:   Return a separable filter? (default: select automatically).\n\n    Returns:\n        Float32 tensor of the shape\n        `[filter_height, filter_width]` (non-separable) or\n        `[filter_taps]` (separable).\n    \"\"\"\n    # Validate.\n    if f is None:\n        f = 1\n    f = torch.as_tensor(f, dtype=torch.float32)\n    assert f.ndim in [0, 1, 2]\n    assert f.numel() > 0\n    if f.ndim == 0:\n        f = f[np.newaxis]\n\n    # Separable?\n    if separable is None:\n        separable = (f.ndim == 1 and f.numel() >= 8)\n    if f.ndim == 1 and not separable:\n        f = f.ger(f)\n    assert f.ndim == (1 if separable else 2)\n\n    # Apply normalize, flip, gain, and device.\n    if normalize:\n        f /= f.sum()\n    if flip_filter:\n        f = f.flip(list(range(f.ndim)))\n    f = f * (gain ** (f.ndim / 2))\n    f = f.to(device=device)\n    return f\n\n#----------------------------------------------------------------------------\n\ndef upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):\n    r\"\"\"Pad, upsample, filter, and downsample a batch of 2D images.\n\n    Performs the following sequence of operations for each channel:\n\n    1. Upsample the image by inserting N-1 zeros after each pixel (`up`).\n\n    2. Pad the image with the specified number of zeros on each side (`padding`).\n       Negative padding corresponds to cropping the image.\n\n    3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it\n       so that the footprint of all output pixels lies within the input image.\n\n    4. Downsample the image by keeping every Nth pixel (`down`).\n\n    This sequence of operations bears close resemblance to scipy.signal.upfirdn().\n    The fused op is considerably more efficient than performing the same calculation\n    using standard PyTorch ops. It supports gradients of arbitrary order.\n\n    Args:\n        x:           Float32/float64/float16 input tensor of the shape\n                     `[batch_size, num_channels, in_height, in_width]`.\n        f:           Float32 FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        up:          Integer upsampling factor. Can be a single int or a list/tuple\n                     `[x, y]` (default: 1).\n        down:        Integer downsampling factor. Can be a single int or a list/tuple\n                     `[x, y]` (default: 1).\n        padding:     Padding with respect to the upsampled image. Can be a single number\n                     or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                     (default: 0).\n        flip_filter: False = convolution, True = correlation (default: False).\n        gain:        Overall scaling factor for signal magnitude (default: 1).\n        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    assert isinstance(x, torch.Tensor)\n    assert impl in ['ref', 'cuda']\n    if impl == 'cuda' and x.device.type == 'cuda' and _init():\n        return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)\n    return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)\n\n#----------------------------------------------------------------------------\n\n@misc.profiled_function\ndef _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):\n    \"\"\"Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.\n    \"\"\"\n    # Validate arguments.\n    assert isinstance(x, torch.Tensor) and x.ndim == 4\n    if f is None:\n        f = torch.ones([1, 1], dtype=torch.float32, device=x.device)\n    assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]\n    assert f.dtype == torch.float32 and not f.requires_grad\n    batch_size, num_channels, in_height, in_width = x.shape\n    upx, upy = _parse_scaling(up)\n    downx, downy = _parse_scaling(down)\n    padx0, padx1, pady0, pady1 = _parse_padding(padding)\n\n    # Check that upsampled buffer is not smaller than the filter.\n    upW = in_width * upx + padx0 + padx1\n    upH = in_height * upy + pady0 + pady1\n    assert upW >= f.shape[-1] and upH >= f.shape[0]\n\n    # Upsample by inserting zeros.\n    x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])\n    x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])\n    x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])\n\n    # Pad or crop.\n    x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])\n    x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]\n\n    # Setup filter.\n    f = f * (gain ** (f.ndim / 2))\n    f = f.to(x.dtype)\n    if not flip_filter:\n        f = f.flip(list(range(f.ndim)))\n\n    # Convolve with the filter.\n    f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)\n    if f.ndim == 4:\n        x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)\n    else:\n        x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)\n        x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)\n\n    # Downsample by throwing away pixels.\n    x = x[:, :, ::downy, ::downx]\n    return x\n\n#----------------------------------------------------------------------------\n\n_upfirdn2d_cuda_cache = dict()\n\ndef _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):\n    \"\"\"Fast CUDA implementation of `upfirdn2d()` using custom ops.\n    \"\"\"\n    # Parse arguments.\n    upx, upy = _parse_scaling(up)\n    downx, downy = _parse_scaling(down)\n    padx0, padx1, pady0, pady1 = _parse_padding(padding)\n\n    # Lookup from cache.\n    key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)\n    if key in _upfirdn2d_cuda_cache:\n        return _upfirdn2d_cuda_cache[key]\n\n    # Forward op.\n    class Upfirdn2dCuda(torch.autograd.Function):\n        @staticmethod\n        def forward(ctx, x, f): # pylint: disable=arguments-differ\n            assert isinstance(x, torch.Tensor) and x.ndim == 4\n            if f is None:\n                f = torch.ones([1, 1], dtype=torch.float32, device=x.device)\n            if f.ndim == 1 and f.shape[0] == 1:\n                f = f.square().unsqueeze(0) # Convert separable-1 into full-1x1.\n            assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]\n            y = x\n            if f.ndim == 2:\n                y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)\n            else:\n                y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, 1.0)\n                y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, gain)\n            ctx.save_for_backward(f)\n            ctx.x_shape = x.shape\n            return y\n\n        @staticmethod\n        def backward(ctx, dy): # pylint: disable=arguments-differ\n            f, = ctx.saved_tensors\n            _, _, ih, iw = ctx.x_shape\n            _, _, oh, ow = dy.shape\n            fw, fh = _get_filter_size(f)\n            p = [\n                fw - padx0 - 1,\n                iw * upx - ow * downx + padx0 - upx + 1,\n                fh - pady0 - 1,\n                ih * upy - oh * downy + pady0 - upy + 1,\n            ]\n            dx = None\n            df = None\n\n            if ctx.needs_input_grad[0]:\n                dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)\n\n            assert not ctx.needs_input_grad[1]\n            return dx, df\n\n    # Add to cache.\n    _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda\n    return Upfirdn2dCuda\n\n#----------------------------------------------------------------------------\n\ndef filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):\n    r\"\"\"Filter a batch of 2D images using the given 2D FIR filter.\n\n    By default, the result is padded so that its shape matches the input.\n    User-specified padding is applied on top of that, with negative values\n    indicating cropping. Pixels outside the image are assumed to be zero.\n\n    Args:\n        x:           Float32/float64/float16 input tensor of the shape\n                     `[batch_size, num_channels, in_height, in_width]`.\n        f:           Float32 FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        padding:     Padding with respect to the output. Can be a single number or a\n                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                     (default: 0).\n        flip_filter: False = convolution, True = correlation (default: False).\n        gain:        Overall scaling factor for signal magnitude (default: 1).\n        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    padx0, padx1, pady0, pady1 = _parse_padding(padding)\n    fw, fh = _get_filter_size(f)\n    p = [\n        padx0 + fw // 2,\n        padx1 + (fw - 1) // 2,\n        pady0 + fh // 2,\n        pady1 + (fh - 1) // 2,\n    ]\n    return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)\n\n#----------------------------------------------------------------------------\n\ndef upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):\n    r\"\"\"Upsample a batch of 2D images using the given 2D FIR filter.\n\n    By default, the result is padded so that its shape is a multiple of the input.\n    User-specified padding is applied on top of that, with negative values\n    indicating cropping. Pixels outside the image are assumed to be zero.\n\n    Args:\n        x:           Float32/float64/float16 input tensor of the shape\n                     `[batch_size, num_channels, in_height, in_width]`.\n        f:           Float32 FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        up:          Integer upsampling factor. Can be a single int or a list/tuple\n                     `[x, y]` (default: 1).\n        padding:     Padding with respect to the output. Can be a single number or a\n                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                     (default: 0).\n        flip_filter: False = convolution, True = correlation (default: False).\n        gain:        Overall scaling factor for signal magnitude (default: 1).\n        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    upx, upy = _parse_scaling(up)\n    padx0, padx1, pady0, pady1 = _parse_padding(padding)\n    fw, fh = _get_filter_size(f)\n    p = [\n        padx0 + (fw + upx - 1) // 2,\n        padx1 + (fw - upx) // 2,\n        pady0 + (fh + upy - 1) // 2,\n        pady1 + (fh - upy) // 2,\n    ]\n    return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)\n\n#----------------------------------------------------------------------------\n\ndef downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):\n    r\"\"\"Downsample a batch of 2D images using the given 2D FIR filter.\n\n    By default, the result is padded so that its shape is a fraction of the input.\n    User-specified padding is applied on top of that, with negative values\n    indicating cropping. Pixels outside the image are assumed to be zero.\n\n    Args:\n        x:           Float32/float64/float16 input tensor of the shape\n                     `[batch_size, num_channels, in_height, in_width]`.\n        f:           Float32 FIR filter of the shape\n                     `[filter_height, filter_width]` (non-separable),\n                     `[filter_taps]` (separable), or\n                     `None` (identity).\n        down:        Integer downsampling factor. Can be a single int or a list/tuple\n                     `[x, y]` (default: 1).\n        padding:     Padding with respect to the input. Can be a single number or a\n                     list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`\n                     (default: 0).\n        flip_filter: False = convolution, True = correlation (default: False).\n        gain:        Overall scaling factor for signal magnitude (default: 1).\n        impl:        Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).\n\n    Returns:\n        Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.\n    \"\"\"\n    downx, downy = _parse_scaling(down)\n    padx0, padx1, pady0, pady1 = _parse_padding(padding)\n    fw, fh = _get_filter_size(f)\n    p = [\n        padx0 + (fw - downx + 1) // 2,\n        padx1 + (fw - downx) // 2,\n        pady0 + (fh - downy + 1) // 2,\n        pady1 + (fh - downy) // 2,\n    ]\n    return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)\n\n#----------------------------------------------------------------------------\n"
  },
  {
    "path": "ADD/utils/util_net.py",
    "content": "#!/usr/bin/env python\n# -*- coding:utf-8 -*-\n# Power by Zongsheng Yue 2021-11-24 20:29:36\n\nimport math\nimport torch\nfrom pathlib import Path\nfrom collections import OrderedDict\nimport torch.nn.functional as F\nfrom copy import deepcopy\n\ndef calculate_parameters(net):\n    out = 0\n    for param in net.parameters():\n        out += param.numel()\n    return out\n\ndef pad_input(x, mod):\n    h, w = x.shape[-2:]\n    bottom = int(math.ceil(h/mod)*mod -h)\n    right = int(math.ceil(w/mod)*mod - w)\n    x_pad = F.pad(x, pad=(0, right, 0, bottom), mode='reflect')\n    return x_pad\n\ndef forward_chop(net, x, net_kwargs=None, scale=1, shave=10, min_size=160000):\n    n_GPUs = 1\n    b, c, h, w = x.size()\n    h_half, w_half = h // 2, w // 2\n    h_size, w_size = h_half + shave, w_half + shave\n    lr_list = [\n        x[:, :, 0:h_size, 0:w_size],\n        x[:, :, 0:h_size, (w - w_size):w],\n        x[:, :, (h - h_size):h, 0:w_size],\n        x[:, :, (h - h_size):h, (w - w_size):w]]\n\n    if w_size * h_size < min_size:\n        sr_list = []\n        for i in range(0, 4, n_GPUs):\n            lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0)\n            if net_kwargs is None:\n                sr_batch = net(lr_batch)\n            else:\n                sr_batch = net(lr_batch, **net_kwargs)\n            sr_list.extend(sr_batch.chunk(n_GPUs, dim=0))\n    else:\n        sr_list = [\n            forward_chop(patch, shave=shave, min_size=min_size) \\\n            for patch in lr_list\n        ]\n\n    h, w = scale * h, scale * w\n    h_half, w_half = scale * h_half, scale * w_half\n    h_size, w_size = scale * h_size, scale * w_size\n    shave *= scale\n\n    output = x.new(b, c, h, w)\n    output[:, :, 0:h_half, 0:w_half] \\\n        = sr_list[0][:, :, 0:h_half, 0:w_half]\n    output[:, :, 0:h_half, w_half:w] \\\n        = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size]\n    output[:, :, h_half:h, 0:w_half] \\\n        = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half]\n    output[:, :, h_half:h, w_half:w] \\\n        = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]\n\n    return output\n\ndef measure_time(net, inputs, num_forward=100):\n    '''\n    Measuring the average runing time (seconds) for pytorch.\n    out = net(*inputs)\n    '''\n    start = torch.cuda.Event(enable_timing=True)\n    end = torch.cuda.Event(enable_timing=True)\n\n    start.record()\n    with torch.set_grad_enabled(False):\n        for _ in range(num_forward):\n            out = net(*inputs)\n    end.record()\n\n    torch.cuda.synchronize()\n\n    return start.elapsed_time(end) / 1000\n\ndef reload_model(model, ckpt):\n    if list(model.state_dict().keys())[0].startswith('module.'):\n        if list(ckpt.keys())[0].startswith('module.'):\n            ckpt = ckpt\n        else:\n            ckpt = OrderedDict({f'module.{key}':value for key, value in ckpt.items()})\n    else:\n        if list(ckpt.keys())[0].startswith('module.'):\n            ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})\n        else:\n            ckpt = ckpt\n    model.load_state_dict(ckpt, True)\n\ndef compute_hinge_loss(real_output, fake_output, x_start_, r1_lambda):\n    if r1_lambda == 0:\n        real_loss_total = torch.relu(torch.ones_like(real_output) - real_output).mean()\n        fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()\n\n    else:\n        real_loss_ = torch.relu(torch.ones_like(real_output) - real_output).mean()\n\n        # 计算真实样本的梯度\n        grad_real = torch.autograd.grad(outputs=real_output.sum(), inputs=x_start_, create_graph=True)[0]\n\n        # 计算梯度惩罚\n        grad_penalty = (grad_real.contiguous().view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean() * r1_lambda\n\n        real_loss_total = real_loss_ + grad_penalty\n        fake_loss_total = torch.relu(torch.ones_like(fake_output) + fake_output).mean()\n\n    real_loss = real_loss_total\n    fake_loss = fake_loss_total\n\n    loss_d = real_loss + fake_loss\n\n    return loss_d\n\n\n\ndef reload_model_(model, ckpt):\n    if list(model.state_dict().keys())[0].startswith('model.'):\n        if list(ckpt.keys())[0].startswith('model.'):\n            ckpt = ckpt\n        else:\n            ckpt = OrderedDict({f'model.{key}':value for key, value in ckpt.items()})\n    else:\n        if list(ckpt.keys())[0].startswith('model.'):\n            ckpt = OrderedDict({key[7:]:value for key, value in ckpt.items()})\n        else:\n            ckpt = ckpt\n    model.load_state_dict(ckpt, True)\n\n\n\ndef reload_model_IDE(model, ckpt):\n    extracted_dict = OrderedDict()\n    for key, value in ckpt.items():\n        if key.startswith('E_st'):\n            new_key = key.replace('E_st.', '')\n            extracted_dict[new_key] = value\n\n    model.load_state_dict(extracted_dict, True)\n\n\n\nclass EMA():\n    def __init__(self, model, decay):\n        self.model = model\n        self.decay = decay\n        self.shadow = {}\n        self.backup = {}\n\n    def register(self):\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                self.shadow[name] = param.data.clone()\n\n    def update(self):\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                assert name in self.shadow\n                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]\n                self.shadow[name] = new_average.clone()\n\n    def apply_shadow(self):\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                assert name in self.shadow\n                self.backup[name] = param.data\n                param.data = self.shadow[name]\n\n    def restore(self):\n        for name, param in self.model.named_parameters():\n            if param.requires_grad:\n                assert name in self.backup\n                param.data = self.backup[name]\n        self.backup = {}\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n    <img src=\"figs/logo.png\" width=\"400\">\n</p>\n\n<div align=\"center\">\n<h2>Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution</h2>\n\n\n<a href='https://arxiv.org/pdf/2401.00877'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> \n\n\n[Lingchen Sun](https://scholar.google.com/citations?hl=zh-CN&tzom=-480&user=ZCDjTn8AAAAJ)<sup>1,2</sup>\n| [Rongyuan Wu](https://scholar.google.com/citations?user=A-U8zE8AAAAJ&hl=zh-CN)<sup>1,2</sup> | \n[Jie Liang](https://scholar.google.com.sg/citations?user=REWxLZsAAAAJ&hl)<sup>2</sup> |\n[Zhengqiang Zhang](https://scholar.google.com/citations?hl=zh-CN&user=UX26wSMAAAAJ&view_op=list_works&sortby=pubdate)<sup>1,2</sup> | \n[Hongwei Yong](https://scholar.google.com.hk/citations?user=Xii74qQAAAAJ&hl=zh-CN)<sup>1</sup> | \n[Lei Zhang](https://www4.comp.polyu.edu.hk/~cslzhang)<sup>1,2</sup>\n\n<sup>1</sup>The Hong Kong Polytechnic University, <sup>2</sup>OPPO Research Institute\n</div>\n\n:star: If CCSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:\n\n ## 🧡ྀི What's New in CCSR-v2?\nWe have implemented the CCSR-v2 code based on the [Diffusers](https://github.com/huggingface/diffusers). Compared to CCSR-v1, CCSR-v2 brings a host of upgrades:\n\n- 🛠️**Step Flexibility**: Offers flexibility in diffusion step selection, **allowing users to freely adjust the number of steps to suit their specific requirements**. This adaptability **requires no additional re-training**, ensuring seamless integration into diverse workflows.\n- ⚡**Efficiency**: Supports highly efficient inference with **as few as 2 or even 1 diffusion step**, drastically reducing computation time without compromising quality.\n- 📈**Enhanced Clarity**: With upgraded algorithms, CCSR-v2 restores images with crisper details while maintaining fidelity.\n- ⚖️**Results stability**: CCSR-v2 exhibits significantly improved stability in synthesizing fine image details, ensuring higher-quality outputs.\n- 🔄**Stage 2 Refinement**: In CCSR-v2, the output $\\hat{x}_{0 \\gets T}$ from Stage 1 is now directly fed into Stage 2, streamlining the restoration process into an efficient one-step diffusion workflow. This strategy boosts both speed and performance.\n\n![ccsr](figs/fig.png)\nVisual comparisons between the SR outputs with the same input low-quality image but two different noise samples by different DM-based\nmethods. `S` denotes diffusion sampling timesteps. Existing DM-based methods, including StableSR, PASD, SeeSR, SUPIR and AddSR, **show noticeable instability with the different noise samples**. OSEDiff directly takes low-quality image as input without\nnoise sampling. It is deterministic and stable, but **cannot perform multi-step diffusion** for high generative capacity. In contrast, **our proposed CCSR method\nis flexible for both multi-step diffusion and single-step diffusion, while producing stable results with high fidelity and visual quality**.\n\n## ⏰ Update\n- **2024.12.12**: Code and models for CCSR-v2 are released. 👀 Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v2.0).\n- **2024.9.25**: ⭐[CCSR-v2](https://arxiv.org/pdf/2401.00877) is released, offering reduced step requirements and supporting flexible diffusion step selection (2 or even 1 step) during the inference stage without the need for re-training.\n- **2023.12.23**: Code and models for [CCSR-v1](https://arxiv.org/pdf/2401.00877v1) are released. Please refer to this [branch](https://github.com/csslc/CCSR/tree/CCSR-v1.0).\n\n\n## 🌟 Overview Framework\n![ccsr](figs/framework.png)\n\n## 😍 Visual Results\n### Demo on Real-world SR\n\n[<img src=\"figs/compare_1.png\" height=\"213px\"/>](https://imgsli.com/MzI2MTg5) [<img src=\"figs/compare_2.png\" height=\"213px\"/>](https://imgsli.com/MzI2MTky/1/3) [<img src=\"figs/compare_3.png\" height=\"213px\"/>](https://imgsli.com/MzI2MTk0/0/2) [<img src=\"figs/compare_4.png\" height=\"213px\"/>](https://imgsli.com/MzI2MTk1/0/2) \n\n\n![ccsr](figs/compare_standard.png)\n\n![ccsr](figs/compare_efficient.png)\nFor more comparisons, please refer to our paper for details.\n\n## 📝 Quantitative comparisons\nWe propose new stability metrics, namely global standard deviation (G-STD) and local standard deviation (L-STD), to respectively measure the image-level and pixel-level variations of the SR results of diffusion-based methods.\n\nMore details about G-STD and L-STD can be found in our paper.\n\n![ccsr](figs/table.png)\n## ⚙ Dependencies and Installation\n```shell\n## git clone this repository\ngit clone https://github.com/csslc/CCSR.git\ncd CCSR\n\n\n# create an environment with python >= 3.9\nconda create -n ccsr python=3.9\nconda activate ccsr\npip install -r requirements.txt\n```\n## 🍭 Quick Inference\n**For ease of comparison, we have provided the test results of CCSR-v2 on the DIV2K, RealSR, and DrealSR benchmarks with varying diffusion steps, which can be accessed via [Google Drive](https://drive.google.com/drive/folders/1xjURQZgKAlENzMnAJA2PDG9h_UxfZzio?usp=sharing).**\n\n#### Step 1: Download the pretrained models\n- Download the pretrained SD-2.1-base models from [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).\n- Download the CCSR-v2 models from and put the models in the `preset/models`:\n\n| Model Name             | Description                      | GoogleDrive                                                                                                                                                        | BaiduNetdisk                                                                                                                 |\n|:-----------------------|:---------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------------------------------------------------------------------------------------------------------------------|\n| Controlnet             | Trained in the stage 1.          | [download](https://drive.google.com/drive/folders/1aHwgodKwKYZJBKs0QlFzanSjMDhrNyRA?usp=sharing)                                                                   | [download](https://pan.baidu.com/s/1SKS70iE4GhhHGxqY1KS8mw) (pwd: ccsr)                                                      |\n| VAE                    | Trained in the stage 2.          | [download](https://drive.google.com/drive/folders/1yHfMV81Md6db4StHTP5MC-eSeLFeBKm8?usp=sharing)                                                                   | [download](https://pan.baidu.com/s/1fxOIeL6Hk6Muq9h8itAIKQ) (pwd: ccsr)                                                      |\n| Pre-trained Controlnet | The pre-trained model of stage1. | [download](https://drive.google.com/drive/folders/1LTtBRuObITOJwbW-sTDnHtp8xIUZFDHh?usp=sharing)                                                                   | [download](https://pan.baidu.com/s/1mDeuHBqNj_Iol7PCY_Xfww) (pwd: ccsr)                                                      |\n| Dino models            | The pre-trained models for disc. | [download](https://drive.google.com/drive/folders/1PcuZGUTJlltdPz2yk2ZIa4GCtb1yk_y6?usp=sharing)                                                                   | [download](https://pan.baidu.com/s/1nPdNwgua91mDDRApWUm39Q) (pwd: ccsr)                                                      |\n\n#### Step 2: Prepare testing data\nYou can put the testing images in the `preset/test_datasets`.\n\n#### Step 3: Running testing command \nFor one-step diffusion process:\n```\npython test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 1 \\\n--t_min 0.0 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 1.0 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4\n```\nFor multi-step diffusion process:\n```\npython test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 6 \\\n--t_max 0.6667 \\\n--t_min 0.5 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 4.5 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4\n```\nWe integrate [tile_diffusion](https://github.com/albarji/mixture-of-diffusers) and [tile_vae](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111/tree/main) to the [test_ccsr_tile.py](test_ccsr_tile.py) to save the GPU memory for inference.\nYou can change the tile size and stride according to the VRAM of your device.\n```\npython test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 6 \\\n--t_max 0.6667 \\\n--t_min 0.5 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 4.5 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4 \\\n--tile_diffusion \\\n--tile_diffusion_size 512 \\\n--tile_diffusion_stride 256 \\\n--tile_vae \\\n--vae_decoder_tile_size 224 \\\n--vae_encoder_tile_size 1024 \\\n```\n\nYou can obtain `N` different SR results by setting `sample_times` as `N` to test the stability of CCSR. The data folder should be like this:\n\n```\n experiments/test\n ├── sample00   # the first group of SR results \n └── sample01   # the second group of SR results \n   ...\n └── sampleN   # the N-th group of SR results \n```\n\n## 📏 Evaluation\n1. Calculate the Image Quality Assessment for each restored group.\n\n   Fill in the required information in [cal_iqa.py](cal_iqa/cal_iqa.py) and run, then you can obtain the evaluation results in the folder like this:\n   ```\n    log_path\n    ├── log_name_npy  # save the IQA values of each restored group as the npy files\n    └── log_name.log   # log recode\n   ```\n\n2. Calculate the G-STD value for the diffusion-based SR method.\n\n   Fill in the required information in [iqa_G-STD.py](cal_iqa/iqa_G-STD.py) and run, then you can obtain the mean IQA values of N restored groups and G-STD value.\n\n3. Calculate the L-STD value for the diffusion-based SR method.\n\n   Fill in the required information in [iqa_L-STD.py](cal_iqa/iqa_L-STD.py) and run, then you can obtain the L-STD value.\n\n\n## 🚋 Train \n\n#### Step1: Prepare training data\n  Generate txt file for the training set.\n  Fill in the required information in [get_path](scripts/get_path.py) and run, then you can obtain the txt file recording the paths of ground-truth images. \n  You can save the txt file into `preset/gt_path.txt`.\n\n#### Step2: Train Stage1 Model\n1. Download pretrained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) to provide generative capabilities.\n\n    ```shell\n    wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt --no-check-certificate\n    ```\n\n2. Start training.\n\n    ```shell\n   CUDA_VISIBLE_DEVICES=\"0,1,2,3,\" accelerate launch train_ccsr_stage1.py \\\n    --pretrained_model_name_or_path=\"preset/models/stable-diffusion-2-1-base\" \\\n    --controlnet_model_name_or_path='preset/models/pretrained_controlnet' \\\n    --enable_xformers_memory_efficient_attention \\\n    --output_dir=\"./experiments/ccsrv2_stage1\" \\\n    --mixed_precision=\"fp16\" \\\n    --resolution=512 \\\n    --learning_rate=5e-5 \\\n    --train_batch_size=4 \\\n    --gradient_accumulation_steps=6 \\\n    --dataloader_num_workers=0 \\\n    --checkpointing_steps=500 \\\n    --t_max=0.6667 \\\n    --max_train_steps=20000 \\\n    --dataset_root_folders 'preset/gt_path.txt' \n    ```\n\n#### Step3: Train Stage2 Model\n1. Put the model obtained from the stage1 into `controlnet_model_name_or_path`.\n2. Start training.\n    ```shell\n    CUDA_VISIBLE_DEVICES=\"0,1,2,3,\" accelerate launch train_ccsr_stage2.py \\\n    --pretrained_model_name_or_path=\"preset/models/stable-diffusion-2-1-base\" \\\n    --controlnet_model_name_or_path='preset/models/model_stage1' \\\n    --enable_xformers_memory_efficient_attention \\\n    --output_dir=\"./experiments/ccsrv2_stage2\" \\\n    --mixed_precision=\"fp16\" \\\n    --resolution=512 \\\n    --learning_rate=5e-6 \\\n    --train_batch_size=2 \\\n    --gradient_accumulation_steps=8 \\\n    --checkpointing_steps=500 \\\n    --is_start_lr=True \\\n    --t_max=0.6667 \\\n    --num_inference_steps=1 \\\n    --is_module \\\n    --lambda_l2=1.0 \\\n    --lambda_lpips=1.0 \\\n    --lambda_disc=0.05 \\\n    --lambda_disc_train=0.5 \\\n    --begin_disc=100 \\\n    --max_train_steps=2000 \\\n    --dataset_root_folders 'preset/gt_path.txt'  \n      ```\n    \n    \n    \n    \n\n### Citations\n\nIf our code helps your research or work, please consider citing our paper.\nThe following are BibTeX references:\n\n```\n@article{sun2023ccsr,\n  title={Improving the Stability of Diffusion Models for Content Consistent Super-Resolution},\n  author={Sun, Lingchen and Wu, Rongyuan and Zhang, Zhengqiang and Yong, Hongwei and Zhang, Lei},\n  journal={arXiv preprint arXiv:2401.00877},\n  year={2024}\n}\n```\n\n### License\nThis project is released under the [Apache 2.0 license](LICENSE).\n\n### Acknowledgement\nThis project is based on [ControlNet](https://github.com/lllyasviel/ControlNet), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [SeeSR](https://github.com/cswry/SeeSR). Some codes are brought from [ADDSR](https://github.com/NJU-PCALab/AddSR). Thanks for their awesome works. \n\n### Contact\nIf you have any questions, please contact: ling-chen.sun@connect.polyu.hk\n\n\n<details>\n<summary>statistics</summary>\n\n![visitors](https://visitor-badge.laobi.icu/badge?page_id=csslc/CCSR)\n\n</details>\n\n\n"
  },
  {
    "path": "dataloaders/paired_dataset_txt.py",
    "content": "import glob\nimport os\nfrom PIL import Image\nimport random\nimport numpy as np\n\nfrom torch import nn\nfrom torchvision import transforms\nfrom torch.utils import data as data\nimport torch.nn.functional as F\n\nfrom .realesrgan import RealESRGAN_degradation\n\nclass PairedCaptionDataset(data.Dataset):\n    def __init__(\n            self,\n            root_folders=None,\n            tokenizer=None,\n            gt_ratio=0, # let lr is gt\n    ):\n        super(PairedCaptionDataset, self).__init__()\n\n        self.gt_ratio = gt_ratio\n        with open(root_folders, 'r') as f:\n            self.gt_list = [line.strip() for line in f.readlines()]\n\n        self.img_preproc = transforms.Compose([\n            transforms.RandomCrop((512, 512)),\n            transforms.Resize((512, 512)),\n            transforms.RandomHorizontalFlip(),\n            ])\n\n        self.degradation = RealESRGAN_degradation('dataloaders/params_ccsr.yml', device='cuda')\n        self.tokenizer = tokenizer\n\n    \n    def tokenize_caption(self, caption=\"\"):\n        inputs = self.tokenizer(\n            caption, max_length=self.tokenizer.model_max_length, padding=\"max_length\", truncation=True, return_tensors=\"pt\"\n        )\n\n        return inputs.input_ids\n\n    def __getitem__(self, index):\n\n        gt_path = self.gt_list[index]\n        gt_img = Image.open(gt_path).convert('RGB')\n        gt_img = self.img_preproc(gt_img)\n\n        gt_img, img_t = self.degradation.degrade_process(np.asarray(gt_img)/255., resize_bak=True)\n\n        if random.random() < self.gt_ratio:\n            lq_img = gt_img\n        else:\n            lq_img = img_t\n\n        # no caption used\n        lq_caption = ''\n\n        example = dict()\n        example[\"conditioning_pixel_values\"] = lq_img.squeeze(0) # [0, 1]\n        example[\"pixel_values\"] = gt_img.squeeze(0) * 2.0 - 1.0 # [-1, 1]\n        example[\"input_caption\"] = self.tokenize_caption(caption=lq_caption).squeeze(0)\n\n        lq_img = lq_img.squeeze()\n\n        return example\n\n    def __len__(self):\n        return len(self.gt_list)"
  },
  {
    "path": "dataloaders/params_ccsr.yml",
    "content": "scale: 4\ncolor_jitter_prob: 0.0\ngray_prob: 0.0\n\n# the first degradation process\nresize_prob: [0.2, 0.7, 0.1]  # up, down, keep\nresize_range: [0.3, 1.5]\ngaussian_noise_prob: 0.5\nnoise_range: [1, 15]\npoisson_scale_range: [0.05, 2.0]\ngray_noise_prob: 0.4\njpeg_range: [60, 95]\n\n\n# the second degradation process\nsecond_blur_prob: 0.5\nresize_prob2: [0.3, 0.4, 0.3]  # up, down, keep\nresize_range2: [0.6, 1.2]\ngaussian_noise_prob2: 0.5\nnoise_range2: [1, 12]\npoisson_scale_range2: [0.05, 1.0]\ngray_noise_prob2: 0.4\njpeg_range2: [60, 100]\n\nkernel_info:\n    blur_kernel_size: 21\n    kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n    kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]\n    sinc_prob: 0.1\n    blur_sigma: [0.2, 1.5]\n    betag_range: [0.5, 2.0]\n    betap_range: [1, 1.5]\n\n    blur_kernel_size2: 11\n    kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']\n    kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]\n    sinc_prob2: 0.1\n    blur_sigma2: [0.2, 1.0]\n    betag_range2: [0.5, 2.0]\n    betap_range2: [1, 1.5]\n\n    final_sinc_prob: 0.8\n"
  },
  {
    "path": "dataloaders/realesrgan.py",
    "content": "import os\nimport numpy as np\nimport cv2\nimport glob\nimport math\nimport yaml\nimport random\nfrom collections import OrderedDict\nimport torch\nimport torch.nn.functional as F\n\nfrom basicsr.data.transforms import augment\nfrom basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels\nfrom basicsr.utils import DiffJPEG, USMSharp, img2tensor, tensor2img\nfrom basicsr.utils.img_process_util import filter2D\nfrom basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt\nfrom torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation,\n                                               normalize, rgb_to_grayscale)\n\ncur_path = os.path.dirname(os.path.abspath(__file__))\n\n\ndef ordered_yaml():\n    \"\"\"Support OrderedDict for yaml.\n\n    Returns:\n        yaml Loader and Dumper.\n    \"\"\"\n    try:\n        from yaml import CDumper as Dumper\n        from yaml import CLoader as Loader\n    except ImportError:\n        from yaml import Dumper, Loader\n\n    _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG\n\n    def dict_representer(dumper, data):\n        return dumper.represent_dict(data.items())\n\n    def dict_constructor(loader, node):\n        return OrderedDict(loader.construct_pairs(node))\n\n    Dumper.add_representer(OrderedDict, dict_representer)\n    Loader.add_constructor(_mapping_tag, dict_constructor)\n    return Loader, Dumper\n\ndef opt_parse(opt_path):\n    with open(opt_path, mode='r') as f:\n        Loader, _ = ordered_yaml()\n        opt = yaml.load(f, Loader=Loader) \n\n    return opt\n\nclass RealESRGAN_degradation(object):\n    def __init__(self, opt_path='', device='cpu'):\n        self.opt = opt_parse(opt_path)\n        self.device = device #torch.device('cpu')\n        optk = self.opt['kernel_info']       \n\n        # blur settings for the first degradation\n        self.blur_kernel_size = optk['blur_kernel_size']\n        self.kernel_list = optk['kernel_list']\n        self.kernel_prob = optk['kernel_prob']\n        self.blur_sigma = optk['blur_sigma']\n        self.betag_range = optk['betag_range']\n        self.betap_range = optk['betap_range']\n        self.sinc_prob = optk['sinc_prob']\n\n        # blur settings for the second degradation\n        self.blur_kernel_size2 = optk['blur_kernel_size2']\n        self.kernel_list2 = optk['kernel_list2']\n        self.kernel_prob2 = optk['kernel_prob2']\n        self.blur_sigma2 = optk['blur_sigma2']\n        self.betag_range2 = optk['betag_range2']\n        self.betap_range2 = optk['betap_range2']\n        self.sinc_prob2 = optk['sinc_prob2']\n\n        # a final sinc filter\n        self.final_sinc_prob = optk['final_sinc_prob']\n\n        self.kernel_range = [2 * v + 1 for v in range(3, 11)]  # kernel size ranges from 7 to 21\n        self.pulse_tensor = torch.zeros(21, 21).float()  # convolving with pulse tensor brings no blurry effect\n        self.pulse_tensor[10, 10] = 1\n\n        self.jpeger = DiffJPEG(differentiable=False).to(self.device)\n        self.usm_shaper = USMSharp().to(self.device)\n    \n    def color_jitter_pt(self, img, brightness, contrast, saturation, hue):\n        fn_idx = torch.randperm(4)\n        for fn_id in fn_idx:\n            if fn_id == 0 and brightness is not None:\n                brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()\n                img = adjust_brightness(img, brightness_factor)\n\n            if fn_id == 1 and contrast is not None:\n                contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()\n                img = adjust_contrast(img, contrast_factor)\n\n            if fn_id == 2 and saturation is not None:\n                saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()\n                img = adjust_saturation(img, saturation_factor)\n\n            if fn_id == 3 and hue is not None:\n                hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()\n                img = adjust_hue(img, hue_factor)\n        return img\n\n    def random_augment(self, img_gt):\n        # random horizontal flip\n        img_gt, status = augment(img_gt, hflip=True, rotation=False, return_status=True)\n        \"\"\"\n        # random color jitter \n        if np.random.uniform() < self.opt['color_jitter_prob']:\n            jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32)\n            img_gt = img_gt + jitter_val\n            img_gt = np.clip(img_gt, 0, 1)    \n\n        # random grayscale\n        if np.random.uniform() < self.opt['gray_prob']:\n            #img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY)\n            img_gt = cv2.cvtColor(img_gt, cv2.COLOR_RGB2GRAY)\n            img_gt = np.tile(img_gt[:, :, None], [1, 1, 3])\n        \"\"\"\n        # BGR to RGB, HWC to CHW, numpy to tensor\n        img_gt = img2tensor([img_gt], bgr2rgb=False, float32=True)[0].unsqueeze(0)\n\n        return img_gt\n\n    def random_kernels(self):\n        # ------------------------ Generate kernels (used in the first degradation) ------------------------ #\n        kernel_size = random.choice(self.kernel_range)\n        if np.random.uniform() < self.sinc_prob:\n            # this sinc filter setting is for kernels ranging from [7, 21]\n            if kernel_size < 13:\n                omega_c = np.random.uniform(np.pi / 3, np.pi)\n            else:\n                omega_c = np.random.uniform(np.pi / 5, np.pi)\n            kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)\n        else:\n            kernel = random_mixed_kernels(\n                    self.kernel_list,\n                    self.kernel_prob,\n                    kernel_size,\n                    self.blur_sigma,\n                    self.blur_sigma, [-math.pi, math.pi],\n                    self.betag_range,\n                    self.betap_range,\n                    noise_range=None)\n        # pad kernel\n        pad_size = (21 - kernel_size) // 2\n        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))\n\n        # ------------------------ Generate kernels (used in the second degradation) ------------------------ #\n        kernel_size = random.choice(self.kernel_range)\n        if np.random.uniform() < self.sinc_prob2:\n            if kernel_size < 13:\n                omega_c = np.random.uniform(np.pi / 3, np.pi)\n            else:\n                omega_c = np.random.uniform(np.pi / 5, np.pi)\n            kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)\n        else:\n            kernel2 = random_mixed_kernels(\n                self.kernel_list2,\n                self.kernel_prob2,\n                kernel_size,\n                self.blur_sigma2,\n                self.blur_sigma2, [-math.pi, math.pi],\n                self.betag_range2,\n                self.betap_range2,\n                noise_range=None)\n\n        # pad kernel\n        pad_size = (21 - kernel_size) // 2\n        kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))\n\n        # ------------------------------------- sinc kernel ------------------------------------- #\n        if np.random.uniform() < self.final_sinc_prob:\n            kernel_size = random.choice(self.kernel_range)\n            omega_c = np.random.uniform(np.pi / 3, np.pi)\n            sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)\n            sinc_kernel = torch.FloatTensor(sinc_kernel)\n        else:\n            sinc_kernel = self.pulse_tensor\n\n        kernel = torch.FloatTensor(kernel)\n        kernel2 = torch.FloatTensor(kernel2) \n\n        return kernel, kernel2, sinc_kernel\n\n    @torch.no_grad()\n    def degrade_process(self, img_gt, resize_bak=False):\n        img_gt = self.random_augment(img_gt)\n        kernel1, kernel2, sinc_kernel = self.random_kernels()\n        img_gt, kernel1, kernel2, sinc_kernel = img_gt.to(self.device), kernel1.to(self.device), kernel2.to(self.device), sinc_kernel.to(self.device)\n        #img_gt = self.usm_shaper(img_gt) # shaper gt\n        ori_h, ori_w = img_gt.size()[2:4]\n\n        #scale_final = random.randint(4, 16)\n        scale_final = 4\n\n        # ----------------------- The first degradation process ----------------------- #\n        # blur\n        out = filter2D(img_gt, kernel1)\n        # random resize\n        updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0]\n        if updown_type == 'up':\n            scale = np.random.uniform(1, self.opt['resize_range'][1])\n        elif updown_type == 'down':\n            scale = np.random.uniform(self.opt['resize_range'][0], 1)\n        else:\n            scale = 1\n        mode = random.choice(['area', 'bilinear', 'bicubic'])\n        out = F.interpolate(out, scale_factor=scale, mode=mode)\n        # noise\n        gray_noise_prob = self.opt['gray_noise_prob']\n        if np.random.uniform() < self.opt['gaussian_noise_prob']:\n            out = random_add_gaussian_noise_pt(\n                out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)\n        else:\n            out = random_add_poisson_noise_pt(\n                out,\n                scale_range=self.opt['poisson_scale_range'],\n                gray_prob=gray_noise_prob,\n                clip=True,\n                rounds=False)\n        # JPEG compression\n        jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])\n        out = torch.clamp(out, 0, 1)\n        out = self.jpeger(out, quality=jpeg_p)\n\n        # ----------------------- The second degradation process ----------------------- #\n        # blur\n        if np.random.uniform() < self.opt['second_blur_prob']:\n            out = filter2D(out, kernel2)\n        # random resize\n        updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0]\n        if updown_type == 'up':\n            scale = np.random.uniform(1, self.opt['resize_range2'][1])\n        elif updown_type == 'down':\n            scale = np.random.uniform(self.opt['resize_range2'][0], 1)\n        else:\n            scale = 1\n        mode = random.choice(['area', 'bilinear', 'bicubic'])\n        out = F.interpolate(\n            out, size=(int(ori_h / scale_final * scale), int(ori_w / scale_final * scale)), mode=mode)\n        # noise\n        gray_noise_prob = self.opt['gray_noise_prob2']\n        if np.random.uniform() < self.opt['gaussian_noise_prob2']:\n            out = random_add_gaussian_noise_pt(\n                out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)\n        else:\n            out = random_add_poisson_noise_pt(\n                out,\n                scale_range=self.opt['poisson_scale_range2'],\n                gray_prob=gray_noise_prob,\n                clip=True,\n                rounds=False)\n\n        # JPEG compression + the final sinc filter\n        # We also need to resize images to desired sizes. We group [resize back + sinc filter] together\n        # as one operation.\n        # We consider two orders:\n        #   1. [resize back + sinc filter] + JPEG compression\n        #   2. JPEG compression + [resize back + sinc filter]\n        # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.\n        if np.random.uniform() < 0.5:\n            # resize back + the final sinc filter\n            mode = random.choice(['area', 'bilinear', 'bicubic'])\n            out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)\n            out = filter2D(out, sinc_kernel)\n            # JPEG compression\n            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])\n            out = torch.clamp(out, 0, 1)\n            out = self.jpeger(out, quality=jpeg_p)\n        else:\n            # JPEG compression\n            jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2'])\n            out = torch.clamp(out, 0, 1)\n            out = self.jpeger(out, quality=jpeg_p)\n            # resize back + the final sinc filter\n            mode = random.choice(['area', 'bilinear', 'bicubic'])\n            out = F.interpolate(out, size=(ori_h // scale_final, ori_w // scale_final), mode=mode)\n            out = filter2D(out, sinc_kernel)\n\n        if np.random.uniform() < self.opt['gray_prob']:\n            out = rgb_to_grayscale(out, num_output_channels=1)\n\n        if np.random.uniform() < self.opt['color_jitter_prob']:\n            brightness = self.opt.get('brightness', (0.5, 1.5))\n            contrast = self.opt.get('contrast', (0.5, 1.5))\n            saturation = self.opt.get('saturation', (0, 1.5))\n            hue = self.opt.get('hue', (-0.1, 0.1))\n            out = self.color_jitter_pt(out, brightness, contrast, saturation, hue)\n\n        if resize_bak:\n            mode = random.choice(['area', 'bilinear', 'bicubic'])\n            out = F.interpolate(out, size=(ori_h, ori_w), mode=mode)\n        # clamp and round\n        img_lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.\n\n        return img_gt, img_lq\n\n\n"
  },
  {
    "path": "models/DiffAugment.py",
    "content": "# BSD 2-Clause \"Simplified\" License\n# Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n# * Redistributions of source code must retain the above copyright notice, this\n#   list of conditions and the following disclaimer.\n#\n# * Redistributions in binary form must reproduce the above copyright notice,\n#   this list of conditions and the following disclaimer in the documentation\n#   and/or other materials provided with the distribution.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\n# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\n# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\n# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\n# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\n# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n#\n# Code from https://github.com/mit-han-lab/data-efficient-gans\n\n\"\"\"Training GANs with DiffAugment.\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\n\n\ndef DiffAugment(x: torch.Tensor, policy: str = '', channels_first: bool = True) -> torch.Tensor:\n    if policy:\n        if not channels_first:\n            x = x.permute(0, 3, 1, 2)\n        for p in policy.split(','):\n            for f in AUGMENT_FNS[p]:\n                x = f(x)\n        if not channels_first:\n            x = x.permute(0, 2, 3, 1)\n        x = x.contiguous()\n    return x\n\n\ndef rand_brightness(x: torch.Tensor) -> torch.Tensor:\n    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)\n    return x\n\n\ndef rand_saturation(x: torch.Tensor) -> torch.Tensor:\n    x_mean = x.mean(dim=1, keepdim=True)\n    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean\n    return x\n\n\ndef rand_contrast(x: torch.Tensor) -> torch.Tensor:\n    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)\n    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean\n    return x\n\n\ndef rand_translation(x: torch.Tensor, ratio: float = 0.125) -> torch.Tensor:\n    shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n    translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)\n    translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)\n    grid_batch, grid_x, grid_y = torch.meshgrid(\n        torch.arange(x.size(0), dtype=torch.long, device=x.device),\n        torch.arange(x.size(2), dtype=torch.long, device=x.device),\n        torch.arange(x.size(3), dtype=torch.long, device=x.device),\n    )\n    grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)\n    grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)\n    x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])\n    x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2)\n    return x\n\n\ndef rand_cutout(x: torch.Tensor, ratio: float = 0.2) -> torch.Tensor:\n    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)\n    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)\n    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)\n    grid_batch, grid_x, grid_y = torch.meshgrid(\n        torch.arange(x.size(0), dtype=torch.long, device=x.device),\n        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),\n        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),\n    )\n    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)\n    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)\n    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)\n    mask[grid_batch, grid_x, grid_y] = 0\n    x = x * mask.unsqueeze(1)\n    return x\n\n\ndef rand_resize(x: torch.Tensor, min_ratio: float = 0.8, max_ratio: float = 1.2) -> torch.Tensor:\n    resize_ratio = np.random.rand()*(max_ratio-min_ratio) + min_ratio\n    resized_img = F.interpolate(x, size=int(resize_ratio*x.shape[3]), mode='bilinear')\n    org_size = x.shape[3]\n    if int(resize_ratio*x.shape[3]) < x.shape[3]:\n        left_pad = (x.shape[3]-int(resize_ratio*x.shape[3]))/2.\n        left_pad = int(left_pad)\n        right_pad = x.shape[3] - left_pad - resized_img.shape[3]\n        x = F.pad(resized_img, (left_pad, right_pad, left_pad, right_pad), \"constant\", 0.)\n    else:\n        left = (int(resize_ratio*x.shape[3])-x.shape[3])/2.\n        left = int(left)\n        x = resized_img[:, :, left:(left+x.shape[3]), left:(left+x.shape[3])]\n    assert x.shape[2] == org_size\n    assert x.shape[3] == org_size\n    return x\n\n\nAUGMENT_FNS = {\n    'color': [rand_brightness, rand_saturation, rand_contrast],\n    'translation': [rand_translation],\n    'resize': [rand_resize],\n    'cutout': [rand_cutout],\n}"
  },
  {
    "path": "models/controlnet.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import FromOriginalControlnetMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom .unet_2d_blocks import (\n    CrossAttnDownBlock2D,\n    DownBlock2D,\n    UNetMidBlock2DCrossAttn,\n    get_down_block,\n)\nfrom .unet_2d_condition import UNet2DConditionModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass ControlNetOutput(BaseOutput):\n    \"\"\"\n    The output of [`ControlNetModel`].\n\n    Args:\n        down_block_res_samples (`tuple[torch.Tensor]`):\n            A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should\n            be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be\n            used to condition the original UNet's downsampling activations.\n        mid_down_block_re_sample (`torch.Tensor`):\n            The activation of the midde block (the lowest sample resolution). Each tensor should be of shape\n            `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.\n            Output can be used to condition the original UNet's middle block activation.\n    \"\"\"\n\n    down_block_res_samples: Tuple[torch.Tensor]\n    mid_block_res_sample: torch.Tensor\n\n\nclass ControlNetConditioningEmbedding(nn.Module):\n    \"\"\"\n    Quoting from https://arxiv.org/abs/2302.05543: \"Stable Diffusion uses a pre-processing method similar to VQ-GAN\n    [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized\n    training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the\n    convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides\n    (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full\n    model) to encode image-space conditions ... into feature maps ...\"\n    \"\"\"\n\n    def __init__(\n        self,\n        conditioning_embedding_channels: int,\n        conditioning_channels: int = 3,\n        block_out_channels: Tuple[int] = (16, 32, 96, 256),\n    ):\n        super().__init__()\n\n        self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)\n\n        self.blocks = nn.ModuleList([])\n\n        for i in range(len(block_out_channels) - 1):\n            channel_in = block_out_channels[i]\n            channel_out = block_out_channels[i + 1]\n            self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))\n            self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))\n\n        self.conv_out = zero_module(\n            nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)\n        )\n\n    def forward(self, conditioning):\n        embedding = self.conv_in(conditioning)\n        embedding = F.silu(embedding)\n\n        for block in self.blocks:\n            embedding = block(embedding)\n            embedding = F.silu(embedding)\n\n        embedding = self.conv_out(embedding)\n\n        return embedding\n\n\nclass ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):\n    \"\"\"\n    A ControlNet model.\n\n    Args:\n        in_channels (`int`, defaults to 4):\n            The number of channels in the input sample.\n        flip_sin_to_cos (`bool`, defaults to `True`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, defaults to 0):\n            The frequency shift to apply to the time embedding.\n        down_block_types (`tuple[str]`, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):\n        block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, defaults to 2):\n            The number of layers per block.\n        downsample_padding (`int`, defaults to 1):\n            The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, defaults to 1):\n            The scale factor to use for the mid block.\n        act_fn (`str`, defaults to \"silu\"):\n            The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32):\n            The number of groups to use for the normalization. If None, normalization and activation layers is skipped\n            in post-processing.\n        norm_eps (`float`, defaults to 1e-5):\n            The epsilon to use for the normalization.\n        cross_attention_dim (`int`, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):\n            The dimension of the attention heads.\n        use_linear_projection (`bool`, defaults to `False`):\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        num_class_embeds (`int`, *optional*, defaults to 0):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        upcast_attention (`bool`, defaults to `False`):\n        resnet_time_scale_shift (`str`, defaults to `\"default\"`):\n            Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.\n        projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):\n            The dimension of the `class_labels` input when `class_embed_type=\"projection\"`. Required when\n            `class_embed_type=\"projection\"`.\n        controlnet_conditioning_channel_order (`str`, defaults to `\"rgb\"`):\n            The channel order of conditional image. Will convert to `rgb` if it's `bgr`.\n        conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):\n            The tuple of output channel for each block in the `conditioning_embedding` layer.\n        global_pool_conditions (`bool`, defaults to `False`):\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        in_channels: int = 4,\n        conditioning_channels: int = 3,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: int = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: int = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),\n        global_pool_conditions: bool = False,\n        addition_embed_type_num_heads=64,\n        use_vae_encode_condition=False,\n    ):\n        super().__init__()\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        # input\n        conv_in_kernel = 3\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # use_vae_encode_condition\n        self.use_vae_encode_condition = use_vae_encode_condition\n        if self.use_vae_encode_condition:\n            print(f'============================')\n            print(f'use vae encode condition in CONTROLNET!!!')\n            print(f'============================')\n            self.condition_conv_in = nn.Conv2d(\n                in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n            )\n        else:\n            print(f'============================')\n            print(f'Not !!! use vae encode condition in CONTROLNET')\n            print(f'============================')\n            # control net conditioning embedding\n            self.controlnet_cond_embedding = ControlNetConditioningEmbedding(\n                conditioning_embedding_channels=block_out_channels[0],\n                block_out_channels=conditioning_embedding_out_channels,\n                conditioning_channels=conditioning_channels,\n            )\n\n        # time\n        time_embed_dim = block_out_channels[0] * 4\n        self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n        timestep_input_dim = block_out_channels[0]\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        self.down_blocks = nn.ModuleList([])\n        self.controlnet_down_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        # down\n        output_channel = block_out_channels[0]\n\n        controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_down_blocks.append(controlnet_block)\n\n\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block,\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim,\n                num_attention_heads=num_attention_heads[i],\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n                downsample_padding=downsample_padding,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n            )\n            self.down_blocks.append(down_block)\n\n            for _ in range(layers_per_block):\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n            if not is_final_block:\n                controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)\n                controlnet_block = zero_module(controlnet_block)\n                self.controlnet_down_blocks.append(controlnet_block)\n\n        # mid\n        mid_block_channel = block_out_channels[-1]\n\n        controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)\n        controlnet_block = zero_module(controlnet_block)\n        self.controlnet_mid_block = controlnet_block\n\n        self.mid_block = UNetMidBlock2DCrossAttn(\n            transformer_layers_per_block=transformer_layers_per_block[-1],\n            in_channels=mid_block_channel,\n            temb_channels=time_embed_dim,\n            resnet_eps=norm_eps,\n            resnet_act_fn=act_fn,\n            output_scale_factor=mid_block_scale_factor,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads[-1],\n            resnet_groups=norm_num_groups,\n            use_linear_projection=use_linear_projection,\n            upcast_attention=upcast_attention,\n        )\n\n    @classmethod\n    def from_unet(\n        cls,\n        unet: UNet2DConditionModel,\n        controlnet_conditioning_channel_order: str = \"rgb\",\n        conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),\n        load_weights_from_unet: bool = True,\n        use_vae_encode_condition: bool = False,\n    ):\n        r\"\"\"\n        Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].\n\n        Parameters:\n            unet (`UNet2DConditionModel`):\n                The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied\n                where applicable.\n        \"\"\"\n        transformer_layers_per_block = (\n            unet.config.transformer_layers_per_block if \"transformer_layers_per_block\" in unet.config else 1\n        )\n        encoder_hid_dim = unet.config.encoder_hid_dim if \"encoder_hid_dim\" in unet.config else None\n        encoder_hid_dim_type = unet.config.encoder_hid_dim_type if \"encoder_hid_dim_type\" in unet.config else None\n        addition_embed_type = unet.config.addition_embed_type if \"addition_embed_type\" in unet.config else None\n        addition_time_embed_dim = (\n            unet.config.addition_time_embed_dim if \"addition_time_embed_dim\" in unet.config else None\n        )\n\n        controlnet = cls(\n            encoder_hid_dim=encoder_hid_dim,\n            encoder_hid_dim_type=encoder_hid_dim_type,\n            addition_embed_type=addition_embed_type,\n            addition_time_embed_dim=addition_time_embed_dim,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=unet.config.in_channels,\n            flip_sin_to_cos=unet.config.flip_sin_to_cos,\n            freq_shift=unet.config.freq_shift,\n            down_block_types=unet.config.down_block_types,\n            only_cross_attention=unet.config.only_cross_attention,\n            block_out_channels=unet.config.block_out_channels,\n            layers_per_block=unet.config.layers_per_block,\n            downsample_padding=unet.config.downsample_padding,\n            mid_block_scale_factor=unet.config.mid_block_scale_factor,\n            act_fn=unet.config.act_fn,\n            norm_num_groups=unet.config.norm_num_groups,\n            norm_eps=unet.config.norm_eps,\n            cross_attention_dim=unet.config.cross_attention_dim,\n            attention_head_dim=unet.config.attention_head_dim,\n            num_attention_heads=unet.config.num_attention_heads,\n            use_linear_projection=unet.config.use_linear_projection,\n            class_embed_type=unet.config.class_embed_type,\n            num_class_embeds=unet.config.num_class_embeds,\n            upcast_attention=unet.config.upcast_attention,\n            resnet_time_scale_shift=unet.config.resnet_time_scale_shift,\n            projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,\n            controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,\n            conditioning_embedding_out_channels=conditioning_embedding_out_channels,\n            use_vae_encode_condition=use_vae_encode_condition,\n        )\n\n        if load_weights_from_unet:\n            controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())\n            controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())\n            controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())\n\n            if controlnet.class_embedding:\n                controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())\n\n            controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)\n            controlnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)\n\n        return controlnet\n\n    @property\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        controlnet_cond: torch.FloatTensor,\n        conditioning_scale: float = 1.0,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        guess_mode: bool = False,\n        return_dict: bool = True,\n        image_encoder_hidden_states: torch.Tensor = None,\n        vae_encode_condition_hidden_states: torch.Tensor = None, \n        use_vae_encode_condition = False,\n    ) -> Union[ControlNetOutput, Tuple]:\n        \"\"\"\n        The [`ControlNetModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor.\n            timestep (`Union[torch.Tensor, float, int]`):\n                The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.Tensor`):\n                The encoder hidden states.\n            controlnet_cond (`torch.FloatTensor`):\n                The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.\n            conditioning_scale (`float`, defaults to `1.0`):\n                The scale factor for ControlNet outputs.\n            class_labels (`torch.Tensor`, *optional*, defaults to `None`):\n                Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.\n            timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):\n            attention_mask (`torch.Tensor`, *optional*, defaults to `None`):\n            added_cond_kwargs (`dict`):\n                Additional conditions for the Stable Diffusion XL UNet.\n            cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):\n                A kwargs dictionary that if specified is passed along to the `AttnProcessor`.\n            guess_mode (`bool`, defaults to `False`):\n                In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if\n                you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.\n            return_dict (`bool`, defaults to `True`):\n                Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.\n\n        Returns:\n            [`~models.controlnet.ControlNetOutput`] **or** `tuple`:\n                If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is\n                returned where the first element is the sample tensor.\n        \"\"\"\n        # check channel order\n        channel_order = self.config.controlnet_conditioning_channel_order\n\n        if channel_order == \"rgb\":\n            # in rgb order by default\n            ...\n        elif channel_order == \"bgr\":\n            controlnet_cond = torch.flip(controlnet_cond, dims=[1])\n        else:\n            raise ValueError(f\"unknown `controlnet_conditioning_channel_order`: {channel_order}\")\n\n        # prepare attention_mask\n        if attention_mask is not None:\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # timesteps does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)\n            emb = emb + class_emb\n\n        if self.config.addition_embed_type is not None:\n            if self.config.addition_embed_type == \"text\":\n                aug_emb = self.add_embedding(encoder_hidden_states)\n\n            elif self.config.addition_embed_type == \"text_time\":\n                if \"text_embeds\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                    )\n                text_embeds = added_cond_kwargs.get(\"text_embeds\")\n                if \"time_ids\" not in added_cond_kwargs:\n                    raise ValueError(\n                        f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                    )\n                time_ids = added_cond_kwargs.get(\"time_ids\")\n                time_embeds = self.add_time_proj(time_ids.flatten())\n                time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n                add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n                add_embeds = add_embeds.to(emb.dtype)\n                aug_emb = self.add_embedding(add_embeds)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        if not self.use_vae_encode_condition:\n            controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)\n        else:\n            controlnet_cond = self.condition_conv_in(vae_encode_condition_hidden_states)\n\n        sample = sample + controlnet_cond\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    image_encoder_hidden_states=image_encoder_hidden_states,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                image_encoder_hidden_states=image_encoder_hidden_states,\n            )\n\n        # 5. Control net blocks\n\n        controlnet_down_block_res_samples = ()\n\n        for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):\n            down_block_res_sample = controlnet_block(down_block_res_sample)\n            controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)\n\n        down_block_res_samples = controlnet_down_block_res_samples\n\n        mid_block_res_sample = self.controlnet_mid_block(sample)\n\n        # 6. scaling\n        if guess_mode and not self.config.global_pool_conditions:\n            scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device)  # 0.1 to 1.0\n\n            scales = scales * conditioning_scale\n            down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]\n            mid_block_res_sample = mid_block_res_sample * scales[-1]  # last one\n        else:\n            down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]\n            mid_block_res_sample = mid_block_res_sample * conditioning_scale\n\n        if self.config.global_pool_conditions:\n            down_block_res_samples = [\n                torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples\n            ]\n            mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)\n\n        if not return_dict:\n            return (down_block_res_samples, mid_block_res_sample)\n\n        return ControlNetOutput(\n            down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample\n        )\n\n\ndef zero_module(module):\n    for p in module.parameters():\n        nn.init.zeros_(p)\n    return module\n\n\n"
  },
  {
    "path": "models/losses/__init__.py",
    "content": "from models.losses.contperceptual import LPIPSWithDiscriminator"
  },
  {
    "path": "models/losses/contperceptual.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom taming.modules.losses.vqperceptual import *  # TODO: taming dependency yes/no?\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import FromOriginalControlnetMixin\n\nclass LPIPSWithDiscriminator(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):\n    def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_loss=\"hinge\"):\n\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        self.kl_weight = kl_weight\n        self.pixel_weight = pixelloss_weight\n        self.perceptual_loss = LPIPS().eval()\n        self.perceptual_weight = perceptual_weight\n        # output log variance\n        self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        self.disc_loss = hinge_d_loss if disc_loss == \"hinge\" else vanilla_d_loss\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        if last_layer is not None:\n            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        else:\n            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = d_weight * self.discriminator_weight\n        return d_weight\n\n    def forward(self, inputs, reconstructions, optimizer_idx,\n                global_step, posteriors=None, last_layer=None, cond=None, split=\"train\",\n                weights=None, return_dic=False):\n        rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n\n        nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar\n        weighted_nll_loss = nll_loss\n        if weights is not None:\n            weighted_nll_loss = weights*nll_loss\n        weighted_nll_loss = torch.mean(weighted_nll_loss) / weighted_nll_loss.shape[0]\n        nll_loss = torch.mean(nll_loss) / nll_loss.shape[0]\n        if self.kl_weight>0:\n            kl_loss = posteriors.kl()\n            kl_loss = torch.mean(kl_loss) / kl_loss.shape[0]\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            if self.disc_factor > 0.0:\n                try:\n                    d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n                except RuntimeError:\n                    # assert not self.training\n                    d_weight = torch.tensor(1.0) * self.discriminator_weight\n            else:\n                # d_weight = torch.tensor(0.0)\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            if self.kl_weight>0:\n                loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss\n                log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(), \"{}/logvar\".format(split): self.logvar.detach(),\n                       \"{}/kl_loss\".format(split): kl_loss.detach().mean(), \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                       \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                       \"{}/d_weight\".format(split): d_weight.detach(),\n                       \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                       \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                       }\n                if return_dic:\n                    loss_dic = {}\n                    loss_dic['total_loss'] = loss.clone().detach().mean()\n                    loss_dic['logvar'] = self.logvar.detach()\n                    loss_dic['kl_loss'] = kl_loss.detach().mean()\n                    loss_dic['nll_loss'] = nll_loss.detach().mean()\n                    loss_dic['rec_loss'] = rec_loss.detach().mean()\n                    loss_dic['d_weight'] = d_weight.detach()\n                    loss_dic['disc_factor'] = torch.tensor(disc_factor)\n                    loss_dic['g_loss'] = g_loss.detach().mean()\n            else:\n                loss = weighted_nll_loss + d_weight * disc_factor * g_loss\n                log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(), \"{}/logvar\".format(split): self.logvar.detach(),\n                       \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                       \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                       \"{}/d_weight\".format(split): d_weight.detach(),\n                       \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                       \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                       }\n                if return_dic:\n                    loss_dic = {}\n                    loss_dic[\"{}/total_loss\".format(split)] = loss.clone().detach().mean()\n                    loss_dic[\"{}/logvar\".format(split)] = self.logvar.detach()\n                    loss_dic['nll_loss'.format(split)] = nll_loss.detach().mean()\n                    loss_dic['rec_loss'.format(split)] = rec_loss.detach().mean()\n                    loss_dic['d_weight'.format(split)] = d_weight.detach()\n                    loss_dic['disc_factor'.format(split)] = torch.tensor(disc_factor)\n                    loss_dic['g_loss'.format(split)] = g_loss.detach().mean()\n\n            if return_dic:\n                return loss, log, loss_dic\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n\n            if return_dic:\n                loss_dic = {}\n                loss_dic[\"{}/disc_loss\".format(split)] = d_loss.clone().detach().mean()\n                loss_dic[\"{}/logits_real\".format(split)] = logits_real.detach().mean()\n                loss_dic[\"{}/logits_fake\".format(split)] = logits_fake.detach().mean()\n                return d_loss, log, loss_dic\n\n            return d_loss, log\n\n"
  },
  {
    "path": "models/losses/vqperceptual.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import repeat\n\nfrom taming.modules.discriminator.model import NLayerDiscriminator, weights_init\nfrom taming.modules.losses.lpips import LPIPS\nfrom taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss\n\n\ndef hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):\n    assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]\n    loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])\n    loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])\n    loss_real = (weights * loss_real).sum() / weights.sum()\n    loss_fake = (weights * loss_fake).sum() / weights.sum()\n    d_loss = 0.5 * (loss_real + loss_fake)\n    return d_loss\n\ndef adopt_weight(weight, global_step, threshold=0, value=0.):\n    if global_step < threshold:\n        weight = value\n    return weight\n\n\ndef measure_perplexity(predicted_indices, n_embed):\n    # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py\n    # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally\n    encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)\n    avg_probs = encodings.mean(0)\n    perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()\n    cluster_use = torch.sum(avg_probs > 0)\n    return perplexity, cluster_use\n\ndef l1(x, y):\n    return torch.abs(x-y)\n\n\ndef l2(x, y):\n    return torch.pow((x-y), 2)\n\n\nclass VQLPIPSWithDiscriminator(nn.Module):\n    def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,\n                 disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,\n                 perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,\n                 disc_ndf=64, disc_loss=\"hinge\", n_classes=None, perceptual_loss=\"lpips\",\n                 pixel_loss=\"l1\"):\n        super().__init__()\n        assert disc_loss in [\"hinge\", \"vanilla\"]\n        assert perceptual_loss in [\"lpips\", \"clips\", \"dists\"]\n        assert pixel_loss in [\"l1\", \"l2\"]\n        self.codebook_weight = codebook_weight\n        self.pixel_weight = pixelloss_weight\n        if perceptual_loss == \"lpips\":\n            print(f\"{self.__class__.__name__}: Running with LPIPS.\")\n            self.perceptual_loss = LPIPS().eval()\n        else:\n            raise ValueError(f\"Unknown perceptual loss: >> {perceptual_loss} <<\")\n        self.perceptual_weight = perceptual_weight\n\n        if pixel_loss == \"l1\":\n            self.pixel_loss = l1\n        else:\n            self.pixel_loss = l2\n\n        self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,\n                                                 n_layers=disc_num_layers,\n                                                 use_actnorm=use_actnorm,\n                                                 ndf=disc_ndf\n                                                 ).apply(weights_init)\n        self.discriminator_iter_start = disc_start\n        if disc_loss == \"hinge\":\n            self.disc_loss = hinge_d_loss\n        elif disc_loss == \"vanilla\":\n            self.disc_loss = vanilla_d_loss\n        else:\n            raise ValueError(f\"Unknown GAN loss '{disc_loss}'.\")\n        print(f\"VQLPIPSWithDiscriminator running with {disc_loss} loss.\")\n        self.disc_factor = disc_factor\n        self.discriminator_weight = disc_weight\n        self.disc_conditional = disc_conditional\n        self.n_classes = n_classes\n\n    # def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n    #     if last_layer is not None:\n    #         nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n    #         g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n    #     else:\n    #         nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n    #         g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n    #     d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n    #     d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n    #     d_weight = d_weight * self.discriminator_weight\n    #     return d_weight\n\n    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):\n        # if last_layer is not None:\n        #     nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]\n        #     g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]\n        # else:\n        #     nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]\n        #     g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]\n\n        # d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)\n        # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()\n        d_weight = 1.0 * self.discriminator_weight\n        return d_weight\n\n    def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,\n                global_step, last_layer=None, cond=None, split=\"train\", predicted_indices=None):\n        if not exists(codebook_loss):\n            codebook_loss = torch.tensor([0.]).to(inputs.device)\n        #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())\n        rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())\n        if self.perceptual_weight > 0:\n            p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())\n            rec_loss = rec_loss + self.perceptual_weight * p_loss\n        else:\n            p_loss = torch.tensor([0.0])\n\n        nll_loss = rec_loss\n        #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]\n        nll_loss = torch.mean(nll_loss)\n\n        # now the GAN part\n        if optimizer_idx == 0:\n            # generator update\n            if cond is None:\n                assert not self.disc_conditional\n                logits_fake = self.discriminator(reconstructions.contiguous())\n            else:\n                assert self.disc_conditional\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))\n            g_loss = -torch.mean(logits_fake)\n\n            try:\n                d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)\n            except RuntimeError:\n                assert not self.training\n                d_weight = torch.tensor(0.0)\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()\n\n            log = {\"{}/total_loss\".format(split): loss.clone().detach().mean(),\n                   \"{}/quant_loss\".format(split): codebook_loss.detach().mean(),\n                   \"{}/nll_loss\".format(split): nll_loss.detach().mean(),\n                   \"{}/rec_loss\".format(split): rec_loss.detach().mean(),\n                   \"{}/p_loss\".format(split): p_loss.detach().mean(),\n                   \"{}/d_weight\".format(split): d_weight.detach(),\n                   \"{}/disc_factor\".format(split): torch.tensor(disc_factor),\n                   \"{}/g_loss\".format(split): g_loss.detach().mean(),\n                   }\n            if predicted_indices is not None:\n                assert self.n_classes is not None\n                with torch.no_grad():\n                    perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)\n                log[f\"{split}/perplexity\"] = perplexity\n                log[f\"{split}/cluster_usage\"] = cluster_usage\n            return loss, log\n\n        if optimizer_idx == 1:\n            # second pass for discriminator update\n            if cond is None:\n                logits_real = self.discriminator(inputs.contiguous().detach())\n                logits_fake = self.discriminator(reconstructions.contiguous().detach())\n            else:\n                logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))\n                logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))\n\n            disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)\n            d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)\n\n            log = {\"{}/disc_loss\".format(split): d_loss.clone().detach().mean(),\n                   \"{}/logits_real\".format(split): logits_real.detach().mean(),\n                   \"{}/logits_fake\".format(split): logits_fake.detach().mean()\n                   }\n            return d_loss, log\n"
  },
  {
    "path": "models/shared.py",
    "content": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\n\"\"\"Shared architecture blocks.\"\"\"\n\nfrom typing import Callable\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\n\nfrom ADD.th_utils.ops import bias_act\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, fn: Callable):\n        super().__init__()\n        self.fn = fn\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return (self.fn(x) + x) / np.sqrt(2)\n\n\nclass FullyConnectedLayer(nn.Module):\n    def __init__(\n        self,\n        in_features: int,              # Number of input features.\n        out_features: int,             # Number of output features.\n        bias: bool  = True,            # Apply additive bias before the activation function?\n        activation: str   = 'linear',  # Activation function: 'relu', 'lrelu', etc.\n        lr_multiplier: float = 1.0,    # Learning rate multiplier.\n        weight_init: float = 1.0,      # Initial standard deviation of the weight tensor.\n        bias_init: float = 0.0,        # Initial value for the additive bias.\n    ):\n\n        super().__init__()\n        self.in_features = in_features\n        self.out_features = out_features\n        self.activation = activation\n        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))\n        bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])\n        self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None\n        self.weight_gain = lr_multiplier / np.sqrt(in_features)\n        self.bias_gain = lr_multiplier\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        w = self.weight.to(x.dtype) * self.weight_gain\n        b = self.bias\n        if b is not None:\n            b = b.to(x.dtype)\n            if self.bias_gain != 1:\n                b = b * self.bias_gain\n\n        if self.activation == 'linear' and b is not None:\n            x = torch.addmm(b.unsqueeze(0), x, w.t())\n        else:\n            x = x.matmul(w.t())\n            x = bias_act.bias_act(x, b, act=self.activation)\n        return x\n\n    def extra_repr(self) -> str:\n        return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'\n\n\nclass MLP(nn.Module):\n    def __init__(\n        self,\n        features_list: list[int],    # Number of features in each layer of the MLP.\n        activation: str = 'linear',  # Activation function: 'relu', 'lrelu', etc.\n        lr_multiplier: float = 1.0,  # Learning rate multiplier.\n        linear_out: bool = False     # Use the 'linear' activation function for the output layer?\n    ):\n        super().__init__()\n        num_layers = len(features_list) - 1\n        self.num_layers = num_layers\n        self.out_dim = features_list[-1]\n\n        for idx in range(num_layers):\n            in_features = features_list[idx]\n            out_features = features_list[idx + 1]\n            if linear_out and idx == num_layers-1:\n                activation = 'linear'\n            layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)\n            setattr(self, f'fc{idx}', layer)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        ''' if x is sequence of tokens, shift tokens to batch and apply MLP to all'''\n        shift2batch = (x.ndim == 3)\n\n        if shift2batch:\n            B, K, C = x.shape\n            x = x.flatten(0,1)\n\n        for idx in range(self.num_layers):\n            layer = getattr(self, f'fc{idx}')\n            x = layer(x)\n\n        if shift2batch:\n            x = x.reshape(B, K, -1)\n\n        return x\n"
  },
  {
    "path": "models/unet_2d_blocks.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nfrom diffusers.utils import is_torch_version, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention import AdaGroupNorm\nfrom diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0\nfrom diffusers.models.dual_transformer_2d import DualTransformer2DModel\nfrom diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D\nfrom diffusers.models.transformer_2d import Transformer2DModel\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\ndef get_down_block(\n    down_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    temb_channels,\n    add_downsample,\n    resnet_eps,\n    resnet_act_fn,\n    transformer_layers_per_block=1,\n    num_attention_heads=None,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    downsample_padding=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n    attention_type=\"default\",\n    resnet_skip_time_act=False,\n    resnet_out_scale_factor=1.0,\n    cross_attention_norm=None,\n    attention_head_dim=None,\n    downsample_type=None,\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    down_block_type = down_block_type[7:] if down_block_type.startswith(\"UNetRes\") else down_block_type\n    if down_block_type == \"DownBlock2D\":\n        return DownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"ResnetDownsampleBlock2D\":\n        return ResnetDownsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif down_block_type == \"AttnDownBlock2D\":\n        if add_downsample is False:\n            downsample_type = None\n        else:\n            downsample_type = downsample_type or \"conv\"  # default to 'conv'\n        return AttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            downsample_type=downsample_type,\n        )\n    elif down_block_type == \"CrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnDownBlock2D\")\n        return CrossAttnDownBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n        )\n    elif down_block_type == \"SimpleCrossAttnDownBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D\")\n        return SimpleCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif down_block_type == \"SkipDownBlock2D\":\n        return SkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnSkipDownBlock2D\":\n        return AttnSkipDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"DownEncoderBlock2D\":\n        return DownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"AttnDownEncoderBlock2D\":\n        return AttnDownEncoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            downsample_padding=downsample_padding,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif down_block_type == \"KDownBlock2D\":\n        return KDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif down_block_type == \"KCrossAttnDownBlock2D\":\n        return KCrossAttnDownBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_downsample=add_downsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            add_self_attention=True if not add_downsample else False,\n        )\n    raise ValueError(f\"{down_block_type} does not exist.\")\n\n\ndef get_up_block(\n    up_block_type,\n    num_layers,\n    in_channels,\n    out_channels,\n    prev_output_channel,\n    temb_channels,\n    add_upsample,\n    resnet_eps,\n    resnet_act_fn,\n    transformer_layers_per_block=1,\n    num_attention_heads=None,\n    resnet_groups=None,\n    cross_attention_dim=None,\n    dual_cross_attention=False,\n    use_linear_projection=False,\n    only_cross_attention=False,\n    upcast_attention=False,\n    resnet_time_scale_shift=\"default\",\n    attention_type=\"default\",\n    resnet_skip_time_act=False,\n    resnet_out_scale_factor=1.0,\n    cross_attention_norm=None,\n    attention_head_dim=None,\n    upsample_type=None,\n):\n    # If attn head dim is not defined, we default it to the number of heads\n    if attention_head_dim is None:\n        logger.warn(\n            f\"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}.\"\n        )\n        attention_head_dim = num_attention_heads\n\n    up_block_type = up_block_type[7:] if up_block_type.startswith(\"UNetRes\") else up_block_type\n    if up_block_type == \"UpBlock2D\":\n        return UpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"ResnetUpsampleBlock2D\":\n        return ResnetUpsampleBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n        )\n    elif up_block_type == \"CrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for CrossAttnUpBlock2D\")\n        return CrossAttnUpBlock2D(\n            num_layers=num_layers,\n            transformer_layers_per_block=transformer_layers_per_block,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            num_attention_heads=num_attention_heads,\n            dual_cross_attention=dual_cross_attention,\n            use_linear_projection=use_linear_projection,\n            only_cross_attention=only_cross_attention,\n            upcast_attention=upcast_attention,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            attention_type=attention_type,\n        )\n    elif up_block_type == \"SimpleCrossAttnUpBlock2D\":\n        if cross_attention_dim is None:\n            raise ValueError(\"cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D\")\n        return SimpleCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            skip_time_act=resnet_skip_time_act,\n            output_scale_factor=resnet_out_scale_factor,\n            only_cross_attention=only_cross_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n    elif up_block_type == \"AttnUpBlock2D\":\n        if add_upsample is False:\n            upsample_type = None\n        else:\n            upsample_type = upsample_type or \"conv\"  # default to 'conv'\n\n        return AttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            upsample_type=upsample_type,\n        )\n    elif up_block_type == \"SkipUpBlock2D\":\n        return SkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"AttnSkipUpBlock2D\":\n        return AttnSkipUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            prev_output_channel=prev_output_channel,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n        )\n    elif up_block_type == \"UpDecoderBlock2D\":\n        return UpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"AttnUpDecoderBlock2D\":\n        return AttnUpDecoderBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            resnet_groups=resnet_groups,\n            attention_head_dim=attention_head_dim,\n            resnet_time_scale_shift=resnet_time_scale_shift,\n            temb_channels=temb_channels,\n        )\n    elif up_block_type == \"KUpBlock2D\":\n        return KUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n        )\n    elif up_block_type == \"KCrossAttnUpBlock2D\":\n        return KCrossAttnUpBlock2D(\n            num_layers=num_layers,\n            in_channels=in_channels,\n            out_channels=out_channels,\n            temb_channels=temb_channels,\n            add_upsample=add_upsample,\n            resnet_eps=resnet_eps,\n            resnet_act_fn=resnet_act_fn,\n            cross_attention_dim=cross_attention_dim,\n            attention_head_dim=attention_head_dim,\n        )\n\n    raise ValueError(f\"{up_block_type} does not exist.\")\n\n\nclass AutoencoderTinyBlock(nn.Module):\n    def __init__(self, in_channels: int, out_channels: int, act_fn: str):\n        super().__init__()\n        act_fn = get_activation(act_fn)\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),\n            act_fn,\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n            act_fn,\n            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),\n        )\n        self.skip = (\n            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)\n            if in_channels != out_channels\n            else nn.Identity()\n        )\n        self.fuse = nn.ReLU()\n\n    def forward(self, x):\n        return self.fuse(self.conv(x) + self.skip(x))\n\n\nclass UNetMidBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        add_attention: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n    ):\n        super().__init__()\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n        self.add_attention = add_attention\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}.\"\n            )\n            attention_head_dim = in_channels\n\n        for _ in range(num_layers):\n            if self.add_attention:\n                attentions.append(\n                    Attention(\n                        in_channels,\n                        heads=in_channels // attention_head_dim,\n                        dim_head=attention_head_dim,\n                        rescale_output_factor=output_scale_factor,\n                        eps=resnet_eps,\n                        norm_num_groups=resnet_groups if resnet_time_scale_shift == \"default\" else None,\n                        spatial_norm_dim=temb_channels if resnet_time_scale_shift == \"spatial\" else None,\n                        residual_connection=True,\n                        bias=True,\n                        upcast_softmax=True,\n                        _from_deprecated_attn_block=True,\n                    )\n                )\n            else:\n                attentions.append(None)\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n    def forward(self, hidden_states, temb=None):\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if attn is not None:\n                hidden_states = attn(hidden_states, temb=temb)\n            hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass UNetMidBlock2DCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        output_scale_factor=1.0,\n        cross_attention_dim=1280,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        upcast_attention=False,\n        attention_type=\"default\",\n        image_cross_attention_dim=512,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n            )\n        ]\n        attentions = []\n\n        for _ in range(num_layers):\n            if not dual_cross_attention:\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        in_channels // num_attention_heads,\n                        in_channels=in_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        in_channels // num_attention_heads,\n                        in_channels=in_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        image_encoder_hidden_states: Optional[torch.FloatTensor] = None,\n    ) -> torch.FloatTensor:\n        hidden_states = self.resnets[0](hidden_states, temb)\n\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n            else:\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass UNetMidBlock2DSimpleCrossAttn(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n        cross_attention_dim=1280,\n        skip_time_act=False,\n        only_cross_attention=False,\n        cross_attention_norm=None,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n\n        self.attention_head_dim = attention_head_dim\n        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)\n\n        self.num_heads = in_channels // self.attention_head_dim\n\n        # there is always at least one resnet\n        resnets = [\n            ResnetBlock2D(\n                in_channels=in_channels,\n                out_channels=in_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=resnet_groups,\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                skip_time_act=skip_time_act,\n            )\n        ]\n        attentions = []\n\n        for _ in range(num_layers):\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=in_channels,\n                    cross_attention_dim=in_channels,\n                    heads=self.num_heads,\n                    dim_head=self.attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=in_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            # attn\n            hidden_states = attn(\n                hidden_states,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=mask,\n                **cross_attention_kwargs,\n            )\n\n            # resnet\n            hidden_states = resnet(hidden_states, temb)\n\n        return hidden_states\n\n\nclass AttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n        downsample_padding=1,\n        downsample_type=\"conv\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n        self.downsample_type = downsample_type\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if downsample_type == \"conv\":\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        elif downsample_type == \"resnet\":\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states, temb=None, upsample_size=None):\n        output_states = ()\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states)\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                if self.downsample_type == \"resnet\":\n                    hidden_states = downsampler(hidden_states, temb=temb)\n                else:\n                    hidden_states = downsampler(hidden_states)\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass CrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        downsample_padding=1,\n        add_downsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n        attention_type=\"default\",\n        image_cross_attention_dim=512,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n            \n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        additional_residuals=None,\n        image_encoder_hidden_states: Optional[torch.FloatTensor] = None,\n    ):\n        output_states = ()\n\n        blocks = list(zip(self.resnets, self.attentions))\n        for i, (resnet, attn) in enumerate(blocks):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n\n            # apply additional residuals to the output of the last pair of resnet and attention blocks\n            if i == len(blocks) - 1 and additional_residuals is not None:\n                hidden_states = hidden_states + additional_residuals\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass DownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        downsample_padding=1,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, temb=None):\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass DownEncoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        downsample_padding=1,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=None,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states):\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=None)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states\n\n\nclass AttnDownEncoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        downsample_padding=1,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=None,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    Downsample2D(\n                        out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name=\"op\"\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n    def forward(self, hidden_states):\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb=None)\n            hidden_states = attn(hidden_states)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states\n\n\nclass AttnSkipDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=np.sqrt(2.0),\n        add_downsample=True,\n    ):\n        super().__init__()\n        self.attentions = nn.ModuleList([])\n        self.resnets = nn.ModuleList([])\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(in_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            self.attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=32,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        if add_downsample:\n            self.resnet_down = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                down=True,\n                kernel=\"fir\",\n            )\n            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])\n            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))\n        else:\n            self.resnet_down = None\n            self.downsamplers = None\n            self.skip_conv = None\n\n    def forward(self, hidden_states, temb=None, skip_sample=None):\n        output_states = ()\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states)\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            hidden_states = self.resnet_down(hidden_states, temb)\n            for downsampler in self.downsamplers:\n                skip_sample = downsampler(skip_sample)\n\n            hidden_states = self.skip_conv(skip_sample) + hidden_states\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states, skip_sample\n\n\nclass SkipDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        output_scale_factor=np.sqrt(2.0),\n        add_downsample=True,\n        downsample_padding=1,\n    ):\n        super().__init__()\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(in_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        if add_downsample:\n            self.resnet_down = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                down=True,\n                kernel=\"fir\",\n            )\n            self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])\n            self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))\n        else:\n            self.resnet_down = None\n            self.downsamplers = None\n            self.skip_conv = None\n\n    def forward(self, hidden_states, temb=None, skip_sample=None):\n        output_states = ()\n\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb)\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            hidden_states = self.resnet_down(hidden_states, temb)\n            for downsampler in self.downsamplers:\n                skip_sample = downsampler(skip_sample)\n\n            hidden_states = self.skip_conv(skip_sample) + hidden_states\n\n            output_states += (hidden_states,)\n\n        return hidden_states, output_states, skip_sample\n\n\nclass ResnetDownsampleBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        skip_time_act=False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, temb=None):\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, temb)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass SimpleCrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        add_downsample=True,\n        skip_time_act=False,\n        only_cross_attention=False,\n        cross_attention_norm=None,\n    ):\n        super().__init__()\n\n        self.has_cross_attention = True\n\n        resnets = []\n        attentions = []\n\n        self.attention_head_dim = attention_head_dim\n        self.num_heads = out_channels // self.attention_head_dim\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=out_channels,\n                    cross_attention_dim=out_channels,\n                    heads=self.num_heads,\n                    dim_head=attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        down=True,\n                    )\n                ]\n            )\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        output_states = ()\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n\n            output_states = output_states + (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states, temb)\n\n            output_states = output_states + (hidden_states,)\n\n        return hidden_states, output_states\n\n\nclass KDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: int = 32,\n        add_downsample=False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    dropout=dropout,\n                    temb_channels=temb_channels,\n                    groups=groups,\n                    groups_out=groups_out,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_downsample:\n            # YiYi's comments- might be able to use FirDownsample2D, look into details later\n            self.downsamplers = nn.ModuleList([KDownsample2D()])\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, temb=None):\n        output_states = ()\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n            output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass KCrossAttnDownBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        cross_attention_dim: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_group_size: int = 32,\n        add_downsample=True,\n        attention_head_dim: int = 64,\n        add_self_attention: bool = False,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n\n        for i in range(num_layers):\n            in_channels = in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    dropout=dropout,\n                    temb_channels=temb_channels,\n                    groups=groups,\n                    groups_out=groups_out,\n                    eps=resnet_eps,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n            attentions.append(\n                KAttentionBlock(\n                    out_channels,\n                    out_channels // attention_head_dim,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                    temb_channels=temb_channels,\n                    attention_bias=True,\n                    add_self_attention=add_self_attention,\n                    cross_attention_norm=\"layer_norm\",\n                    group_size=resnet_group_size,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        if add_downsample:\n            self.downsamplers = nn.ModuleList([KDownsample2D()])\n        else:\n            self.downsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        output_states = ()\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n\n            if self.downsamplers is None:\n                output_states += (None,)\n            else:\n                output_states += (hidden_states,)\n\n        if self.downsamplers is not None:\n            for downsampler in self.downsamplers:\n                hidden_states = downsampler(hidden_states)\n\n        return hidden_states, output_states\n\n\nclass AttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n        upsample_type=\"conv\",\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.upsample_type = upsample_type\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if upsample_type == \"conv\":\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        elif upsample_type == \"resnet\":\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n            hidden_states = attn(hidden_states)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                if self.upsample_type == \"resnet\":\n                    hidden_states = upsampler(hidden_states, temb=temb)\n                else:\n                    hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass CrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        transformer_layers_per_block: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        num_attention_heads=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        dual_cross_attention=False,\n        use_linear_projection=False,\n        only_cross_attention=False,\n        upcast_attention=False,\n        attention_type=\"default\",\n        image_cross_attention_dim=512,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.num_attention_heads = num_attention_heads\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            if not dual_cross_attention:\n                attentions.append(\n                    Transformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=transformer_layers_per_block,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                        use_linear_projection=use_linear_projection,\n                        only_cross_attention=only_cross_attention,\n                        upcast_attention=upcast_attention,\n                        attention_type=attention_type,\n                    )\n                )\n            else:\n                attentions.append(\n                    DualTransformer2DModel(\n                        num_attention_heads,\n                        out_channels // num_attention_heads,\n                        in_channels=out_channels,\n                        num_layers=1,\n                        cross_attention_dim=cross_attention_dim,\n                        norm_num_groups=resnet_groups,\n                    )\n                )\n\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n        image_encoder_hidden_states: Optional[torch.FloatTensor] = None,\n    ):  \n    \n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    return_dict=False,\n                )[0]\n                \n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states\n\n\nclass UpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_upsample=True,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, upsample_size)\n\n        return hidden_states\n\n\nclass UpDecoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",  # default, spatial\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        temb_channels=None,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=input_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n    def forward(self, hidden_states, temb=None):\n        for resnet in self.resnets:\n            hidden_states = resnet(hidden_states, temb=temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass AttnUpDecoderBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        temb_channels=None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        for i in range(num_layers):\n            input_channels = in_channels if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=input_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n            attentions.append(\n                Attention(\n                    out_channels,\n                    heads=out_channels // attention_head_dim,\n                    dim_head=attention_head_dim,\n                    rescale_output_factor=output_scale_factor,\n                    eps=resnet_eps,\n                    norm_num_groups=resnet_groups if resnet_time_scale_shift != \"spatial\" else None,\n                    spatial_norm_dim=temb_channels if resnet_time_scale_shift == \"spatial\" else None,\n                    residual_connection=True,\n                    bias=True,\n                    upcast_softmax=True,\n                    _from_deprecated_attn_block=True,\n                )\n            )\n\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])\n        else:\n            self.upsamplers = None\n\n    def forward(self, hidden_states, temb=None):\n        for resnet, attn in zip(self.resnets, self.attentions):\n            hidden_states = resnet(hidden_states, temb=temb)\n            hidden_states = attn(hidden_states, temb=temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass AttnSkipUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        output_scale_factor=np.sqrt(2.0),\n        add_upsample=True,\n    ):\n        super().__init__()\n        self.attentions = nn.ModuleList([])\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min(resnet_in_channels + res_skip_channels // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        if attention_head_dim is None:\n            logger.warn(\n                f\"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}.\"\n            )\n            attention_head_dim = out_channels\n\n        self.attentions.append(\n            Attention(\n                out_channels,\n                heads=out_channels // attention_head_dim,\n                dim_head=attention_head_dim,\n                rescale_output_factor=output_scale_factor,\n                eps=resnet_eps,\n                norm_num_groups=32,\n                residual_connection=True,\n                bias=True,\n                upcast_softmax=True,\n                _from_deprecated_attn_block=True,\n            )\n        )\n\n        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)\n        if add_upsample:\n            self.resnet_up = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                groups_out=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                up=True,\n                kernel=\"fir\",\n            )\n            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n            self.skip_norm = torch.nn.GroupNorm(\n                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True\n            )\n            self.act = nn.SiLU()\n        else:\n            self.resnet_up = None\n            self.skip_conv = None\n            self.skip_norm = None\n            self.act = None\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n\n        hidden_states = self.attentions[0](hidden_states)\n\n        if skip_sample is not None:\n            skip_sample = self.upsampler(skip_sample)\n        else:\n            skip_sample = 0\n\n        if self.resnet_up is not None:\n            skip_sample_states = self.skip_norm(hidden_states)\n            skip_sample_states = self.act(skip_sample_states)\n            skip_sample_states = self.skip_conv(skip_sample_states)\n\n            skip_sample = skip_sample + skip_sample_states\n\n            hidden_states = self.resnet_up(hidden_states, temb)\n\n        return hidden_states, skip_sample\n\n\nclass SkipUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_pre_norm: bool = True,\n        output_scale_factor=np.sqrt(2.0),\n        add_upsample=True,\n        upsample_padding=1,\n    ):\n        super().__init__()\n        self.resnets = nn.ModuleList([])\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            self.resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=min((resnet_in_channels + res_skip_channels) // 4, 32),\n                    groups_out=min(out_channels // 4, 32),\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                )\n            )\n\n        self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)\n        if add_upsample:\n            self.resnet_up = ResnetBlock2D(\n                in_channels=out_channels,\n                out_channels=out_channels,\n                temb_channels=temb_channels,\n                eps=resnet_eps,\n                groups=min(out_channels // 4, 32),\n                groups_out=min(out_channels // 4, 32),\n                dropout=dropout,\n                time_embedding_norm=resnet_time_scale_shift,\n                non_linearity=resnet_act_fn,\n                output_scale_factor=output_scale_factor,\n                pre_norm=resnet_pre_norm,\n                use_in_shortcut=True,\n                up=True,\n                kernel=\"fir\",\n            )\n            self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n            self.skip_norm = torch.nn.GroupNorm(\n                num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True\n            )\n            self.act = nn.SiLU()\n        else:\n            self.resnet_up = None\n            self.skip_conv = None\n            self.skip_norm = None\n            self.act = None\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            hidden_states = resnet(hidden_states, temb)\n\n        if skip_sample is not None:\n            skip_sample = self.upsampler(skip_sample)\n        else:\n            skip_sample = 0\n\n        if self.resnet_up is not None:\n            skip_sample_states = self.skip_norm(hidden_states)\n            skip_sample_states = self.act(skip_sample_states)\n            skip_sample_states = self.skip_conv(skip_sample_states)\n\n            skip_sample = skip_sample + skip_sample_states\n\n            hidden_states = self.resnet_up(hidden_states, temb)\n\n        return hidden_states, skip_sample\n\n\nclass ResnetUpsampleBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        prev_output_channel: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        skip_time_act=False,\n    ):\n        super().__init__()\n        resnets = []\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass SimpleCrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        prev_output_channel: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 1,\n        resnet_eps: float = 1e-6,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_act_fn: str = \"swish\",\n        resnet_groups: int = 32,\n        resnet_pre_norm: bool = True,\n        attention_head_dim=1,\n        cross_attention_dim=1280,\n        output_scale_factor=1.0,\n        add_upsample=True,\n        skip_time_act=False,\n        only_cross_attention=False,\n        cross_attention_norm=None,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        self.has_cross_attention = True\n        self.attention_head_dim = attention_head_dim\n\n        self.num_heads = out_channels // self.attention_head_dim\n\n        for i in range(num_layers):\n            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels\n            resnet_in_channels = prev_output_channel if i == 0 else out_channels\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=resnet_in_channels + res_skip_channels,\n                    out_channels=out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=resnet_groups,\n                    dropout=dropout,\n                    time_embedding_norm=resnet_time_scale_shift,\n                    non_linearity=resnet_act_fn,\n                    output_scale_factor=output_scale_factor,\n                    pre_norm=resnet_pre_norm,\n                    skip_time_act=skip_time_act,\n                )\n            )\n\n            processor = (\n                AttnAddedKVProcessor2_0() if hasattr(F, \"scaled_dot_product_attention\") else AttnAddedKVProcessor()\n            )\n\n            attentions.append(\n                Attention(\n                    query_dim=out_channels,\n                    cross_attention_dim=out_channels,\n                    heads=self.num_heads,\n                    dim_head=self.attention_head_dim,\n                    added_kv_proj_dim=cross_attention_dim,\n                    norm_num_groups=resnet_groups,\n                    bias=True,\n                    upcast_softmax=True,\n                    only_cross_attention=only_cross_attention,\n                    cross_attention_norm=cross_attention_norm,\n                    processor=processor,\n                )\n            )\n        self.attentions = nn.ModuleList(attentions)\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList(\n                [\n                    ResnetBlock2D(\n                        in_channels=out_channels,\n                        out_channels=out_channels,\n                        temb_channels=temb_channels,\n                        eps=resnet_eps,\n                        groups=resnet_groups,\n                        dropout=dropout,\n                        time_embedding_norm=resnet_time_scale_shift,\n                        non_linearity=resnet_act_fn,\n                        output_scale_factor=output_scale_factor,\n                        pre_norm=resnet_pre_norm,\n                        skip_time_act=skip_time_act,\n                        up=True,\n                    )\n                ]\n            )\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        if attention_mask is None:\n            # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.\n            mask = None if encoder_hidden_states is None else encoder_attention_mask\n        else:\n            # when attention_mask is defined: we don't even check for encoder_attention_mask.\n            # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.\n            # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.\n            #       then we can simplify this whole if/else block to:\n            #         mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask\n            mask = attention_mask\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # resnet\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)\n\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=mask,\n                    **cross_attention_kwargs,\n                )\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states, temb)\n\n        return hidden_states\n\n\nclass KUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 5,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: Optional[int] = 32,\n        add_upsample=True,\n    ):\n        super().__init__()\n        resnets = []\n        k_in_channels = 2 * out_channels\n        k_out_channels = in_channels\n        num_layers = num_layers - 1\n\n        for i in range(num_layers):\n            in_channels = k_in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=k_out_channels if (i == num_layers - 1) else out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=groups,\n                    groups_out=groups_out,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([KUpsample2D()])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):\n        res_hidden_states_tuple = res_hidden_states_tuple[-1]\n        if res_hidden_states_tuple is not None:\n            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)\n\n        for resnet in self.resnets:\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module):\n                    def custom_forward(*inputs):\n                        return module(*inputs)\n\n                    return custom_forward\n\n                if is_torch_version(\">=\", \"1.11.0\"):\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb, use_reentrant=False\n                    )\n                else:\n                    hidden_states = torch.utils.checkpoint.checkpoint(\n                        create_custom_forward(resnet), hidden_states, temb\n                    )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\nclass KCrossAttnUpBlock2D(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        temb_channels: int,\n        dropout: float = 0.0,\n        num_layers: int = 4,\n        resnet_eps: float = 1e-5,\n        resnet_act_fn: str = \"gelu\",\n        resnet_group_size: int = 32,\n        attention_head_dim=1,  # attention dim_head\n        cross_attention_dim: int = 768,\n        add_upsample: bool = True,\n        upcast_attention: bool = False,\n    ):\n        super().__init__()\n        resnets = []\n        attentions = []\n\n        is_first_block = in_channels == out_channels == temb_channels\n        is_middle_block = in_channels != out_channels\n        add_self_attention = True if is_first_block else False\n\n        self.has_cross_attention = True\n        self.attention_head_dim = attention_head_dim\n\n        # in_channels, and out_channels for the block (k-unet)\n        k_in_channels = out_channels if is_first_block else 2 * out_channels\n        k_out_channels = in_channels\n\n        num_layers = num_layers - 1\n\n        for i in range(num_layers):\n            in_channels = k_in_channels if i == 0 else out_channels\n            groups = in_channels // resnet_group_size\n            groups_out = out_channels // resnet_group_size\n\n            if is_middle_block and (i == num_layers - 1):\n                conv_2d_out_channels = k_out_channels\n            else:\n                conv_2d_out_channels = None\n\n            resnets.append(\n                ResnetBlock2D(\n                    in_channels=in_channels,\n                    out_channels=out_channels,\n                    conv_2d_out_channels=conv_2d_out_channels,\n                    temb_channels=temb_channels,\n                    eps=resnet_eps,\n                    groups=groups,\n                    groups_out=groups_out,\n                    dropout=dropout,\n                    non_linearity=resnet_act_fn,\n                    time_embedding_norm=\"ada_group\",\n                    conv_shortcut_bias=False,\n                )\n            )\n            attentions.append(\n                KAttentionBlock(\n                    k_out_channels if (i == num_layers - 1) else out_channels,\n                    k_out_channels // attention_head_dim\n                    if (i == num_layers - 1)\n                    else out_channels // attention_head_dim,\n                    attention_head_dim,\n                    cross_attention_dim=cross_attention_dim,\n                    temb_channels=temb_channels,\n                    attention_bias=True,\n                    add_self_attention=add_self_attention,\n                    cross_attention_norm=\"layer_norm\",\n                    upcast_attention=upcast_attention,\n                )\n            )\n\n        self.resnets = nn.ModuleList(resnets)\n        self.attentions = nn.ModuleList(attentions)\n\n        if add_upsample:\n            self.upsamplers = nn.ModuleList([KUpsample2D()])\n        else:\n            self.upsamplers = None\n\n        self.gradient_checkpointing = False\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],\n        temb: Optional[torch.FloatTensor] = None,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        upsample_size: Optional[int] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        res_hidden_states_tuple = res_hidden_states_tuple[-1]\n        if res_hidden_states_tuple is not None:\n            hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)\n\n        for resnet, attn in zip(self.resnets, self.attentions):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(resnet),\n                    hidden_states,\n                    temb,\n                    **ckpt_kwargs,\n                )\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n            else:\n                hidden_states = resnet(hidden_states, temb)\n                hidden_states = attn(\n                    hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    emb=temb,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                )\n\n        if self.upsamplers is not None:\n            for upsampler in self.upsamplers:\n                hidden_states = upsampler(hidden_states)\n\n        return hidden_states\n\n\n# can potentially later be renamed to `No-feed-forward` attention\nclass KAttentionBlock(nn.Module):\n    r\"\"\"\n    A basic Transformer block.\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.\n        cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.\n        activation_fn (`str`, *optional*, defaults to `\"geglu\"`): Activation function to be used in feed-forward.\n        num_embeds_ada_norm (:\n            obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.\n        attention_bias (:\n            obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.\n    \"\"\"\n\n    def __init__(\n        self,\n        dim: int,\n        num_attention_heads: int,\n        attention_head_dim: int,\n        dropout: float = 0.0,\n        cross_attention_dim: Optional[int] = None,\n        attention_bias: bool = False,\n        upcast_attention: bool = False,\n        temb_channels: int = 768,  # for ada_group_norm\n        add_self_attention: bool = False,\n        cross_attention_norm: Optional[str] = None,\n        group_size: int = 32,\n    ):\n        super().__init__()\n        self.add_self_attention = add_self_attention\n\n        # 1. Self-Attn\n        if add_self_attention:\n            self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))\n            self.attn1 = Attention(\n                query_dim=dim,\n                heads=num_attention_heads,\n                dim_head=attention_head_dim,\n                dropout=dropout,\n                bias=attention_bias,\n                cross_attention_dim=None,\n                cross_attention_norm=None,\n            )\n\n        # 2. Cross-Attn\n        self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))\n        self.attn2 = Attention(\n            query_dim=dim,\n            cross_attention_dim=cross_attention_dim,\n            heads=num_attention_heads,\n            dim_head=attention_head_dim,\n            dropout=dropout,\n            bias=attention_bias,\n            upcast_attention=upcast_attention,\n            cross_attention_norm=cross_attention_norm,\n        )\n\n    def _to_3d(self, hidden_states, height, weight):\n        return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)\n\n    def _to_4d(self, hidden_states, height, weight):\n        return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: Optional[torch.FloatTensor] = None,\n        # TODO: mark emb as non-optional (self.norm2 requires it).\n        #       requires assessing impact of change to positional param interface.\n        emb: Optional[torch.FloatTensor] = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        encoder_attention_mask: Optional[torch.FloatTensor] = None,\n    ):\n        cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}\n\n        # 1. Self-Attention\n        if self.add_self_attention:\n            norm_hidden_states = self.norm1(hidden_states, emb)\n\n            height, weight = norm_hidden_states.shape[2:]\n            norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)\n\n            attn_output = self.attn1(\n                norm_hidden_states,\n                encoder_hidden_states=None,\n                attention_mask=attention_mask,\n                **cross_attention_kwargs,\n            )\n            attn_output = self._to_4d(attn_output, height, weight)\n\n            hidden_states = attn_output + hidden_states\n\n        # 2. Cross-Attention/None\n        norm_hidden_states = self.norm2(hidden_states, emb)\n\n        height, weight = norm_hidden_states.shape[2:]\n        norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)\n        attn_output = self.attn2(\n            norm_hidden_states,\n            encoder_hidden_states=encoder_hidden_states,\n            attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,\n            **cross_attention_kwargs,\n        )\n        attn_output = self._to_4d(attn_output, height, weight)\n\n        hidden_states = attn_output + hidden_states\n\n        return hidden_states\n"
  },
  {
    "path": "models/unet_2d_condition.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import UNet2DConditionLoadersMixin\nfrom diffusers.utils import BaseOutput, logging\nfrom diffusers.models.activations import get_activation\nfrom diffusers.models.attention_processor import AttentionProcessor, AttnProcessor\nfrom diffusers.models.embeddings import (\n    GaussianFourierProjection,\n    ImageHintTimeEmbedding,\n    ImageProjection,\n    ImageTimeEmbedding,\n    PositionNet,\n    TextImageProjection,\n    TextImageTimeEmbedding,\n    TextTimeEmbedding,\n    TimestepEmbedding,\n    Timesteps,\n)\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom .unet_2d_blocks import (\n    UNetMidBlock2DCrossAttn,\n    UNetMidBlock2DSimpleCrossAttn,\n    get_down_block,\n    get_up_block,\n)\n\nimport os, json\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@dataclass\nclass UNet2DConditionOutput(BaseOutput):\n    \"\"\"\n    The output of [`UNet2DConditionModel`].\n\n    Args:\n        sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):\n            The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: torch.FloatTensor = None\n\n\nclass UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):\n    r\"\"\"\n    A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample\n    shaped output.\n\n    This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented\n    for all models (such as downloading or saving).\n\n    Parameters:\n        sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):\n            Height and width of input/output sample.\n        in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.\n        out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.\n        center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.\n        flip_sin_to_cos (`bool`, *optional*, defaults to `False`):\n            Whether to flip the sin to cos in the time embedding.\n        freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use.\n        mid_block_type (`str`, *optional*, defaults to `\"UNetMidBlock2DCrossAttn\"`):\n            Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or\n            `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")`):\n            The tuple of upsample blocks to use.\n        only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):\n            Whether to include self-attention in the basic transformer blocks, see\n            [`~models.attention.BasicTransformerBlock`].\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.\n        downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.\n        mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.\n        act_fn (`str`, *optional*, defaults to `\"silu\"`): The activation function to use.\n        norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.\n            If `None`, normalization and activation layers is skipped in post-processing.\n        norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.\n        cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):\n            The dimension of the cross attention features.\n        transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):\n            The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for\n            [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],\n            [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].\n        encoder_hid_dim (`int`, *optional*, defaults to None):\n            If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`\n            dimension to `cross_attention_dim`.\n        encoder_hid_dim_type (`str`, *optional*, defaults to `None`):\n            If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text\n            embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.\n        attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.\n        num_attention_heads (`int`, *optional*):\n            The number of attention heads. If not defined, defaults to `attention_head_dim`\n        resnet_time_scale_shift (`str`, *optional*, defaults to `\"default\"`): Time scale shift config\n            for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.\n        class_embed_type (`str`, *optional*, defaults to `None`):\n            The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,\n            `\"timestep\"`, `\"identity\"`, `\"projection\"`, or `\"simple_projection\"`.\n        addition_embed_type (`str`, *optional*, defaults to `None`):\n            Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or\n            \"text\". \"text\" will use the `TextTimeEmbedding` layer.\n        addition_time_embed_dim: (`int`, *optional*, defaults to `None`):\n            Dimension for the timestep embeddings.\n        num_class_embeds (`int`, *optional*, defaults to `None`):\n            Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing\n            class conditioning with `class_embed_type` equal to `None`.\n        time_embedding_type (`str`, *optional*, defaults to `positional`):\n            The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.\n        time_embedding_dim (`int`, *optional*, defaults to `None`):\n            An optional override for the dimension of the projected time embedding.\n        time_embedding_act_fn (`str`, *optional*, defaults to `None`):\n            Optional activation function to use only once on the time embeddings before they are passed to the rest of\n            the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.\n        timestep_post_act (`str`, *optional*, defaults to `None`):\n            The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.\n        time_cond_proj_dim (`int`, *optional*, defaults to `None`):\n            The dimension of `cond_proj` layer in the timestep embedding.\n        conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.\n        conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.\n        projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when\n            `class_embed_type=\"projection\"`. Required when `class_embed_type=\"projection\"`.\n        class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time\n            embeddings with the class embeddings.\n        mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):\n            Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If\n            `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the\n            `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`\n            otherwise.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n\n    @register_to_config\n    def __init__(\n        self,\n        sample_size: Optional[int] = None,\n        in_channels: int = 4,\n        out_channels: int = 4,\n        center_input_sample: bool = False,\n        flip_sin_to_cos: bool = True,\n        freq_shift: int = 0,\n        down_block_types: Tuple[str] = (\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"CrossAttnDownBlock2D\",\n            \"DownBlock2D\",\n        ),\n        mid_block_type: Optional[str] = \"UNetMidBlock2DCrossAttn\",\n        up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\"),\n        only_cross_attention: Union[bool, Tuple[bool]] = False,\n        block_out_channels: Tuple[int] = (320, 640, 1280, 1280),\n        layers_per_block: Union[int, Tuple[int]] = 2,\n        downsample_padding: int = 1,\n        mid_block_scale_factor: float = 1,\n        act_fn: str = \"silu\",\n        norm_num_groups: Optional[int] = 32,\n        norm_eps: float = 1e-5,\n        cross_attention_dim: Union[int, Tuple[int]] = 1280,\n        transformer_layers_per_block: Union[int, Tuple[int]] = 1,\n        encoder_hid_dim: Optional[int] = None,\n        encoder_hid_dim_type: Optional[str] = None,\n        attention_head_dim: Union[int, Tuple[int]] = 8,\n        num_attention_heads: Optional[Union[int, Tuple[int]]] = None,\n        dual_cross_attention: bool = False,\n        use_linear_projection: bool = False,\n        class_embed_type: Optional[str] = None,\n        addition_embed_type: Optional[str] = None,\n        addition_time_embed_dim: Optional[int] = None,\n        num_class_embeds: Optional[int] = None,\n        upcast_attention: bool = False,\n        resnet_time_scale_shift: str = \"default\",\n        resnet_skip_time_act: bool = False,\n        resnet_out_scale_factor: int = 1.0,\n        time_embedding_type: str = \"positional\",\n        time_embedding_dim: Optional[int] = None,\n        time_embedding_act_fn: Optional[str] = None,\n        timestep_post_act: Optional[str] = None,\n        time_cond_proj_dim: Optional[int] = None,\n        conv_in_kernel: int = 3,\n        conv_out_kernel: int = 3,\n        projection_class_embeddings_input_dim: Optional[int] = None,\n        attention_type: str = \"default\",\n        class_embeddings_concat: bool = False,\n        mid_block_only_cross_attention: Optional[bool] = None,\n        cross_attention_norm: Optional[str] = None,\n        addition_embed_type_num_heads=64,\n    ):\n        super().__init__()\n\n        self.sample_size = sample_size\n\n        if num_attention_heads is not None:\n            raise ValueError(\n                \"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19.\"\n            )\n\n        # If `num_attention_heads` is not defined (which is the case for most models)\n        # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.\n        # The reason for this behavior is to correct for incorrectly named variables that were introduced\n        # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131\n        # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking\n        # which is why we correct for the naming here.\n        num_attention_heads = num_attention_heads or attention_head_dim\n\n        # Check inputs\n        if len(down_block_types) != len(up_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}.\"\n            )\n\n        if len(block_out_channels) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}.\"\n            )\n\n        if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):\n            raise ValueError(\n                f\"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}.\"\n            )\n\n        # input\n        conv_in_padding = (conv_in_kernel - 1) // 2\n        self.conv_in = nn.Conv2d(\n            in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding\n        )\n\n        # time\n        if time_embedding_type == \"fourier\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 2\n            if time_embed_dim % 2 != 0:\n                raise ValueError(f\"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.\")\n            self.time_proj = GaussianFourierProjection(\n                time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos\n            )\n            timestep_input_dim = time_embed_dim\n        elif time_embedding_type == \"positional\":\n            time_embed_dim = time_embedding_dim or block_out_channels[0] * 4\n\n            self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)\n            timestep_input_dim = block_out_channels[0]\n        else:\n            raise ValueError(\n                f\"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`.\"\n            )\n\n        self.time_embedding = TimestepEmbedding(\n            timestep_input_dim,\n            time_embed_dim,\n            act_fn=act_fn,\n            post_act_fn=timestep_post_act,\n            cond_proj_dim=time_cond_proj_dim,\n        )\n\n        if encoder_hid_dim_type is None and encoder_hid_dim is not None:\n            encoder_hid_dim_type = \"text_proj\"\n            self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)\n            logger.info(\"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.\")\n\n        if encoder_hid_dim is None and encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}.\"\n            )\n\n        if encoder_hid_dim_type == \"text_proj\":\n            self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)\n        elif encoder_hid_dim_type == \"text_image_proj\":\n            # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image_proj\"` (Kadinsky 2.1)`\n            self.encoder_hid_proj = TextImageProjection(\n                text_embed_dim=encoder_hid_dim,\n                image_embed_dim=cross_attention_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2\n            self.encoder_hid_proj = ImageProjection(\n                image_embed_dim=encoder_hid_dim,\n                cross_attention_dim=cross_attention_dim,\n            )\n        elif encoder_hid_dim_type is not None:\n            raise ValueError(\n                f\"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'.\"\n            )\n        else:\n            self.encoder_hid_proj = None\n\n        # class embedding\n        if class_embed_type is None and num_class_embeds is not None:\n            self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)\n        elif class_embed_type == \"timestep\":\n            self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)\n        elif class_embed_type == \"identity\":\n            self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)\n        elif class_embed_type == \"projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            # The projection `class_embed_type` is the same as the timestep `class_embed_type` except\n            # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings\n            # 2. it projects from an arbitrary input dimension.\n            #\n            # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.\n            # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.\n            # As a result, `TimestepEmbedding` can be passed arbitrary vectors.\n            self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif class_embed_type == \"simple_projection\":\n            if projection_class_embeddings_input_dim is None:\n                raise ValueError(\n                    \"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set\"\n                )\n            self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)\n        else:\n            self.class_embedding = None\n\n        if addition_embed_type == \"text\":\n            if encoder_hid_dim is not None:\n                text_time_embedding_from_dim = encoder_hid_dim\n            else:\n                text_time_embedding_from_dim = cross_attention_dim\n\n            self.add_embedding = TextTimeEmbedding(\n                text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads\n            )\n        elif addition_embed_type == \"text_image\":\n            # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much\n            # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use\n            # case when `addition_embed_type == \"text_image\"` (Kadinsky 2.1)`\n            self.add_embedding = TextImageTimeEmbedding(\n                text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim\n            )\n        elif addition_embed_type == \"text_time\":\n            self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)\n            self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)\n        elif addition_embed_type == \"image\":\n            # Kandinsky 2.2\n            self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 ControlNet\n            self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)\n        elif addition_embed_type is not None:\n            raise ValueError(f\"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.\")\n\n        if time_embedding_act_fn is None:\n            self.time_embed_act = None\n        else:\n            self.time_embed_act = get_activation(time_embedding_act_fn)\n\n        self.down_blocks = nn.ModuleList([])\n        self.up_blocks = nn.ModuleList([])\n\n        if isinstance(only_cross_attention, bool):\n            if mid_block_only_cross_attention is None:\n                mid_block_only_cross_attention = only_cross_attention\n\n            only_cross_attention = [only_cross_attention] * len(down_block_types)\n\n        if mid_block_only_cross_attention is None:\n            mid_block_only_cross_attention = False\n\n        if isinstance(num_attention_heads, int):\n            num_attention_heads = (num_attention_heads,) * len(down_block_types)\n\n        if isinstance(attention_head_dim, int):\n            attention_head_dim = (attention_head_dim,) * len(down_block_types)\n\n        if isinstance(cross_attention_dim, int):\n            cross_attention_dim = (cross_attention_dim,) * len(down_block_types)\n\n        if isinstance(layers_per_block, int):\n            layers_per_block = [layers_per_block] * len(down_block_types)\n\n        if isinstance(transformer_layers_per_block, int):\n            transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)\n\n        if class_embeddings_concat:\n            # The time embeddings are concatenated with the class embeddings. The dimension of the\n            # time embeddings passed to the down, middle, and up blocks is twice the dimension of the\n            # regular time embeddings\n            blocks_time_embed_dim = time_embed_dim * 2\n        else:\n            blocks_time_embed_dim = time_embed_dim\n\n        # down\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            down_block = get_down_block(\n                down_block_type,\n                num_layers=layers_per_block[i],\n                transformer_layers_per_block=transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_downsample=not is_final_block,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=cross_attention_dim[i],\n                num_attention_heads=num_attention_heads[i],\n                downsample_padding=downsample_padding,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n            )\n            self.down_blocks.append(down_block)\n\n        # mid\n        if mid_block_type == \"UNetMidBlock2DCrossAttn\":\n            self.mid_block = UNetMidBlock2DCrossAttn(\n                transformer_layers_per_block=transformer_layers_per_block[-1],\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                cross_attention_dim=cross_attention_dim[-1],\n                num_attention_heads=num_attention_heads[-1],\n                resnet_groups=norm_num_groups,\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                upcast_attention=upcast_attention,\n                attention_type=attention_type,\n            )\n        elif mid_block_type == \"UNetMidBlock2DSimpleCrossAttn\":\n            self.mid_block = UNetMidBlock2DSimpleCrossAttn(\n                in_channels=block_out_channels[-1],\n                temb_channels=blocks_time_embed_dim,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                output_scale_factor=mid_block_scale_factor,\n                cross_attention_dim=cross_attention_dim[-1],\n                attention_head_dim=attention_head_dim[-1],\n                resnet_groups=norm_num_groups,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                skip_time_act=resnet_skip_time_act,\n                only_cross_attention=mid_block_only_cross_attention,\n                cross_attention_norm=cross_attention_norm,\n            )\n        elif mid_block_type is None:\n            self.mid_block = None\n        else:\n            raise ValueError(f\"unknown mid_block_type : {mid_block_type}\")\n\n        # count how many layers upsample the images\n        self.num_upsamplers = 0\n\n        # up\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        reversed_num_attention_heads = list(reversed(num_attention_heads))\n        reversed_layers_per_block = list(reversed(layers_per_block))\n        reversed_cross_attention_dim = list(reversed(cross_attention_dim))\n        reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))\n        only_cross_attention = list(reversed(only_cross_attention))\n\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(up_block_types):\n            is_final_block = i == len(block_out_channels) - 1\n\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]\n\n            # add upsample block for all BUT final layer\n            if not is_final_block:\n                add_upsample = True\n                self.num_upsamplers += 1\n            else:\n                add_upsample = False\n\n            up_block = get_up_block(\n                up_block_type,\n                num_layers=reversed_layers_per_block[i] + 1,\n                transformer_layers_per_block=reversed_transformer_layers_per_block[i],\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                temb_channels=blocks_time_embed_dim,\n                add_upsample=add_upsample,\n                resnet_eps=norm_eps,\n                resnet_act_fn=act_fn,\n                resnet_groups=norm_num_groups,\n                cross_attention_dim=reversed_cross_attention_dim[i],\n                num_attention_heads=reversed_num_attention_heads[i],\n                dual_cross_attention=dual_cross_attention,\n                use_linear_projection=use_linear_projection,\n                only_cross_attention=only_cross_attention[i],\n                upcast_attention=upcast_attention,\n                resnet_time_scale_shift=resnet_time_scale_shift,\n                attention_type=attention_type,\n                resnet_skip_time_act=resnet_skip_time_act,\n                resnet_out_scale_factor=resnet_out_scale_factor,\n                cross_attention_norm=cross_attention_norm,\n                attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,\n            )\n            self.up_blocks.append(up_block)\n            prev_output_channel = output_channel\n\n        # out\n        if norm_num_groups is not None:\n            self.conv_norm_out = nn.GroupNorm(\n                num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps\n            )\n\n            self.conv_act = get_activation(act_fn)\n\n        else:\n            self.conv_norm_out = None\n            self.conv_act = None\n\n        conv_out_padding = (conv_out_kernel - 1) // 2\n        self.conv_out = nn.Conv2d(\n            block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding\n        )\n\n        if attention_type == \"gated\":\n            positive_len = 768\n            if isinstance(cross_attention_dim, int):\n                positive_len = cross_attention_dim\n            elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):\n                positive_len = cross_attention_dim[0]\n            self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)\n\n    @property\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor(return_deprecated_lora=True)\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    def set_default_attn_processor(self):\n        \"\"\"\n        Disables custom attention processors and sets the default attention implementation.\n        \"\"\"\n        self.set_attn_processor(AttnProcessor())\n\n    def set_attention_slice(self, slice_size):\n        r\"\"\"\n        Enable sliced attention computation.\n\n        When this option is enabled, the attention module splits the input tensor in slices to compute attention in\n        several steps. This is useful for saving some memory in exchange for a small decrease in speed.\n\n        Args:\n            slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `\"auto\"`):\n                When `\"auto\"`, input to the attention heads is halved, so attention is computed in two steps. If\n                `\"max\"`, maximum amount of memory is saved by running only one slice at a time. If a number is\n                provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`\n                must be a multiple of `slice_size`.\n        \"\"\"\n        sliceable_head_dims = []\n\n        def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):\n            if hasattr(module, \"set_attention_slice\"):\n                sliceable_head_dims.append(module.sliceable_head_dim)\n\n            for child in module.children():\n                fn_recursive_retrieve_sliceable_dims(child)\n\n        # retrieve number of attention layers\n        for module in self.children():\n            fn_recursive_retrieve_sliceable_dims(module)\n\n        num_sliceable_layers = len(sliceable_head_dims)\n\n        if slice_size == \"auto\":\n            # half the attention head size is usually a good trade-off between\n            # speed and memory\n            slice_size = [dim // 2 for dim in sliceable_head_dims]\n        elif slice_size == \"max\":\n            # make smallest slice possible\n            slice_size = num_sliceable_layers * [1]\n\n        slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size\n\n        if len(slice_size) != len(sliceable_head_dims):\n            raise ValueError(\n                f\"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different\"\n                f\" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}.\"\n            )\n\n        for i in range(len(slice_size)):\n            size = slice_size[i]\n            dim = sliceable_head_dims[i]\n            if size is not None and size > dim:\n                raise ValueError(f\"size {size} has to be smaller or equal to {dim}.\")\n\n        # Recursively walk through all the children.\n        # Any children which exposes the set_attention_slice method\n        # gets the message\n        def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):\n            if hasattr(module, \"set_attention_slice\"):\n                module.set_attention_slice(slice_size.pop())\n\n            for child in module.children():\n                fn_recursive_set_attention_slice(child, slice_size)\n\n        reversed_slice_size = list(reversed(slice_size))\n        for module in self.children():\n            fn_recursive_set_attention_slice(module, reversed_slice_size)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        sample: torch.FloatTensor,\n        timestep: Union[torch.Tensor, float, int],\n        encoder_hidden_states: torch.Tensor,\n        class_labels: Optional[torch.Tensor] = None,\n        timestep_cond: Optional[torch.Tensor] = None,\n        attention_mask: Optional[torch.Tensor] = None,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,\n        down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,\n        mid_block_additional_residual: Optional[torch.Tensor] = None,\n        encoder_attention_mask: Optional[torch.Tensor] = None,\n        return_dict: bool = True,\n        image_encoder_hidden_states: torch.Tensor = None,\n    ) -> Union[UNet2DConditionOutput, Tuple]:\n        r\"\"\"\n        The [`UNet2DConditionModel`] forward method.\n\n        Args:\n            sample (`torch.FloatTensor`):\n                The noisy input tensor with the following shape `(batch, channel, height, width)`.\n            timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.\n            encoder_hidden_states (`torch.FloatTensor`):\n                The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.\n            encoder_attention_mask (`torch.Tensor`):\n                A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If\n                `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,\n                which adds large negative values to the attention scores corresponding to \"discard\" tokens.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain\n                tuple.\n            cross_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].\n            added_cond_kwargs: (`dict`, *optional*):\n                A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that\n                are passed along to the UNet blocks.\n\n        Returns:\n            [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:\n                If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise\n                a `tuple` is returned where the first element is the sample tensor.\n        \"\"\"\n        # By default samples have to be AT least a multiple of the overall upsampling factor.\n        # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).\n        # However, the upsampling interpolation output size can be forced to fit any upsampling size\n        # on the fly if necessary.\n        default_overall_up_factor = 2**self.num_upsamplers\n\n        # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`\n        forward_upsample_size = False\n        upsample_size = None\n\n        if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):\n            logger.info(\"Forward upsample size to force interpolation output size.\")\n            forward_upsample_size = True\n\n        # ensure attention_mask is a bias, and give it a singleton query_tokens dimension\n        # expects mask of shape:\n        #   [batch, key_tokens]\n        # adds singleton query_tokens dimension:\n        #   [batch,                    1, key_tokens]\n        # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:\n        #   [batch,  heads, query_tokens, key_tokens] (e.g. torch sdp attn)\n        #   [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)\n        if attention_mask is not None:\n            # assume that mask is expressed as:\n            #   (1 = keep,      0 = discard)\n            # convert mask into a bias that can be added to attention scores:\n            #       (keep = +0,     discard = -10000.0)\n            attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0\n            attention_mask = attention_mask.unsqueeze(1)\n\n        # convert encoder_attention_mask to a bias the same way we do for attention_mask\n        if encoder_attention_mask is not None:\n            encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0\n            encoder_attention_mask = encoder_attention_mask.unsqueeze(1)\n\n        # 0. center input if necessary\n        if self.config.center_input_sample:\n            sample = 2 * sample - 1.0\n\n        # 1. time\n        timesteps = timestep\n        if not torch.is_tensor(timesteps):\n            # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can\n            # This would be a good case for the `match` statement (Python 3.10+)\n            is_mps = sample.device.type == \"mps\"\n            if isinstance(timestep, float):\n                dtype = torch.float32 if is_mps else torch.float64\n            else:\n                dtype = torch.int32 if is_mps else torch.int64\n            timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)\n        elif len(timesteps.shape) == 0:\n            timesteps = timesteps[None].to(sample.device)\n\n        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n        timesteps = timesteps.expand(sample.shape[0])\n\n        t_emb = self.time_proj(timesteps)\n\n        # `Timesteps` does not contain any weights and will always return f32 tensors\n        # but time_embedding might actually be running in fp16. so we need to cast here.\n        # there might be better ways to encapsulate this.\n        t_emb = t_emb.to(dtype=sample.dtype)\n\n        emb = self.time_embedding(t_emb, timestep_cond)\n        aug_emb = None\n\n        if self.class_embedding is not None:\n            if class_labels is None:\n                raise ValueError(\"class_labels should be provided when num_class_embeds > 0\")\n\n            if self.config.class_embed_type == \"timestep\":\n                class_labels = self.time_proj(class_labels)\n\n                # `Timesteps` does not contain any weights and will always return f32 tensors\n                # there might be better ways to encapsulate this.\n                class_labels = class_labels.to(dtype=sample.dtype)\n\n            class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)\n\n            if self.config.class_embeddings_concat:\n                emb = torch.cat([emb, class_emb], dim=-1)\n            else:\n                emb = emb + class_emb\n\n        if self.config.addition_embed_type == \"text\":\n            aug_emb = self.add_embedding(encoder_hidden_states)\n        elif self.config.addition_embed_type == \"text_image\":\n            # Kandinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            text_embs = added_cond_kwargs.get(\"text_embeds\", encoder_hidden_states)\n            aug_emb = self.add_embedding(text_embs, image_embs)\n        elif self.config.addition_embed_type == \"text_time\":\n            # SDXL - style\n            if \"text_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            text_embeds = added_cond_kwargs.get(\"text_embeds\")\n            if \"time_ids\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`\"\n                )\n            time_ids = added_cond_kwargs.get(\"time_ids\")\n            time_embeds = self.add_time_proj(time_ids.flatten())\n            time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))\n\n            add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)\n            add_embeds = add_embeds.to(emb.dtype)\n            aug_emb = self.add_embedding(add_embeds)\n        elif self.config.addition_embed_type == \"image\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            aug_emb = self.add_embedding(image_embs)\n        elif self.config.addition_embed_type == \"image_hint\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs or \"hint\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`\"\n                )\n            image_embs = added_cond_kwargs.get(\"image_embeds\")\n            hint = added_cond_kwargs.get(\"hint\")\n            aug_emb, hint = self.add_embedding(image_embs, hint)\n            sample = torch.cat([sample, hint], dim=1)\n\n        emb = emb + aug_emb if aug_emb is not None else emb\n\n        if self.time_embed_act is not None:\n            emb = self.time_embed_act(emb)\n\n        if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_proj\":\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"text_image_proj\":\n            # Kadinsky 2.1 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)\n        elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == \"image_proj\":\n            # Kandinsky 2.2 - style\n            if \"image_embeds\" not in added_cond_kwargs:\n                raise ValueError(\n                    f\"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`\"\n                )\n            image_embeds = added_cond_kwargs.get(\"image_embeds\")\n            encoder_hidden_states = self.encoder_hid_proj(image_embeds)\n        # 2. pre-process\n        sample = self.conv_in(sample)\n\n        # 2.5 GLIGEN position net\n        if cross_attention_kwargs is not None and cross_attention_kwargs.get(\"gligen\", None) is not None:\n            cross_attention_kwargs = cross_attention_kwargs.copy()\n            gligen_args = cross_attention_kwargs.pop(\"gligen\")\n            cross_attention_kwargs[\"gligen\"] = {\"objs\": self.position_net(**gligen_args)}\n\n        # 3. down\n\n        is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None\n        is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None\n\n        down_block_res_samples = (sample,)\n        for downsample_block in self.down_blocks:\n            if hasattr(downsample_block, \"has_cross_attention\") and downsample_block.has_cross_attention:\n                # For t2i-adapter CrossAttnDownBlock2D\n                additional_residuals = {}\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    additional_residuals[\"additional_residuals\"] = down_block_additional_residuals.pop(0)\n\n                sample, res_samples = downsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    attention_mask=attention_mask,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    encoder_attention_mask=encoder_attention_mask,\n                    image_encoder_hidden_states=image_encoder_hidden_states,\n                    **additional_residuals,\n                )\n            else:\n                sample, res_samples = downsample_block(hidden_states=sample, temb=emb)\n\n                if is_adapter and len(down_block_additional_residuals) > 0:\n                    sample += down_block_additional_residuals.pop(0)\n\n            down_block_res_samples += res_samples\n\n        if is_controlnet:\n            new_down_block_res_samples = ()\n\n            for down_block_res_sample, down_block_additional_residual in zip(\n                down_block_res_samples, down_block_additional_residuals\n            ):\n                down_block_res_sample = down_block_res_sample + down_block_additional_residual\n                new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)\n\n            down_block_res_samples = new_down_block_res_samples\n\n        # 4. mid\n        if self.mid_block is not None:\n            sample = self.mid_block(\n                sample,\n                emb,\n                encoder_hidden_states=encoder_hidden_states,\n                attention_mask=attention_mask,\n                cross_attention_kwargs=cross_attention_kwargs,\n                encoder_attention_mask=encoder_attention_mask,\n                image_encoder_hidden_states=image_encoder_hidden_states,\n            )\n            # To support T2I-Adapter-XL\n            if (\n                is_adapter\n                and len(down_block_additional_residuals) > 0\n                and sample.shape == down_block_additional_residuals[0].shape\n            ):\n                sample += down_block_additional_residuals.pop(0)\n\n        if is_controlnet:\n            sample = sample + mid_block_additional_residual\n\n        # 5. up\n        for i, upsample_block in enumerate(self.up_blocks):\n            is_final_block = i == len(self.up_blocks) - 1\n\n            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]\n            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]\n\n            # if we have not reached the final block and need to forward the\n            # upsample size, we do it here\n            if not is_final_block and forward_upsample_size:\n                upsample_size = down_block_res_samples[-1].shape[2:]\n\n            if hasattr(upsample_block, \"has_cross_attention\") and upsample_block.has_cross_attention:\n                sample = upsample_block(\n                    hidden_states=sample,\n                    temb=emb,\n                    res_hidden_states_tuple=res_samples,\n                    encoder_hidden_states=encoder_hidden_states,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                    upsample_size=upsample_size,\n                    attention_mask=attention_mask,\n                    encoder_attention_mask=encoder_attention_mask,\n                    image_encoder_hidden_states=image_encoder_hidden_states,\n                )\n            else:\n                sample = upsample_block(\n                    hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size\n                )\n\n        # 6. post-process\n        if self.conv_norm_out:\n            sample = self.conv_norm_out(sample)\n            sample = self.conv_act(sample)\n        sample = self.conv_out(sample)\n\n        if not return_dict:\n            return (sample,)\n\n        return UNet2DConditionOutput(sample=sample)\n\n    @classmethod\n    def from_pretrained_orig(cls, pretrained_model_path, subfolder=None, **kwargs):\n        if subfolder is not None:\n            pretrained_model_path = os.path.join(pretrained_model_path, subfolder)\n\n        config_file = os.path.join(pretrained_model_path, 'config.json')\n        if not os.path.isfile(config_file):\n            raise RuntimeError(f\"{config_file} does not exist\")\n        with open(config_file, \"r\") as f:\n            config = json.load(f)\n\n        from diffusers.utils import WEIGHTS_NAME \n        from diffusers.utils import SAFETENSORS_WEIGHTS_NAME\n\n\n        model = cls.from_config(config)\n\n        ## for .bin file\n        # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)\n        # if not os.path.isfile(model_file):\n        #     raise RuntimeError(f\"{model_file} does not exist\")\n        # state_dict = torch.load(model_file, map_location=\"cpu\")\n        # model.load_state_dict(state_dict, strict=False)\n\n        ## for .safetensors file\n        import safetensors\n        model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)\n        if not os.path.isfile(model_file):\n            raise RuntimeError(f\"{model_file} does not exist\")\n        state_dict = safetensors.torch.load_file(model_file, device=\"cpu\")\n        model.load_state_dict(state_dict, strict=False)\n\n        return model\n\n    @classmethod\n    def from_pretrained_safetensor(cls, pretrained_model_path, subfolder=None, **kwargs):\n        if subfolder is not None:\n            pretrained_model_path = os.path.join(pretrained_model_path, subfolder)\n\n        config_file = os.path.join(pretrained_model_path, 'config.json')\n        if not os.path.isfile(config_file):\n            raise RuntimeError(f\"{config_file} does not exist\")\n        with open(config_file, \"r\") as f:\n            config = json.load(f)\n            \n        from diffusers.utils import SAFETENSORS_WEIGHTS_NAME\n        model = cls.from_config(config)\n        model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)\n        if not os.path.isfile(model_file):\n            raise RuntimeError(f\"{model_file} does not exist\")\n        state_dict = torch.load(model_file, map_location=\"cpu\")\n        for k, v in model.state_dict().items():\n            if 'attn2_plus' in k:\n                print(k)\n                state_dict.update({k: v})\n        model.load_state_dict(state_dict, strict=False)\n\n        return model\n"
  },
  {
    "path": "models/vit_utils.py",
    "content": "# MIT License\n#\n# Copyright (c) 2021 Intel ISL (Intel Intelligent Systems Lab)\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n#\n# Based on code from https://github.com/isl-org/DPT\n\n\"\"\"Flexible configuration and feature extraction of timm VisionTransformers.\"\"\"\n\nimport types\nimport math\nfrom typing import Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass AddReadout(nn.Module):\n    def __init__(self, start_index: bool = 1):\n        super(AddReadout, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        if self.start_index == 2:\n            readout = (x[:, 0] + x[:, 1]) / 2\n        else:\n            readout = x[:, 0]\n        return x[:, self.start_index:] + readout.unsqueeze(1)\n\n\nclass Transpose(nn.Module):\n    def __init__(self, dim0: int, dim1: int):\n        super(Transpose, self).__init__()\n        self.dim0 = dim0\n        self.dim1 = dim1\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        x = x.transpose(self.dim0, self.dim1)\n        return x.contiguous()\n\n\ndef forward_vit(pretrained: nn.Module, x: torch.Tensor) -> dict:\n    _, _, H, W = x.size()\n    _ = pretrained.model.forward_flex(x)\n    return {k: pretrained.rearrange(v) for k, v in activations.items()}\n\n\ndef _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:\n    posemb_tok, posemb_grid = (\n        posemb[:, : self.start_index],\n        posemb[0, self.start_index :],\n    )\n\n    gs_old = int(math.sqrt(len(posemb_grid)))\n\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode=\"bilinear\", align_corners=False)\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)\n\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n\n    return posemb\n\n\ndef forward_flex(self, x: torch.Tensor) -> torch.Tensor:\n    # patch proj and dynamically resize\n    B, C, H, W = x.size()\n    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)\n    pos_embed = self._resize_pos_embed(\n        self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]\n    )\n\n    # add cls token\n    cls_tokens = self.cls_token.expand(\n        x.size(0), -1, -1\n    )\n    x = torch.cat((cls_tokens, x), dim=1)\n\n    # forward pass\n    x = x + pos_embed\n    x = self.pos_drop(x)\n\n    for blk in self.blocks:\n        x = blk(x)\n\n    x = self.norm(x)\n    return x\n\n\nactivations = {}\n\n\ndef get_activation(name: str) -> Callable:\n    def hook(model, input, output):\n        activations[name] = output\n    return hook\n\n\ndef make_sd_backbone(\n    model: nn.Module,\n    hooks: list[int] = [2, 5, 8, 11],\n    hook_patch: bool = True,\n    start_index: list[int] = 1,\n):\n    assert len(hooks) == 4\n\n    pretrained = nn.Module()\n    pretrained.model = model\n\n    # add hooks\n    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))\n    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))\n    if hook_patch:\n        pretrained.model.pos_drop.register_forward_hook(get_activation('4'))\n\n    # configure readout\n    pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = patch_size\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\ndef make_vit_backbone(\n    model: nn.Module,\n    patch_size: list[int] = [16, 16],\n    hooks: list[int] = [2, 5, 8, 11],\n    hook_patch: bool = True,\n    start_index: list[int] = 1,\n):\n    assert len(hooks) == 4\n\n    pretrained = nn.Module()\n    pretrained.model = model\n\n    # add hooks\n    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation('0'))\n    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation('1'))\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation('2'))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation('3'))\n    if hook_patch:\n        pretrained.model.pos_drop.register_forward_hook(get_activation('4'))\n\n    # configure readout\n    pretrained.rearrange = nn.Sequential(AddReadout(start_index), Transpose(1, 2))\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = patch_size\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n"
  },
  {
    "path": "myutils/devices.py",
    "content": "import sys\nimport contextlib\nfrom functools import lru_cache\n\nimport torch\n#from modules import errors\n\nif sys.platform == \"darwin\":\n    from modules import mac_specific\n\n\ndef has_mps() -> bool:\n    if sys.platform != \"darwin\":\n        return False\n    else:\n        return mac_specific.has_mps\n\n\ndef get_cuda_device_string():\n    return \"cuda\"\n\n\ndef get_optimal_device_name():\n    if torch.cuda.is_available():\n        return get_cuda_device_string()\n\n    if has_mps():\n        return \"mps\"\n\n    return \"cpu\"\n\n\ndef get_optimal_device():\n    return torch.device(get_optimal_device_name())\n\n\ndef get_device_for(task):\n    return get_optimal_device()\n\n\ndef torch_gc():\n\n    if torch.cuda.is_available():\n        with torch.cuda.device(get_cuda_device_string()):\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n    if has_mps():\n        mac_specific.torch_mps_gc()\n\n\ndef enable_tf32():\n    if torch.cuda.is_available():\n\n        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't\n        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407\n        if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):\n            torch.backends.cudnn.benchmark = True\n\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.allow_tf32 = True\n\n\nenable_tf32()\n#errors.run(enable_tf32, \"Enabling TF32\")\n\ncpu = torch.device(\"cpu\")\ndevice = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device(\"cuda\")\ndtype = torch.float16\ndtype_vae = torch.float16\ndtype_unet = torch.float16\nunet_needs_upcast = False\n\n\ndef cond_cast_unet(input):\n    return input.to(dtype_unet) if unet_needs_upcast else input\n\n\ndef cond_cast_float(input):\n    return input.float() if unet_needs_upcast else input\n\n\ndef randn(seed, shape):\n    torch.manual_seed(seed)\n    return torch.randn(shape, device=device)\n\n\ndef randn_without_seed(shape):\n    return torch.randn(shape, device=device)\n\n\ndef autocast(disable=False):\n    if disable:\n        return contextlib.nullcontext()\n\n    return torch.autocast(\"cuda\")\n\n\ndef without_autocast(disable=False):\n    return torch.autocast(\"cuda\", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()\n\n\nclass NansException(Exception):\n    pass\n\n\ndef test_for_nans(x, where):\n    if not torch.all(torch.isnan(x)).item():\n        return\n\n    if where == \"unet\":\n        message = \"A tensor with all NaNs was produced in Unet.\"\n\n    elif where == \"vae\":\n        message = \"A tensor with all NaNs was produced in VAE.\"\n\n    else:\n        message = \"A tensor with all NaNs was produced.\"\n\n    message += \" Use --disable-nan-check commandline argument to disable this check.\"\n\n    raise NansException(message)\n\n\n@lru_cache\ndef first_time_calculation():\n    \"\"\"\n    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and\n    spends about 2.7 seconds doing that, at least wih NVidia.\n    \"\"\"\n\n    x = torch.zeros((1, 1)).to(device, dtype)\n    linear = torch.nn.Linear(1, 1).to(device, dtype)\n    linear(x)\n\n    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)\n    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)\n    conv2d(x)\n"
  },
  {
    "path": "myutils/img_util.py",
    "content": "import os\nimport PIL\nimport cv2\nimport math\nimport numpy as np\nimport torch\nimport torchvision\nimport imageio\n\nfrom einops import rearrange\n\ndef save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0):\n    videos = rearrange(videos, \"b c t h w -> t b c h w\").cpu()\n    outputs = []\n    for x in videos:\n        x = torchvision.utils.make_grid(x, nrow=n_rows)\n        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)\n        if rescale:\n            x = (x / 2.0 + 0.5).clamp(0, 1)  # -1,1 -> 0,1\n        x = (x * 255).numpy().astype(np.uint8)\n        #x = adjust_gamma(x, 0.5)\n        outputs.append(x)\n\n    outputs = outputs[discardN:]\n\n    if path is not None:\n        #os.makedirs(os.path.dirname(path), exist_ok=True)\n        imageio.mimsave(path, outputs, duration=1000/fps, loop=0)\n\n    return outputs\n\ndef convert_image_to_fn(img_type, minsize, image, eps=0.02):\n    width, height = image.size\n    if min(width, height) < minsize:\n        scale = minsize/min(width, height) + eps\n        image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))\n\n    if image.mode != img_type:\n        return image.convert(img_type)\n    return image"
  },
  {
    "path": "myutils/misc.py",
    "content": "import os\nimport binascii\nfrom safetensors import safe_open\n\nimport torch\n\nfrom diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint\n\ndef rand_name(length=8, suffix=''):\n    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')\n    if suffix:\n        if not suffix.startswith('.'):\n            suffix = '.' + suffix\n        name += suffix\n    return name\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\ndef exists(x):\n    return x is not None\n\ndef identity(x):\n    return x\n\ndef load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=\"\"):\n    if model_path is None: return unet\n    \n    if model_path.endswith(\".ckpt\"):\n        base_state_dict = torch.load(model_path)['state_dict']\n    elif model_path.endswith(\".safetensors\"):\n        state_dict = {}\n        with safe_open(model_path, framework=\"pt\", device=\"cpu\") as f:\n            for key in f.keys():\n                state_dict[key] = f.get_tensor(key)\n                            \n        is_lora = all(\"lora\" in k for k in state_dict.keys())\n        if not is_lora:\n            base_state_dict = state_dict\n        else:\n            base_state_dict = {}\n            with safe_open(model_base, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    base_state_dict[key] = f.get_tensor(key)\n                                 \n    converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config)\n    unet_state_dict = unet.state_dict()\n    for key in converted_unet_checkpoint:\n        converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key]\n    unet.load_state_dict(converted_unet_checkpoint, strict=False)\n\n    if vae is not None:\n        converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config)\n        vae.load_state_dict(converted_vae_checkpoint)\n    \n    return unet, vae"
  },
  {
    "path": "myutils/vaehook.py",
    "content": "# ------------------------------------------------------------------------\n#\n#   Ultimate VAE Tile Optimization\n#\n#   Introducing a revolutionary new optimization designed to make\n#   the VAE work with giant images on limited VRAM!\n#   Say goodbye to the frustration of OOM and hello to seamless output!\n#\n# ------------------------------------------------------------------------\n#\n#   This script is a wild hack that splits the image into tiles,\n#   encodes each tile separately, and merges the result back together.\n#\n#   Advantages:\n#   - The VAE can now work with giant images on limited VRAM\n#       (~10 GB for 8K images!)\n#   - The merged output is completely seamless without any post-processing.\n#\n#   Drawbacks:\n#   - Giant RAM needed. To store the intermediate results for a 4096x4096\n#       images, you need 32 GB RAM it consumes ~20GB); for 8192x8192\n#       you need 128 GB RAM machine (it consumes ~100 GB)\n#   - NaNs always appear in for 8k images when you use fp16 (half) VAE\n#       You must use --no-half-vae to disable half VAE for that giant image.\n#   - Slow speed. With default tile size, it takes around 50/200 seconds\n#       to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode\n#       a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)\n#   - The gradient calculation is not compatible with this hack. It\n#       will break any backward() or torch.autograd.grad() that passes VAE.\n#       (But you can still use the VAE to generate training data.)\n#\n#   How it works:\n#   1) The image is split into tiles.\n#       - To ensure perfect results, each tile is padded with 32 pixels\n#           on each side.\n#       - Then the conv2d/silu/upsample/downsample can produce identical\n#           results to the original image without splitting.\n#   2) The original forward is decomposed into a task queue and a task worker.\n#       - The task queue is a list of functions that will be executed in order.\n#       - The task worker is a loop that executes the tasks in the queue.\n#   3) The task queue is executed for each tile.\n#       - Current tile is sent to GPU.\n#       - local operations are directly executed.\n#       - Group norm calculation is temporarily suspended until the mean\n#           and var of all tiles are calculated.\n#       - The residual is pre-calculated and stored and addded back later.\n#       - When need to go to the next tile, the current tile is send to cpu.\n#   4) After all tiles are processed, tiles are merged on cpu and return.\n#\n#   Enjoy!\n#\n#   @author: LI YI @ Nanyang Technological University - Singapore\n#   @date: 2023-03-02\n#   @license: MIT License\n#\n#   Please give me a star if you like this project!\n#\n# -------------------------------------------------------------------------\n\nimport gc\nfrom time import time\nimport math\nfrom tqdm import tqdm\n\nimport torch\nimport torch.version\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport sys\nsys.path.append('/home/notebook/code/personal/S9048295/code/PASD')\nimport myutils.devices as devices\n#from modules.shared import state\n#from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock\n\ntry:\n    import xformers\n    import xformers.ops\nexcept ImportError:\n    pass\n\nsd_flag = False\n\ndef get_recommend_encoder_tile_size():\n    if torch.cuda.is_available():\n        total_memory = torch.cuda.get_device_properties(\n            devices.device).total_memory // 2**20\n        if total_memory > 16*1000:\n            ENCODER_TILE_SIZE = 3072\n        elif total_memory > 12*1000:\n            ENCODER_TILE_SIZE = 2048\n        elif total_memory > 8*1000:\n            ENCODER_TILE_SIZE = 1536\n        else:\n            ENCODER_TILE_SIZE = 960\n    else:\n        ENCODER_TILE_SIZE = 512\n    return ENCODER_TILE_SIZE\n\n\ndef get_recommend_decoder_tile_size():\n    if torch.cuda.is_available():\n        total_memory = torch.cuda.get_device_properties(\n            devices.device).total_memory // 2**20\n        if total_memory > 30*1000:\n            DECODER_TILE_SIZE = 256\n        elif total_memory > 16*1000:\n            DECODER_TILE_SIZE = 192\n        elif total_memory > 12*1000:\n            DECODER_TILE_SIZE = 128\n        elif total_memory > 8*1000:\n            DECODER_TILE_SIZE = 96\n        else:\n            DECODER_TILE_SIZE = 64\n    else:\n        DECODER_TILE_SIZE = 64\n    return DECODER_TILE_SIZE\n\n\nif 'global const':\n    DEFAULT_ENABLED = False\n    DEFAULT_MOVE_TO_GPU = False\n    DEFAULT_FAST_ENCODER = True\n    DEFAULT_FAST_DECODER = True\n    DEFAULT_COLOR_FIX = 0\n    DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()\n    DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()\n\n\n# inplace version of silu\ndef inplace_nonlinearity(x):\n    # Test: fix for Nans\n    return F.silu(x, inplace=True)\n\n# extracted from ldm.modules.diffusionmodules.model\n\n# from diffusers lib\ndef attn_forward_new(self, h_):\n    batch_size, channel, height, width = h_.shape\n    hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)\n\n    attention_mask = None\n    encoder_hidden_states = None\n    batch_size, sequence_length, _ = hidden_states.shape\n    attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n    query = self.to_q(hidden_states)\n\n    if encoder_hidden_states is None:\n        encoder_hidden_states = hidden_states\n    elif self.norm_cross:\n        encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)\n\n    key = self.to_k(encoder_hidden_states)\n    value = self.to_v(encoder_hidden_states)\n\n    query = self.head_to_batch_dim(query)\n    key = self.head_to_batch_dim(key)\n    value = self.head_to_batch_dim(value)\n\n    attention_probs = self.get_attention_scores(query, key, attention_mask)\n    hidden_states = torch.bmm(attention_probs, value)\n    hidden_states = self.batch_to_head_dim(hidden_states)\n\n    # linear proj\n    hidden_states = self.to_out[0](hidden_states)\n    # dropout\n    hidden_states = self.to_out[1](hidden_states)\n\n    hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n    return hidden_states\n\ndef attn_forward(self, h_):\n    q = self.q(h_)\n    k = self.k(h_)\n    v = self.v(h_)\n\n    # compute attention\n    b, c, h, w = q.shape\n    q = q.reshape(b, c, h*w)\n    q = q.permute(0, 2, 1)   # b,hw,c\n    k = k.reshape(b, c, h*w)  # b,c,hw\n    w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n    w_ = w_ * (int(c)**(-0.5))\n    w_ = torch.nn.functional.softmax(w_, dim=2)\n\n    # attend to values\n    v = v.reshape(b, c, h*w)\n    w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)\n    # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n    h_ = torch.bmm(v, w_)\n    h_ = h_.reshape(b, c, h, w)\n\n    h_ = self.proj_out(h_)\n\n    return h_\n\n\ndef xformer_attn_forward(self, h_):\n    q = self.q(h_)\n    k = self.k(h_)\n    v = self.v(h_)\n\n    # compute attention\n    B, C, H, W = q.shape\n    q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))\n\n    q, k, v = map(\n        lambda t: t.unsqueeze(3)\n        .reshape(B, t.shape[1], 1, C)\n        .permute(0, 2, 1, 3)\n        .reshape(B * 1, t.shape[1], C)\n        .contiguous(),\n        (q, k, v),\n    )\n    out = xformers.ops.memory_efficient_attention(\n        q, k, v, attn_bias=None, op=self.attention_op)\n\n    out = (\n        out.unsqueeze(0)\n        .reshape(B, 1, out.shape[1], C)\n        .permute(0, 2, 1, 3)\n        .reshape(B, out.shape[1], C)\n    )\n    out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)\n    out = self.proj_out(out)\n    return out\n\n\ndef attn2task(task_queue, net):\n    if False: #isinstance(net, AttnBlock):\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.norm))\n        task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))\n        task_queue.append(['add_res', None])\n    elif False: #isinstance(net, MemoryEfficientAttnBlock):\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.norm))\n        task_queue.append(\n            ('attn', lambda x, net=net: xformer_attn_forward(net, x)))\n        task_queue.append(['add_res', None])\n    else:\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.group_norm))\n        task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))\n        task_queue.append(['add_res', None])\n\ndef resblock2task(queue, block):\n    \"\"\"\n    Turn a ResNetBlock into a sequence of tasks and append to the task queue\n\n    @param queue: the target task queue\n    @param block: ResNetBlock\n\n    \"\"\"\n    if block.in_channels != block.out_channels:\n        if sd_flag:\n            if block.use_conv_shortcut:\n                queue.append(('store_res', block.conv_shortcut))\n            else:\n                queue.append(('store_res', block.nin_shortcut))\n        else:\n            if block.use_in_shortcut:\n                queue.append(('store_res', block.conv_shortcut))\n            else:\n                queue.append(('store_res', block.nin_shortcut))\n\n    else:\n        queue.append(('store_res', lambda x: x))\n    queue.append(('pre_norm', block.norm1))\n    queue.append(('silu', inplace_nonlinearity))\n    queue.append(('conv1', block.conv1))\n    queue.append(('pre_norm', block.norm2))\n    queue.append(('silu', inplace_nonlinearity))\n    queue.append(('conv2', block.conv2))\n    queue.append(['add_res', None])\n\n\ndef build_sampling(task_queue, net, is_decoder):\n    \"\"\"\n    Build the sampling part of a task queue\n    @param task_queue: the target task queue\n    @param net: the network\n    @param is_decoder: currently building decoder or encoder\n    \"\"\"\n    if is_decoder:\n        if sd_flag:\n            resblock2task(task_queue, net.mid.block_1)\n            attn2task(task_queue, net.mid.attn_1)\n            print(task_queue)\n            resblock2task(task_queue, net.mid.block_2)\n            resolution_iter = reversed(range(net.num_resolutions))\n            block_ids = net.num_res_blocks + 1\n            condition = 0\n            module = net.up\n            func_name = 'upsample'\n        else:\n            resblock2task(task_queue, net.mid_block.resnets[0])\n            attn2task(task_queue, net.mid_block.attentions[0])\n            resblock2task(task_queue, net.mid_block.resnets[1])\n            resolution_iter = (range(len(net.up_blocks)))  # net.num_resolutions = 3\n            block_ids = 2 + 1\n            condition = len(net.up_blocks) - 1\n            module = net.up_blocks\n            func_name = 'upsamplers'\n    else:\n        resolution_iter = range(net.num_resolutions)\n        block_ids = net.num_res_blocks\n        condition = net.num_resolutions - 1\n        module = net.down\n        func_name = 'downsample'\n\n    for i_level in resolution_iter:\n        for i_block in range(block_ids):\n            if sd_flag:\n                resblock2task(task_queue, module[i_level].block[i_block])\n            else:\n                resblock2task(task_queue, module[i_level].resnets[i_block])\n        if i_level != condition:\n            if sd_flag:\n                task_queue.append((func_name, getattr(module[i_level], func_name)))\n            else:\n                task_queue.append((func_name, module[i_level].upsamplers[0]))\n\n    if not is_decoder:\n        if sd_flag:\n            resblock2task(task_queue, net.mid.block_1)\n            attn2task(task_queue, net.mid.attn_1)\n            resblock2task(task_queue, net.mid.block_2)\n        else:\n            resblock2task(task_queue, net.mid_block.resnets[0])\n            attn2task(task_queue, net.mid_block.attentions[0])\n            resblock2task(task_queue, net.mid_block.resnets[1])\n\n\ndef build_task_queue(net, is_decoder):\n    \"\"\"\n    Build a single task queue for the encoder or decoder\n    @param net: the VAE decoder or encoder network\n    @param is_decoder: currently building decoder or encoder\n    @return: the task queue\n    \"\"\"\n    task_queue = []\n    task_queue.append(('conv_in', net.conv_in))\n\n    # construct the sampling part of the task queue\n    # because encoder and decoder share the same architecture, we extract the sampling part\n    build_sampling(task_queue, net, is_decoder)\n    if is_decoder and not sd_flag:\n        net.give_pre_end = False\n        net.tanh_out = False\n\n    if not is_decoder or not net.give_pre_end:\n        if sd_flag:\n            task_queue.append(('pre_norm', net.norm_out))\n        else:\n            task_queue.append(('pre_norm', net.conv_norm_out))\n        task_queue.append(('silu', inplace_nonlinearity))\n        task_queue.append(('conv_out', net.conv_out))\n        if is_decoder and net.tanh_out:\n            task_queue.append(('tanh', torch.tanh))\n\n    return task_queue\n\n\ndef clone_task_queue(task_queue):\n    \"\"\"\n    Clone a task queue\n    @param task_queue: the task queue to be cloned\n    @return: the cloned task queue\n    \"\"\"\n    return [[item for item in task] for task in task_queue]\n\n\ndef get_var_mean(input, num_groups, eps=1e-6):\n    \"\"\"\n    Get mean and var for group norm\n    \"\"\"\n    b, c = input.size(0), input.size(1)\n    channel_in_group = int(c/num_groups)\n    input_reshaped = input.contiguous().view(\n        1, int(b * num_groups), channel_in_group, *input.size()[2:])\n    var, mean = torch.var_mean(\n        input_reshaped, dim=[0, 2, 3, 4], unbiased=False)\n    return var, mean\n\n\ndef custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):\n    \"\"\"\n    Custom group norm with fixed mean and var\n\n    @param input: input tensor\n    @param num_groups: number of groups. by default, num_groups = 32\n    @param mean: mean, must be pre-calculated by get_var_mean\n    @param var: var, must be pre-calculated by get_var_mean\n    @param weight: weight, should be fetched from the original group norm\n    @param bias: bias, should be fetched from the original group norm\n    @param eps: epsilon, by default, eps = 1e-6 to match the original group norm\n\n    @return: normalized tensor\n    \"\"\"\n    b, c = input.size(0), input.size(1)\n    channel_in_group = int(c/num_groups)\n    input_reshaped = input.contiguous().view(\n        1, int(b * num_groups), channel_in_group, *input.size()[2:])\n\n    out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,\n                       training=False, momentum=0, eps=eps)\n\n    out = out.view(b, c, *input.size()[2:])\n\n    # post affine transform\n    if weight is not None:\n        out *= weight.view(1, -1, 1, 1)\n    if bias is not None:\n        out += bias.view(1, -1, 1, 1)\n    return out\n\n\ndef crop_valid_region(x, input_bbox, target_bbox, is_decoder):\n    \"\"\"\n    Crop the valid region from the tile\n    @param x: input tile\n    @param input_bbox: original input bounding box\n    @param target_bbox: output bounding box\n    @param scale: scale factor\n    @return: cropped tile\n    \"\"\"\n    padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]\n    margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]\n    return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]\n\n# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓\n\n\ndef perfcount(fn):\n    def wrapper(*args, **kwargs):\n        ts = time()\n\n        if torch.cuda.is_available():\n            torch.cuda.reset_peak_memory_stats(devices.device)\n        devices.torch_gc()\n        gc.collect()\n\n        ret = fn(*args, **kwargs)\n\n        devices.torch_gc()\n        gc.collect()\n        if torch.cuda.is_available():\n            vram = torch.cuda.max_memory_allocated(devices.device) / 2**20\n            torch.cuda.reset_peak_memory_stats(devices.device)\n            print(\n                f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')\n        else:\n            print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')\n\n        return ret\n    return wrapper\n\n# copy end :)\n\n\nclass GroupNormParam:\n    def __init__(self):\n        self.var_list = []\n        self.mean_list = []\n        self.pixel_list = []\n        self.weight = None\n        self.bias = None\n\n    def add_tile(self, tile, layer):\n        var, mean = get_var_mean(tile, 32)\n        # For giant images, the variance can be larger than max float16\n        # In this case we create a copy to float32\n        if var.dtype == torch.float16 and var.isinf().any():\n            fp32_tile = tile.float()\n            var, mean = get_var_mean(fp32_tile, 32)\n        # ============= DEBUG: test for infinite =============\n        # if torch.isinf(var).any():\n        #    print('var: ', var)\n        # ====================================================\n        self.var_list.append(var)\n        self.mean_list.append(mean)\n        self.pixel_list.append(\n            tile.shape[2]*tile.shape[3])\n        if hasattr(layer, 'weight'):\n            self.weight = layer.weight\n            self.bias = layer.bias\n        else:\n            self.weight = None\n            self.bias = None\n\n    def summary(self):\n        \"\"\"\n        summarize the mean and var and return a function\n        that apply group norm on each tile\n        \"\"\"\n        if len(self.var_list) == 0:\n            return None\n        var = torch.vstack(self.var_list)\n        mean = torch.vstack(self.mean_list)\n        max_value = max(self.pixel_list)\n        pixels = torch.tensor(\n            self.pixel_list, dtype=torch.float32, device=devices.device) / max_value\n        sum_pixels = torch.sum(pixels)\n        pixels = pixels.unsqueeze(\n            1) / sum_pixels\n        var = torch.sum(\n            var * pixels, dim=0)\n        mean = torch.sum(\n            mean * pixels, dim=0)\n        return lambda x:  custom_group_norm(x, 32, mean, var, self.weight, self.bias)\n\n    @staticmethod\n    def from_tile(tile, norm):\n        \"\"\"\n        create a function from a single tile without summary\n        \"\"\"\n        var, mean = get_var_mean(tile, 32)\n        if var.dtype == torch.float16 and var.isinf().any():\n            fp32_tile = tile.float()\n            var, mean = get_var_mean(fp32_tile, 32)\n            # if it is a macbook, we need to convert back to float16\n            if var.device.type == 'mps':\n                # clamp to avoid overflow\n                var = torch.clamp(var, 0, 60000)\n                var = var.half()\n                mean = mean.half()\n        if hasattr(norm, 'weight'):\n            weight = norm.weight\n            bias = norm.bias\n        else:\n            weight = None\n            bias = None\n\n        def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):\n            return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)\n        return group_norm_func\n\n\nclass VAEHook:\n    def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):\n        self.net = net                  # encoder | decoder\n        self.tile_size = tile_size\n        self.is_decoder = is_decoder\n        self.fast_mode = (fast_encoder and not is_decoder) or (\n            fast_decoder and is_decoder)\n        self.color_fix = color_fix and not is_decoder\n        self.to_gpu = to_gpu\n        self.pad = 11 if is_decoder else 32\n\n    def __call__(self, x):\n        B, C, H, W = x.shape\n        original_device = next(self.net.parameters()).device\n        try:\n            if self.to_gpu:\n                self.net.to(devices.get_optimal_device())\n            if max(H, W) <= self.pad * 2 + self.tile_size:\n                print(\"[Tiled VAE]: the input size is tiny and unnecessary to tile.\")\n                return self.net.original_forward(x)\n            else:\n                return self.vae_tile_forward(x)\n        finally:\n            self.net.to(original_device)\n\n    def get_best_tile_size(self, lowerbound, upperbound):\n        \"\"\"\n        Get the best tile size for GPU memory\n        \"\"\"\n        divider = 32\n        while divider >= 2:\n            remainer = lowerbound % divider\n            if remainer == 0:\n                return lowerbound\n            candidate = lowerbound - remainer + divider\n            if candidate <= upperbound:\n                return candidate\n            divider //= 2\n        return lowerbound\n\n    def split_tiles(self, h, w):\n        \"\"\"\n        Tool function to split the image into tiles\n        @param h: height of the image\n        @param w: width of the image\n        @return: tile_input_bboxes, tile_output_bboxes\n        \"\"\"\n        tile_input_bboxes, tile_output_bboxes = [], []\n        tile_size = self.tile_size\n        pad = self.pad\n        num_height_tiles = math.ceil((h - 2 * pad) / tile_size)\n        num_width_tiles = math.ceil((w - 2 * pad) / tile_size)\n        # If any of the numbers are 0, we let it be 1\n        # This is to deal with long and thin images\n        num_height_tiles = max(num_height_tiles, 1)\n        num_width_tiles = max(num_width_tiles, 1)\n\n        # Suggestions from https://github.com/Kahsolt: auto shrink the tile size\n        real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)\n        real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)\n        real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)\n        real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)\n\n        print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +\n              f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')\n\n        for i in range(num_height_tiles):\n            for j in range(num_width_tiles):\n                # bbox: [x1, x2, y1, y2]\n                # the padding is is unnessary for image borders. So we directly start from (32, 32)\n                input_bbox = [\n                    pad + j * real_tile_width,\n                    min(pad + (j + 1) * real_tile_width, w),\n                    pad + i * real_tile_height,\n                    min(pad + (i + 1) * real_tile_height, h),\n                ]\n\n                # if the output bbox is close to the image boundary, we extend it to the image boundary\n                output_bbox = [\n                    input_bbox[0] if input_bbox[0] > pad else 0,\n                    input_bbox[1] if input_bbox[1] < w - pad else w,\n                    input_bbox[2] if input_bbox[2] > pad else 0,\n                    input_bbox[3] if input_bbox[3] < h - pad else h,\n                ]\n\n                # scale to get the final output bbox\n                output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]\n                tile_output_bboxes.append(output_bbox)\n\n                # indistinguishable expand the input bbox by pad pixels\n                tile_input_bboxes.append([\n                    max(0, input_bbox[0] - pad),\n                    min(w, input_bbox[1] + pad),\n                    max(0, input_bbox[2] - pad),\n                    min(h, input_bbox[3] + pad),\n                ])\n\n        return tile_input_bboxes, tile_output_bboxes\n\n    @torch.no_grad()\n    def estimate_group_norm(self, z, task_queue, color_fix):\n        device = z.device\n        tile = z\n        last_id = len(task_queue) - 1\n        while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':\n            last_id -= 1\n        if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':\n            raise ValueError('No group norm found in the task queue')\n        # estimate until the last group norm\n        for i in range(last_id + 1):\n            task = task_queue[i]\n            if task[0] == 'pre_norm':\n                group_norm_func = GroupNormParam.from_tile(tile, task[1])\n                task_queue[i] = ('apply_norm', group_norm_func)\n                if i == last_id:\n                    return True\n                tile = group_norm_func(tile)\n            elif task[0] == 'store_res':\n                task_id = i + 1\n                while task_id < last_id and task_queue[task_id][0] != 'add_res':\n                    task_id += 1\n                if task_id >= last_id:\n                    continue\n                task_queue[task_id][1] = task[1](tile)\n            elif task[0] == 'add_res':\n                tile += task[1].to(device)\n                task[1] = None\n            elif color_fix and task[0] == 'downsample':\n                for j in range(i, last_id + 1):\n                    if task_queue[j][0] == 'store_res':\n                        task_queue[j] = ('store_res_cpu', task_queue[j][1])\n                return True\n            else:\n                tile = task[1](tile)\n            try:\n                devices.test_for_nans(tile, \"vae\")\n            except:\n                print(f'Nan detected in fast mode estimation. Fast mode disabled.')\n                return False\n\n        raise IndexError('Should not reach here')\n\n    @perfcount\n    @torch.no_grad()\n    def vae_tile_forward(self, z):\n        \"\"\"\n        Decode a latent vector z into an image in a tiled manner.\n        @param z: latent vector\n        @return: image\n        \"\"\"\n        device = next(self.net.parameters()).device\n        net = self.net\n        tile_size = self.tile_size\n        is_decoder = self.is_decoder\n\n        z = z.detach() # detach the input to avoid backprop\n\n        N, height, width = z.shape[0], z.shape[2], z.shape[3]\n        net.last_z_shape = z.shape\n\n        # Split the input into tiles and build a task queue for each tile\n        print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')\n\n        in_bboxes, out_bboxes = self.split_tiles(height, width)\n\n        # Prepare tiles by split the input latents\n        tiles = []\n        for input_bbox in in_bboxes:\n            tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()\n            tiles.append(tile)\n\n        num_tiles = len(tiles)\n        num_completed = 0\n\n        # Build task queues\n        single_task_queue = build_task_queue(net, is_decoder)\n        #print(single_task_queue)\n        if self.fast_mode:\n            # Fast mode: downsample the input image to the tile size,\n            # then estimate the group norm parameters on the downsampled image\n            scale_factor = tile_size / max(height, width)\n            z = z.to(device)\n            downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')\n            # use nearest-exact to keep statictics as close as possible\n            print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')\n\n            # ======= Special thanks to @Kahsolt for distribution shift issue ======= #\n            # The downsampling will heavily distort its mean and std, so we need to recover it.\n            std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)\n            std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)\n            downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old\n            del std_old, mean_old, std_new, mean_new\n            # occasionally the std_new is too small or too large, which exceeds the range of float16\n            # so we need to clamp it to max z's range.\n            downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())\n            estimate_task_queue = clone_task_queue(single_task_queue)\n            if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):\n                single_task_queue = estimate_task_queue\n            del downsampled_z\n\n        task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]\n\n        # Dummy result\n        result = None\n        result_approx = None\n        #try:\n        #    with devices.autocast():\n        #        result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()\n        #except: pass\n        # Free memory of input latent tensor\n        del z\n\n        # Task queue execution\n        pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f\"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: \")\n\n        # execute the task back and forth when switch tiles so that we always\n        # keep one tile on the GPU to reduce unnecessary data transfer\n        forward = True\n        interrupted = False\n        #state.interrupted = interrupted\n        while True:\n            #if state.interrupted: interrupted = True ; break\n\n            group_norm_param = GroupNormParam()\n            for i in range(num_tiles) if forward else reversed(range(num_tiles)):\n                #if state.interrupted: interrupted = True ; break\n\n                tile = tiles[i].to(device)\n                input_bbox = in_bboxes[i]\n                task_queue = task_queues[i]\n\n                interrupted = False\n                while len(task_queue) > 0:\n                    #if state.interrupted: interrupted = True ; break\n\n                    # DEBUG: current task\n                    # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)\n                    task = task_queue.pop(0)\n                    if task[0] == 'pre_norm':\n                        group_norm_param.add_tile(tile, task[1])\n                        break\n                    elif task[0] == 'store_res' or task[0] == 'store_res_cpu':\n                        task_id = 0\n                        res = task[1](tile)\n                        if not self.fast_mode or task[0] == 'store_res_cpu':\n                            res = res.cpu()\n                        while task_queue[task_id][0] != 'add_res':\n                            task_id += 1\n                        task_queue[task_id][1] = res\n                    elif task[0] == 'add_res':\n                        tile += task[1].to(device)\n                        task[1] = None\n                    else:\n                        tile = task[1](tile)\n                        #print(tiles[i].shape, tile.shape, task)\n                    pbar.update(1)\n\n                if interrupted: break\n\n                # check for NaNs in the tile.\n                # If there are NaNs, we abort the process to save user's time\n                #devices.test_for_nans(tile, \"vae\")\n\n                #print(tiles[i].shape, tile.shape, i, num_tiles)\n                if len(task_queue) == 0:\n                    tiles[i] = None\n                    num_completed += 1\n                    if result is None:      # NOTE: dim C varies from different cases, can only be inited dynamically\n                        result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)\n                    result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)\n                    del tile\n                elif i == num_tiles - 1 and forward:\n                    forward = False\n                    tiles[i] = tile\n                elif i == 0 and not forward:\n                    forward = True\n                    tiles[i] = tile\n                else:\n                    tiles[i] = tile.cpu()\n                    del tile\n\n            if interrupted: break\n            if num_completed == num_tiles: break\n\n            # insert the group norm task to the head of each task queue\n            group_norm_func = group_norm_param.summary()\n            if group_norm_func is not None:\n                for i in range(num_tiles):\n                    task_queue = task_queues[i]\n                    task_queue.insert(0, ('apply_norm', group_norm_func))\n\n        # Done!\n        pbar.close()\n        return result if result is not None else result_approx.to(device)"
  },
  {
    "path": "myutils/wavelet_color_fix.py",
    "content": "'''\n# --------------------------------------------------------------------------------\n#   Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)\n# --------------------------------------------------------------------------------\n'''\n\nimport torch\nfrom PIL import Image\nfrom torch import Tensor\nfrom torch.nn import functional as F\n\nfrom torchvision.transforms import ToTensor, ToPILImage\n\ndef adain_color_fix(target: Image, source: Image):\n    # Convert images to tensors\n    to_tensor = ToTensor()\n    target_tensor = to_tensor(target).unsqueeze(0)\n    source_tensor = to_tensor(source).unsqueeze(0)\n\n    # Apply adaptive instance normalization\n    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)\n\n    # Convert tensor back to image\n    to_image = ToPILImage()\n    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))\n\n    return result_image\n\ndef wavelet_color_fix(target: Image, source: Image):\n    # Convert images to tensors\n    to_tensor = ToTensor()\n    target_tensor = to_tensor(target).unsqueeze(0)\n    source_tensor = to_tensor(source).unsqueeze(0)\n\n    # Apply wavelet reconstruction\n    result_tensor = wavelet_reconstruction(target_tensor, source_tensor)\n\n    # Convert tensor back to image\n    to_image = ToPILImage()\n    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))\n\n    return result_image\n\ndef calc_mean_std(feat: Tensor, eps=1e-5):\n    \"\"\"Calculate mean and std for adaptive_instance_normalization.\n    Args:\n        feat (Tensor): 4D tensor.\n        eps (float): A small value added to the variance to avoid\n            divide-by-zero. Default: 1e-5.\n    \"\"\"\n    size = feat.size()\n    assert len(size) == 4, 'The input feature should be 4D tensor.'\n    b, c = size[:2]\n    feat_var = feat.reshape(b, c, -1).var(dim=2) + eps\n    feat_std = feat_var.sqrt().reshape(b, c, 1, 1)\n    feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)\n    return feat_mean, feat_std\n\ndef adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"Adaptive instance normalization.\n    Adjust the reference features to have the similar color and illuminations\n    as those in the degradate features.\n    Args:\n        content_feat (Tensor): The reference feature.\n        style_feat (Tensor): The degradate features.\n    \"\"\"\n    size = content_feat.size()\n    style_mean, style_std = calc_mean_std(style_feat)\n    content_mean, content_std = calc_mean_std(content_feat)\n    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)\n    return normalized_feat * style_std.expand(size) + style_mean.expand(size)\n\ndef wavelet_blur(image: Tensor, radius: int):\n    \"\"\"\n    Apply wavelet blur to the input tensor.\n    \"\"\"\n    # input shape: (1, 3, H, W)\n    # convolution kernel\n    kernel_vals = [\n        [0.0625, 0.125, 0.0625],\n        [0.125, 0.25, 0.125],\n        [0.0625, 0.125, 0.0625],\n    ]\n    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)\n    # add channel dimensions to the kernel to make it a 4D tensor\n    kernel = kernel[None, None]\n    # repeat the kernel across all input channels\n    kernel = kernel.repeat(3, 1, 1, 1)\n    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')\n    # apply convolution\n    output = F.conv2d(image, kernel, groups=3, dilation=radius)\n    return output\n\ndef wavelet_decomposition(image: Tensor, levels=5):\n    \"\"\"\n    Apply wavelet decomposition to the input tensor.\n    This function only returns the low frequency & the high frequency.\n    \"\"\"\n    high_freq = torch.zeros_like(image)\n    for i in range(levels):\n        radius = 2 ** i\n        low_freq = wavelet_blur(image, radius)\n        high_freq += (image - low_freq)\n        image = low_freq\n\n    return high_freq, low_freq\n\ndef wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"\n    Apply wavelet decomposition, so that the content will have the same color as the style.\n    \"\"\"\n    # calculate the wavelet decomposition of the content feature\n    content_high_freq, content_low_freq = wavelet_decomposition(content_feat)\n    del content_low_freq\n    # calculate the wavelet decomposition of the style feature\n    style_high_freq, style_low_freq = wavelet_decomposition(style_feat)\n    del style_high_freq\n    # reconstruct the content feature with the style's high frequency\n    return content_high_freq + style_low_freq\n"
  },
  {
    "path": "pipelines/pipeline_ccsr.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nimport inspect\nimport os\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport PIL.Image\nimport torch\nimport torch.nn.functional as F\nfrom torchvision.utils import save_image\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\n\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.loaders import TextualInversionLoaderMixin\n# from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom models.controlnet import ControlNetModel\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import (\n    PIL_INTERPOLATION,\n    is_accelerate_available,\n    is_accelerate_version,\n    logging,\n    replace_example_docstring,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module, randn_tensor\n\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker\nfrom diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel\n\nfrom utils.vaehook import VAEHook, perfcount\nfrom tqdm import tqdm\nfrom torch import FloatTensor\nfrom PIL import Image\n\nimport time\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> # !pip install opencv-python transformers accelerate\n        >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler\n        >>> from diffusers.utils import load_image\n        >>> import numpy as np\n        >>> import torch\n\n        >>> import cv2\n        >>> from PIL import Image\n\n        >>> # download an image\n        >>> image = load_image(\n        ...     \"https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png\"\n        ... )\n        >>> image = np.array(image)\n\n        >>> # get canny image\n        >>> image = cv2.Canny(image, 100, 200)\n        >>> image = image[:, :, None]\n        >>> image = np.concatenate([image, image, image], axis=2)\n        >>> canny_image = Image.fromarray(image)\n\n        >>> # load control net and stable diffusion v1-5\n        >>> controlnet = ControlNetModel.from_pretrained(\"lllyasviel/sd-controlnet-canny\", torch_dtype=torch.float16)\n        >>> pipe = StableDiffusionControlNetPipeline.from_pretrained(\n        ...     \"runwayml/stable-diffusion-v1-5\", controlnet=controlnet, torch_dtype=torch.float16\n        ... )\n\n        >>> # speed up diffusion process with faster scheduler and memory optimization\n        >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)\n        >>> # remove following line if xformers is not installed\n        >>> pipe.enable_xformers_memory_efficient_attention()\n\n        >>> pipe.enable_model_cpu_offload()\n\n        >>> # generate image\n        >>> generator = torch.manual_seed(0)\n        >>> image = pipe(\n        ...     \"futuristic-looking woman\", num_inference_steps=20, generator=generator, image=canny_image\n        ... ).images[0]\n        ```\n\"\"\"\n\n\nclass StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):\n    r\"\"\"\n    Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    In addition the pipeline inherits the following loading methods:\n        - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            Frozen text-encoder. Stable Diffusion uses the text portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):\n            Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets\n            as a list, the outputs from each ControlNet are added together to create one combined additional\n            conditioning.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n    _optional_components = [\"safety_checker\", \"feature_extractor\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warning(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        if isinstance(controlnet, (list, tuple)):\n            controlnet = MultiControlNetModel(controlnet)\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            tokenizer=tokenizer,\n            unet=unet,\n            controlnet=controlnet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n        )\n        self.scheduler = scheduler\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def _init_tiled_vae(self,\n            encoder_tile_size = 256,\n            decoder_tile_size = 256,\n            fast_decoder = False,\n            fast_encoder = False,\n            color_fix = False,\n            vae_to_gpu = True):\n        # save original forward (only once)\n        if not hasattr(self.vae.encoder, 'original_forward'):\n            setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward)\n        if not hasattr(self.vae.decoder, 'original_forward'):\n            setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward)\n\n        encoder = self.vae.encoder\n        decoder = self.vae.decoder\n\n        self.vae.encoder.forward = VAEHook(\n            encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)\n        self.vae.decoder.forward = VAEHook(\n            decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu)\n\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing\n    def enable_vae_slicing(self):\n        r\"\"\"\n        Enable sliced VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several\n        steps. This is useful to save some memory and allow larger batch sizes.\n        \"\"\"\n        self.vae.enable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing\n    def disable_vae_slicing(self):\n        r\"\"\"\n        Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_slicing()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling\n    def enable_vae_tiling(self):\n        r\"\"\"\n        Enable tiled VAE decoding.\n\n        When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in\n        several steps. This is useful to save a large amount of memory and to allow the processing of larger images.\n        \"\"\"\n        self.vae.enable_tiling()\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling\n    def disable_vae_tiling(self):\n        r\"\"\"\n        Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to\n        computing decoding in one step.\n        \"\"\"\n        self.vae.disable_tiling()\n\n    def enable_sequential_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,\n        text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a\n        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.\n        Note that offloading happens on a submodule basis. Memory savings are higher than with\n        `enable_model_cpu_offload`, but performance is lower.\n        \"\"\"\n        if is_accelerate_available():\n            from accelerate import cpu_offload\n        else:\n            raise ImportError(\"Please install accelerate via `pip install accelerate`\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:\n            cpu_offload(cpu_offloaded_model, device)\n\n        if self.safety_checker is not None:\n            cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)\n\n    def enable_model_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared\n        to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`\n        method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with\n        `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.\n        \"\"\"\n        if is_accelerate_available() and is_accelerate_version(\">=\", \"0.17.0.dev0\"):\n            from accelerate import cpu_offload_with_hook\n        else:\n            raise ImportError(\"`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        hook = None\n        for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:\n            _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)\n\n        if self.safety_checker is not None:\n            # the safety checker can offload the vae again\n            _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)\n\n        # control net hook has be manually offloaded as it alternates with unet\n        cpu_offload_with_hook(self.controlnet, device)\n\n        # We'll offload the last model manually.\n        self.final_offload_hook = hook\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module\n        hooks.\n        \"\"\"\n        if not hasattr(self.unet, \"_hf_hook\"):\n            return self.device\n        for module in self.unet.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt\n    def _encode_prompt(\n        self,\n        prompt,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        negative_prompt=None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        negative_prompt_embeds: Optional[torch.FloatTensor] = None,\n    ):\n        r\"\"\"\n        Encodes the prompt into text encoder hidden states.\n\n        Args:\n             prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            do_classifier_free_guidance (`bool`):\n                whether to use classifier free guidance or not\n            negative_prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts not to guide the image generation. If not defined, one has to pass\n                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is\n                less than `1`).\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            negative_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt\n                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input\n                argument.\n        \"\"\"\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        if prompt_embeds is None:\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n            text_inputs = self.tokenizer(\n                prompt,\n                padding=\"max_length\",\n                max_length=self.tokenizer.model_max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n            text_input_ids = text_inputs.input_ids\n            untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(\n                text_input_ids, untruncated_ids\n            ):\n                removed_text = self.tokenizer.batch_decode(\n                    untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]\n                )\n                logger.warning(\n                    \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                    f\" {self.tokenizer.model_max_length} tokens: {removed_text}\"\n                )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = text_inputs.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            prompt_embeds = self.text_encoder(\n                text_input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            prompt_embeds = prompt_embeds[0]\n\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        bs_embed, seq_len, _ = prompt_embeds.shape\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)\n\n        # get unconditional embeddings for classifier free guidance\n        if do_classifier_free_guidance and negative_prompt_embeds is None:\n            uncond_tokens: List[str]\n            if negative_prompt is None:\n                uncond_tokens = [\"\"] * batch_size\n            elif prompt is not None and type(prompt) is not type(negative_prompt):\n                raise TypeError(\n                    f\"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=\"\n                    f\" {type(prompt)}.\"\n                )\n            elif isinstance(negative_prompt, str):\n                uncond_tokens = [negative_prompt]\n            elif batch_size != len(negative_prompt):\n                raise ValueError(\n                    f\"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:\"\n                    f\" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches\"\n                    \" the batch size of `prompt`.\"\n                )\n            else:\n                uncond_tokens = negative_prompt\n\n            # textual inversion: procecss multi-vector tokens if necessary\n            if isinstance(self, TextualInversionLoaderMixin):\n                uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)\n\n            max_length = prompt_embeds.shape[1]\n            uncond_input = self.tokenizer(\n                uncond_tokens,\n                padding=\"max_length\",\n                max_length=max_length,\n                truncation=True,\n                return_tensors=\"pt\",\n            )\n\n            if hasattr(self.text_encoder.config, \"use_attention_mask\") and self.text_encoder.config.use_attention_mask:\n                attention_mask = uncond_input.attention_mask.to(device)\n            else:\n                attention_mask = None\n\n            negative_prompt_embeds = self.text_encoder(\n                uncond_input.input_ids.to(device),\n                attention_mask=attention_mask,\n            )\n            negative_prompt_embeds = negative_prompt_embeds[0]\n\n        if do_classifier_free_guidance:\n            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method\n            seq_len = negative_prompt_embeds.shape[1]\n\n            negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)\n            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(image, output_type=\"pil\")\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors=\"pt\").to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        warnings.warn(\n            \"The decode_latents method is deprecated and will be removed in a future version. Please\"\n            \" use VaeImageProcessor instead\",\n            FutureWarning,\n        )\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(inspect.signature(self.scheduler.step).parameters.keys())\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        #extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(\n        self,\n        prompt,\n        image,\n        height,\n        width,\n        callback_steps,\n        negative_prompt=None,\n        prompt_embeds=None,\n        negative_prompt_embeds=None,\n        controlnet_conditioning_scale=1.0,\n    ):\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if (callback_steps is None) or (\n            callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n\n        if negative_prompt is not None and negative_prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:\"\n                f\" {negative_prompt_embeds}. Please make sure to only forward one of the two.\"\n            )\n\n        if prompt_embeds is not None and negative_prompt_embeds is not None:\n            if prompt_embeds.shape != negative_prompt_embeds.shape:\n                raise ValueError(\n                    \"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but\"\n                    f\" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`\"\n                    f\" {negative_prompt_embeds.shape}.\"\n                )\n\n        # `prompt` needs more sophisticated handling when there are multiple\n        # conditionings.\n        if isinstance(self.controlnet, MultiControlNetModel):\n            if isinstance(prompt, list):\n                logger.warning(\n                    f\"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}\"\n                    \" prompts. The conditionings will be fixed across the prompts.\"\n                )\n\n        # Check `image`\n        is_compiled = hasattr(F, \"scaled_dot_product_attention\") and isinstance(\n            self.controlnet, torch._dynamo.eval_frame.OptimizedModule\n        )\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            self.check_image(image, prompt, prompt_embeds)\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if not isinstance(image, list):\n                raise TypeError(\"For multiple controlnets: `image` must be type `list`\")\n\n            # When `image` is a nested list:\n            # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])\n            elif any(isinstance(i, list) for i in image):\n                raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif len(image) != len(self.controlnet.nets):\n                raise ValueError(\n                    \"For multiple controlnets: `image` must have the same length as the number of controlnets.\"\n                )\n\n            for image_ in image:\n                self.check_image(image_, prompt, prompt_embeds)\n        else:\n            assert False\n\n        # Check `controlnet_conditioning_scale`\n        if (\n            isinstance(self.controlnet, ControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, ControlNetModel)\n        ):\n            if not isinstance(controlnet_conditioning_scale, float):\n                raise TypeError(\"For single controlnet: `controlnet_conditioning_scale` must be type `float`.\")\n        elif (\n            isinstance(self.controlnet, MultiControlNetModel)\n            or is_compiled\n            and isinstance(self.controlnet._orig_mod, MultiControlNetModel)\n        ):\n            if isinstance(controlnet_conditioning_scale, list):\n                if any(isinstance(i, list) for i in controlnet_conditioning_scale):\n                    raise ValueError(\"A single batch of multiple conditionings are supported at the moment.\")\n            elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(\n                self.controlnet.nets\n            ):\n                raise ValueError(\n                    \"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have\"\n                    \" the same length as the number of controlnets\"\n                )\n        else:\n            assert False\n\n    def check_image(self, image, prompt, prompt_embeds):\n        image_is_pil = isinstance(image, PIL.Image.Image)\n        image_is_tensor = isinstance(image, torch.Tensor)\n        image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)\n        image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)\n\n        if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:\n            raise TypeError(\n                \"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors\"\n            )\n\n        if image_is_pil:\n            image_batch_size = 1\n        elif image_is_tensor:\n            image_batch_size = image.shape[0]\n        elif image_is_pil_list:\n            image_batch_size = len(image)\n        elif image_is_tensor_list:\n            image_batch_size = len(image)\n\n        if prompt is not None and isinstance(prompt, str):\n            prompt_batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            prompt_batch_size = len(prompt)\n        elif prompt_embeds is not None:\n            prompt_batch_size = prompt_embeds.shape[0]\n\n        if image_batch_size != 1 and image_batch_size != prompt_batch_size:\n            raise ValueError(\n                f\"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}\"\n            )\n\n    def prepare_image(\n        self,\n        image,\n        width,\n        height,\n        batch_size,\n        num_images_per_prompt,\n        device,\n        dtype,\n        do_classifier_free_guidance=False,\n        guess_mode=False,\n    ):\n        if not isinstance(image, torch.Tensor):\n            if isinstance(image, PIL.Image.Image):\n                image = [image]\n\n            if isinstance(image[0], PIL.Image.Image):\n                images = []\n\n                for image_ in image:\n                    image_ = image_.convert(\"RGB\")\n                    #image_ = image_.resize((width, height), resample=PIL_INTERPOLATION[\"lanczos\"])\n                    image_ = np.array(image_)\n                    image_ = image_[None, :]\n                    images.append(image_)\n\n                image = images\n\n                image = np.concatenate(image, axis=0)\n                image = np.array(image).astype(np.float32) / 255.0\n                image = image.transpose(0, 3, 1, 2)\n                image = torch.from_numpy(image)#.flip(1)\n            elif isinstance(image[0], torch.Tensor):\n                image = torch.cat(image, dim=0)\n\n        image_batch_size = image.shape[0]\n\n        if image_batch_size == 1:\n            repeat_by = batch_size\n        else:\n            # image batch size is the same as prompt batch size\n            repeat_by = num_images_per_prompt\n\n        image = image.repeat_interleave(repeat_by, dim=0)\n\n        image = image.to(device=device, dtype=dtype)\n\n        if do_classifier_free_guidance and not guess_mode:\n            image = torch.cat([image] * 2)\n\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):\n        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def _default_height_width(self, height, width, image):\n        # NOTE: It is possible that a list of images have different\n        # dimensions for each image, so just checking the first image\n        # is not _exactly_ correct, but it is simple.\n        while isinstance(image, list):\n            image = image[0]\n\n        if height is None:\n            if isinstance(image, PIL.Image.Image):\n                height = image.height\n            elif isinstance(image, torch.Tensor):\n                height = image.shape[2]\n\n            height = (height // 8) * 8  # round down to nearest multiple of 8\n\n        if width is None:\n            if isinstance(image, PIL.Image.Image):\n                width = image.width\n            elif isinstance(image, torch.Tensor):\n                width = image.shape[3]\n\n            width = (width // 8) * 8  # round down to nearest multiple of 8\n\n        return height, width\n\n    # override DiffusionPipeline\n    def save_pretrained(\n        self,\n        save_directory: Union[str, os.PathLike],\n        safe_serialization: bool = False,\n        variant: Optional[str] = None,\n    ):\n        if isinstance(self.controlnet, ControlNetModel):\n            super().save_pretrained(save_directory, safe_serialization, variant)\n        else:\n            raise NotImplementedError(\"Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.\")\n        \n    def previous_timestep(self, timestep):\n        if self.scheduler.custom_timesteps:\n            index = (self.scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0]\n            if index == self.scheduler.timesteps.shape[0] - 1:\n                prev_t = torch.tensor(-1)\n            else:\n                prev_t = self.scheduler.timesteps[index + 1]\n        else:\n            num_inference_steps = (\n                self.scheduler.num_inference_steps if self.scheduler.num_inference_steps else self.scheduler.config.num_train_timesteps\n            )\n            prev_t = timestep - self.scheduler.config.num_train_timesteps // num_inference_steps\n\n        return prev_t   \n\n    def predict_start_from_noise(self, sample, t, model_output):\n        t = t.to(self.scheduler.alphas_cumprod.device)\n        prev_t = self.previous_timestep(t)\n\n        # 1. compute alphas, betas\n        alpha_prod_t = self.scheduler.alphas_cumprod[t].to(sample.device)\n        alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else self.scheduler.one\n        alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device)\n        beta_prod_t = 1 - alpha_prod_t\n        beta_prod_t_prev = 1 - alpha_prod_t_prev\n        current_alpha_t = alpha_prod_t / alpha_prod_t_prev\n        current_beta_t = 1 - current_alpha_t\n\n        # 2. compute predicted original sample from predicted noise also called\n        # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n        if self.scheduler.config.prediction_type == \"epsilon\":\n            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n        elif self.scheduler.config.prediction_type == \"sample\":\n            pred_original_sample = model_output\n        elif self.scheduler.config.prediction_type == \"v_prediction\":\n            pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n        else:\n            raise ValueError(\n                f\"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or\"\n                \" `v_prediction`  for the DDPMScheduler.\"\n            )\n\n        return pred_original_sample\n    \n    def _sliding_windows(self,h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]:\n        hi_list = list(range(0, h - tile_size + 1, tile_stride))\n        if (h - tile_size) % tile_stride != 0:\n            hi_list.append(h - tile_size)\n        \n        wi_list = list(range(0, w - tile_size + 1, tile_stride))\n        if (w - tile_size) % tile_stride != 0:\n            wi_list.append(w - tile_size)\n        \n        coords = []\n        for hi in hi_list:\n            for wi in wi_list:\n                coords.append((hi, hi + tile_size, wi, wi + tile_size))\n        return coords\n\n    # Helper methods within the class\n    def _prepare_controlnet_inputs(self, latent_model_input, latents, prompt_embeds, do_classifier_free_guidance, guess_mode):\n        if guess_mode and do_classifier_free_guidance:\n            return latents, prompt_embeds.chunk(2)[1]\n        return latent_model_input, prompt_embeds\n\n    def _predict_noise(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_diffusion, tile_size, tile_stride, conditioning_scale, guess_mode):\n        if not tile_diffusion:\n            noise_pred = self._unet_predict(latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions)\n        else:\n            noise_pred = self._tile_predict(latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_size, tile_stride, conditioning_scale, guess_mode)\n        return noise_pred\n\n    def _unet_predict(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions):\n        down_res_samples, mid_res_sample = self.controlnet(\n            latent_model_input, t, encoder_hidden_states=prompt_embeds, controlnet_cond=image,\n            conditioning_scale=1.0, guess_mode=False,\n            return_dict=False, vae_encode_condition_hidden_states=vae_conditions\n        )\n        noise_pred = self.unet(\n            latent_model_input, t, encoder_hidden_states=prompt_embeds,\n            cross_attention_kwargs=cross_attention_kwargs,\n            down_block_additional_residuals=down_res_samples,\n            mid_block_additional_residual=mid_res_sample,\n            return_dict=False,\n        )[0]\n        return noise_pred\n\n    def _tile_predict(self, latent_model_input, t, image, prompt_embeds, cross_attention_kwargs, vae_conditions, tile_size, tile_stride, conditioning_scale, guess_mode):\n        tile_weight = self.gaussian_weights(int(tile_size//8), int(tile_size//8), 1).to(latent_model_input.device)\n        noise_pred = torch.zeros_like(latent_model_input, dtype=torch.float32)\n        count = torch.zeros_like(latent_model_input, dtype=torch.float32)\n        h, w = latent_model_input.shape[2:4]\n\n        for hi, hi_end, wi, wi_end in self._sliding_windows(h, w, int(tile_size // 8), int(tile_stride // 8)):\n            tile = latent_model_input[:, :, hi:hi_end, wi:wi_end]\n            tile_cond = vae_conditions[:, :, hi:hi_end, wi:wi_end] if vae_conditions is not None else None\n            tile_image = image[:, :, hi*8:hi_end*8, wi*8:wi_end*8]\n            # tile_cond = self.vae.encode(tile_image * 2 - 1).latent_dist.sample() * self.vae.config.scaling_factor\n            \n            down_block_res_samples, mid_block_res_sample = [None]*10, None\n            down_res_samples, mid_res_sample = self.controlnet(\n                tile, t, encoder_hidden_states=prompt_embeds, controlnet_cond=tile_image,\n                conditioning_scale=1.0, guess_mode=False,\n                return_dict=False, vae_encode_condition_hidden_states=tile_cond\n            )\n            tile_noise = self.unet(\n                tile, t, encoder_hidden_states=prompt_embeds,\n                cross_attention_kwargs=cross_attention_kwargs,\n                down_block_additional_residuals=down_res_samples,\n                mid_block_additional_residual=mid_res_sample,\n                return_dict=False,\n            )[0]\n            noise_pred[:, :, hi:hi_end, wi:wi_end] += tile_noise * tile_weight\n            count[:, :, hi:hi_end, wi:wi_end] += tile_weight\n        noise_pred /= count\n        return noise_pred.to(torch.float16)\n\n    def _initial_step(self, do_classifier_free_guidance, latents, t, timesteps, prompt_embeds, image, vae_conditions, tile_diffusion, tile_size, tile_stride):\n        if do_classifier_free_guidance:\n            prompt_embeds = prompt_embeds.chunk(2)[0]\n            image = image.chunk(2)[0]\n            vae_conditions = vae_conditions.chunk(2)[0]\n        \n        noise_pred = self._predict_noise(latents, t, image, prompt_embeds, None, vae_conditions, tile_diffusion, tile_size, tile_stride, 1.0, False)\n        x0_T = self.predict_start_from_noise(latents, t, noise_pred)\n        noise_tao = torch.randn_like(latents)\n        latents = self.scheduler.add_noise(x0_T, noise_tao, timesteps)\n        return latents, x0_T\n\n    def _postprocess_latents(self, latents, output_type, do_denormalize):\n        latents = latents.to(torch.float16)\n        if output_type != \"latent\":\n            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0].to(torch.float32)\n            image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)\n        else:\n            image = latents\n        return image\n\n    def gaussian_weights(self, tile_width: int, tile_height: int, nbatches: int) -> torch.Tensor:\n        \"\"\"Generates a gaussian mask of weights for tile contributions\"\"\"\n        from numpy import pi, exp, sqrt\n        import numpy as np\n\n        latent_width = tile_width\n        latent_height = tile_height\n\n        var = 0.01\n        midpoint = (latent_width - 1) / 2  # -1 because index goes from 0 to latent_width - 1\n        x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]\n        midpoint = latent_height / 2\n        y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]\n\n        weights = np.outer(y_probs, x_probs)\n        return torch.tile(torch.tensor(weights, device=self.device), (nbatches, 4, 1, 1))\n    \n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        t_max: float,\n        t_min: float,\n        tile_diffusion: bool, \n        tile_size: float,\n        tile_stride: float,\n        prompt: Union[str, List[str]] = None,\n        image: Union[FloatTensor, Image.Image, List[FloatTensor], List[Image.Image]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 7.5,\n        negative_prompt: Optional[Union[str, List[str]]] = None,\n        num_images_per_prompt: Optional[int] = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[FloatTensor] = None,\n        prompt_embeds: Optional[FloatTensor] = None,\n        negative_prompt_embeds: Optional[FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n        conditioning_scale: Union[float, List[float]] = 1.0,\n        guess_mode: bool = False,\n        start_steps: int = 999,\n        use_vae_encode_condition: bool = False,\n        start_point: str = 'noise',\n    ) -> Union[StableDiffusionPipelineOutput, tuple]:\n        r\"\"\"\n        Optimized diffusion pipeline call for image super-resolution.\n        For 'Improving the Stability and Efficiency of Diffusion Models for Content Consistent Super-Resolution'.\n\n        Examples:\n            # Example usage:\n            # pipeline(t_max=0.6667, t_min=0.5, tile_diffusion=True, tile_size=256, tile_stride=128, prompt=\"\", num_inference_steps=6)\n            pass\n        \"\"\"\n        # 0. Set default height and width\n        height, width = self._default_height_width(height, width, image)\n        \n        # 1. Determine batch size\n        if prompt is not None:\n            batch_size = 1 if isinstance(prompt, str) else len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n        \n        device = self._execution_device\n        do_classifier_free_guidance = guidance_scale > 1.0\n        \n        controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet\n\n        # 2. Prepare image\n        image = self.prepare_image(\n            image=image, width=width, height=height,\n            batch_size=batch_size * num_images_per_prompt,\n            num_images_per_prompt=num_images_per_prompt,\n            device=device, dtype=controlnet.dtype,\n            do_classifier_free_guidance=do_classifier_free_guidance,\n            guess_mode=guess_mode\n        )\n\n        # 3. Prepare scheduler timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 4. Prepare extra step kwargs\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n        \n        ### calculate the running time for each inference step\n        torch.cuda.synchronize()\n        start_time = time.time()\n        # 5. Encode prompts\n        prompt_embeds = self._encode_prompt(\n            prompt, device, num_images_per_prompt, do_classifier_free_guidance,\n            negative_prompt, prompt_embeds=prompt_embeds,\n            negative_prompt_embeds=negative_prompt_embeds\n        ) \n\n        # 6. Prepare latent variables\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            self.unet.config.in_channels,\n            height, width, prompt_embeds.dtype,\n            device, generator, latents\n        )\n\n        # 7. Initialize latent variables based on start_point\n        latents_condition_image = self.vae.encode(image * 2 - 1).latent_dist.sample() * self.vae.config.scaling_factor\n        if start_point != 'noise':\n            start_steps_tensor = torch.randint(start_steps, start_steps + 1, (latents.shape[0],), device=latents.device).long()\n            latents = self.scheduler.add_noise(latents_condition_image[0:1, ...], latents, start_steps_tensor)\n\n        # 8. Optionally prepare VAE-encoded condition\n        vae_encode_condition_hidden_states = (\n            latents_condition_image if use_vae_encode_condition else None\n        )\n\n        # 9. Initial prediction at t_max if needed\n        total_steps = len(timesteps)\n        t_tao = timesteps[-round(total_steps * t_max)]\n          \n        \n        if t_max != 1:\n            t = torch.randint(start_steps, start_steps+1, (batch_size,), device=latents.device)\n            latents = latents.to(torch.float16)\n            # we do not do the classifier free guidance in this step\n            latent_model_input = self.scheduler.scale_model_input(latents, t)\n            latents, x0_T = self._initial_step(do_classifier_free_guidance, latent_model_input, t, t_tao, prompt_embeds, image, vae_encode_condition_hidden_states, tile_diffusion, tile_size, tile_stride)\n        # redefine timesteps\n        timesteps = timesteps[-round(total_steps * t_max):]\n        timesteps = timesteps[:-round(total_steps * t_min)] if t_min > 0 else timesteps\n        \n        # 10. Denoising loop\n        if num_inference_steps==1:\n            latents = x0_T\n        else:\n            with self.progress_bar(total=len(timesteps)) as progress_bar:\n                for i, t in enumerate(timesteps):\n                    latents = latents.to(torch.float16)\n                    latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents\n                    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)\n            \n                    controlnet_latent_model_input, controlnet_prompt_embeds = self._prepare_controlnet_inputs(latent_model_input, latents, prompt_embeds, do_classifier_free_guidance, guess_mode)\n            \n                    noise_pred = self._predict_noise(\n                        controlnet_latent_model_input, t, image, controlnet_prompt_embeds, cross_attention_kwargs, \n                        vae_encode_condition_hidden_states, tile_diffusion, tile_size, tile_stride, conditioning_scale, guess_mode\n                    )\n            \n                    if do_classifier_free_guidance:\n                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)\n\n                    latents_old = latents\n                    latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]\n                    \n\n                    # call the callback, if provided\n                    progress_bar.update()\n                    if i == len(timesteps) - 1:\n                        if callback is not None and i % callback_steps == 0:\n                            callback(i, t, latents)\n\n        # Predict x0 for t_min\n        if t_min:\n            x0_tmin = self.predict_start_from_noise(latents_old, t, noise_pred)\n            latents = x0_tmin\n\n        # 11. Post-processing\n        has_nsfw_concept = None\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self._postprocess_latents(latents, output_type, do_denormalize)\n        \n        ## cauculate the inference time for each inference step\n        torch.cuda.synchronize()\n        end_time = time.time()\n        total_time = end_time - start_time\n        \n        return total_time, StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)\n"
  },
  {
    "path": "requirements.txt",
    "content": "diffusers==0.21.0\ntorch==2.0.1\npytorch_lightning\naccelerate==1.2.0\ntransformers==4.25.0\nxformers==0.0.22 \nloralib\nfairscale==0.4.13\nbasicsr==1.4.2\ntimm==0.9.5\npydantic==1.10.11\nhuggingface_hub==0.25.2\nopencv-python-headless\nlpips\n"
  },
  {
    "path": "scripts/get_path.py",
    "content": "import os\n\ndef write_png_paths(folder_path, txt_path):\n    with open(txt_path, 'w') as f:\n        for root, dirs, files in os.walk(folder_path):\n            for file in files:\n                if file.endswith('.png'):\n                    f.write(os.path.join(root, file) + '\\n')\n\n# Example usage:\nfolder_path = ''\ntxt_path = '/gt_path.txt'\nwrite_png_paths(folder_path, txt_path)"
  },
  {
    "path": "scripts/test/test_ccsr_multistep.sh",
    "content": "python test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 6 \\\n--t_max 0.6667 \\\n--t_min 0.5 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 4.5 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4"
  },
  {
    "path": "scripts/test/test_ccsr_onestep.sh",
    "content": "\npython test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 1 \\\n--t_min 0.0 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 1.0 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4"
  },
  {
    "path": "scripts/test/test_ccsr_tile.sh",
    "content": "python test_ccsr_tile.py \\\n--pretrained_model_path preset/models/stable-diffusion-2-1-base \\\n--controlnet_model_path preset/models \\\n--vae_model_path preset/models \\\n--baseline_name ccsr-v2 \\\n--image_path preset/test_datasets \\\n--output_dir experiments/test \\\n--sample_method ddpm \\\n--num_inference_steps 6 \\\n--t_max 0.6667 \\\n--t_min 0.5 \\\n--start_point lr \\\n--start_steps 999 \\\n--process_size 512 \\\n--guidance_scale 4.5 \\\n--sample_times 1 \\\n--use_vae_encode_condition \\\n--upscale 4 \\\n--tile_diffusion \\\n--tile_diffusion_size 512 \\\n--tile_diffusion_stride 256 \\\n--tile_vae \\\n--vae_decoder_tile_size 224 \\\n--vae_encoder_tile_size 1024 \\"
  },
  {
    "path": "scripts/train/train_ccsr_stage1.sh",
    "content": "CUDA_VISIBLE_DEVICES=\"0,1,2,3,\" accelerate launch train_ccsr_stage1.py \\\n--pretrained_model_name_or_path=\"preset/models/stable-diffusion-2-1-base\" \\\n--controlnet_model_name_or_path='preset/models/pretrained_controlnet' \\\n--enable_xformers_memory_efficient_attention \\\n--output_dir=\"./experiments/ccsrv2_stage1\" \\\n--mixed_precision=\"fp16\" \\\n--resolution=512 \\\n--learning_rate=5e-5 \\\n--train_batch_size=4 \\\n--gradient_accumulation_steps=6 \\\n--dataloader_num_workers=0 \\\n--checkpointing_steps=500 \\\n--t_max=0.6667 \\\n--max_train_steps=20000 \\\n--dataset_root_folders 'preset/gt_path.txt' "
  },
  {
    "path": "scripts/train/train_ccsr_stage2.sh",
    "content": "CUDA_VISIBLE_DEVICES=\"0,1,2,3,\" accelerate launch train_ccsr_stage2.py \\\n--pretrained_model_name_or_path=\"preset/models/stable-diffusion-2-1-base\" \\\n--controlnet_model_name_or_path='preset/models/model_stage1' \\\n--enable_xformers_memory_efficient_attention \\\n--output_dir=\"./experiments/ccsrv2_stage2\" \\\n--mixed_precision=\"fp16\" \\\n--resolution=512 \\\n--learning_rate=5e-6 \\\n--train_batch_size=2 \\\n--gradient_accumulation_steps=8 \\\n--checkpointing_steps=500 \\\n--is_start_lr=True \\\n--t_max=0.6667 \\\n--num_inference_steps=1 \\\n--is_module \\\n--lambda_l2=1.0 \\\n--lambda_lpips=1.0 \\\n--lambda_disc=0.05 \\\n--lambda_disc_train=0.5 \\\n--begin_disc=100 \\\n--max_train_steps=2000 \\\n--dataset_root_folders 'preset/gt_path.txt'  "
  },
  {
    "path": "scripts/train/train_controlnet.sh",
    "content": "\nCUDA_VISIBLE_DEVICES=\"0,1,2,3,\" accelerate launch train_controlnet.py \\\n--pretrained_model_name_or_path=\"preset/models/stable-diffusion-2-1-base\" \\\n--controlnet_model_name_or_path='' \\\n --enable_xformers_memory_efficient_attention \\\n --output_dir=\"./experiments/pretrained_controlnet\" \\\n --mixed_precision=\"fp16\" \\\n --resolution=512 \\\n --learning_rate=5e-5 \\\n --train_batch_size=4 \\\n --gradient_accumulation_steps=6 \\\n --dataloader_num_workers=0 \\\n --checkpointing_steps=5000 \\\n --max_train_steps=40000 \\\n --dataset_root_folders 'preset/gt_path.txt'"
  },
  {
    "path": "test_ccsr_tile.py",
    "content": "import os\nimport glob\nimport math\nimport time\nimport argparse\n\nimport numpy as np\nfrom PIL import Image\nimport safetensors.torch\n\nimport torch\nfrom torchvision import transforms\nimport torchvision.transforms.functional as F\n\nfrom accelerate import Accelerator\nfrom accelerate.utils import set_seed\nfrom diffusers import (\n    AutoencoderKL,\n    UniPCMultistepScheduler,\n    DPMSolverMultistepScheduler,\n    DDPMScheduler,\n    UNet2DConditionModel,\n)\n\nfrom diffusers.utils import check_min_version\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor\n\nfrom pipelines.pipeline_ccsr import StableDiffusionControlNetPipeline\nfrom myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix\nfrom models.controlnet import ControlNetModel\n\n\n\ndef load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention):\n    \n    scheduler_mapping = {\n        'unipcmultistep': UniPCMultistepScheduler,\n        'ddpm': DDPMScheduler,\n        'dpmmultistep': DPMSolverMultistepScheduler,\n    }\n\n    try:\n        scheduler_cls = scheduler_mapping[args.sample_method]\n    except KeyError:\n        raise ValueError(f\"Invalid sample_method: {args.sample_method}\")\n\n    scheduler = scheduler_cls.from_pretrained(args.pretrained_model_path, subfolder=\"scheduler\")\n\n    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder=\"text_encoder\")\n    tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder=\"tokenizer\")\n    feature_extractor = CLIPImageProcessor.from_pretrained(os.path.join(args.pretrained_model_path, \"feature_extractor\"))\n    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_path, subfolder=\"unet\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_path, subfolder=\"controlnet\")\n\n    vae_path = args.vae_model_path if args.vae_model_path else args.pretrained_model_path\n    vae = AutoencoderKL.from_pretrained(vae_path, subfolder=\"vae\")\n\n    # Freeze models\n    for model in [vae, text_encoder, unet, controlnet]:\n        model.requires_grad_(False)\n\n    # Enable xformers if available\n    if enable_xformers_memory_efficient_attention:\n        if is_xformers_available():\n            unet.enable_xformers_memory_efficient_attention()\n            controlnet.enable_xformers_memory_efficient_attention()\n        else:\n            raise ValueError(\"xformers is not available. Ensure it is installed correctly.\")\n\n    # Initialize pipeline\n    validation_pipeline = StableDiffusionControlNetPipeline(\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        feature_extractor=feature_extractor,\n        unet=unet,\n        controlnet=controlnet,\n        scheduler=scheduler,\n        safety_checker=None,\n        requires_safety_checker=False,\n    )\n\n    if args.tile_vae:\n        validation_pipeline._init_tiled_vae(\n            encoder_tile_size=args.vae_encoder_tile_size,\n            decoder_tile_size=args.vae_decoder_tile_size\n        )\n\n    # Set weight dtype based on mixed precision\n    dtype_mapping = {\n        \"fp16\": torch.float16,\n        \"bf16\": torch.bfloat16,\n    }\n    weight_dtype = dtype_mapping.get(accelerator.mixed_precision, torch.float32)\n\n    # Move models to accelerator device with appropriate dtype\n    for model in [text_encoder, vae, unet, controlnet]:\n        model.to(accelerator.device, dtype=weight_dtype)\n\n    return validation_pipeline\n\ndef main(args, enable_xformers_memory_efficient_attention=True,):\n    \n    detailed_output_dir = os.path.join(\n        args.output_dir,\n        f\"sr_{args.baseline_name}_{args.sample_method}_{str(args.num_inference_steps).zfill(3)}steps_{args.start_point}{args.start_steps}_size{args.process_size}_cfg{args.guidance_scale}\"\n    )\n\n    accelerator = Accelerator(\n        mixed_precision=args.mixed_precision,\n    )\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the output folder creation\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        os.makedirs(detailed_output_dir, exist_ok=True)\n        accelerator.init_trackers(\"Controlnet\")\n\n    pipeline = load_pipeline(args, accelerator, enable_xformers_memory_efficient_attention)\n\n    if accelerator.is_main_process:\n        generator = torch.Generator(device=accelerator.device)\n        if args.seed is not None:\n            generator.manual_seed(args.seed)\n\n        image_paths = sorted(glob.glob(os.path.join(args.image_path, \"*.*\"))) if os.path.isdir(args.image_path) else [args.image_path]\n\n        time_records = []\n        for image_path in image_paths:\n            validation_image = Image.open(image_path).convert(\"RGB\")\n            negative_prompt = args.negative_prompt\n            validation_prompt = args.added_prompt \n\n            ori_width, ori_height = validation_image.size\n            resize_flag = False\n            rscale = args.upscale\n            if ori_width < args.process_size//rscale or ori_height < args.process_size//rscale:\n                scale = (args.process_size//rscale)/min(ori_width, ori_height)\n                tmp_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height)))\n                validation_image = tmp_image\n                resize_flag = True\n\n\n            validation_image = validation_image.resize((validation_image.size[0]*rscale, validation_image.size[1]*rscale))\n            validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8))\n            width, height = validation_image.size\n            resize_flag = True #\n      \n            for sample_idx in range(args.sample_times):\n                os.makedirs(f'{detailed_output_dir}/sample{str(sample_idx).zfill(2)}/', exist_ok=True)\n                \n            for sample_idx in range(args.sample_times):\n\n                inference_time, image = pipeline(\n                    args.t_max,\n                    args.t_min,\n                    args.tile_diffusion,\n                    args.tile_diffusion_size,\n                    args.tile_diffusion_stride,\n                    args.added_prompt,\n                    validation_image,\n                    num_inference_steps=args.num_inference_steps,\n                    generator=generator,\n                    height=height,\n                    width=width,\n                    guidance_scale=args.guidance_scale,\n                    negative_prompt=args.negative_prompt,\n                    conditioning_scale=args.conditioning_scale,\n                    start_steps=args.start_steps,\n                    start_point=args.start_point,\n                    use_vae_encode_condition=args.use_vae_encode_condition,\n                )\n                image = image.images[0]\n\n                print(f\"Inference time: {inference_time:.4f} seconds\")\n                time_records.append(inference_time)\n\n                # Apply color fixing if specified\n                if args.align_method != 'nofix':\n                    fix_func = wavelet_color_fix if args.align_method == 'wavelet' else adain_color_fix\n                    image = fix_func(image, validation_image)\n                    \n                if resize_flag: \n                    image = image.resize((ori_width*rscale, ori_height*rscale))\n                \n                image_tensor = torch.clamp(F.to_tensor(image), 0, 1)\n                final_image = transforms.ToPILImage()(image_tensor)\n                base_name = os.path.splitext(os.path.basename(image_path))[0]\n                save_path = os.path.join(detailed_output_dir, f\"sample{str(sample_idx).zfill(2)}\", f\"{base_name}.png\")\n                image.save(save_path)\n        \n        # Calculate the average inference time, excluding the first few for stabilization\n        if len(time_records) > 3:\n            average_time = np.mean(time_records[3:])\n        else:\n            average_time = np.mean(time_records)\n        if accelerator.is_main_process:\n            print(f\"Average inference time: {average_time:.4f} seconds\")   \n                    \n\n    # Save the run settings to a file\n    settings_path = os.path.join(detailed_output_dir, \"settings.txt\")\n    with open(settings_path, 'w') as f:\n        f.write(\"------------------ start ------------------\\n\")\n        for key, value in vars(args).items():\n            f.write(f\"{key} : {value}\\n\")\n        f.write(\"------------------- end -------------------\\n\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description=\"Stable Diffusion ControlNet Pipeline for Super-Resolution\")\n    parser.add_argument(\"--controlnet_model_path\", type=str, default=\"\", help=\"Path to ControlNet model\")\n    parser.add_argument(\"--pretrained_model_path\", type=str, default=\"\", help=\"Path to pretrained model\")\n    parser.add_argument(\"--vae_model_path\", type=str, default=\"\", help=\"Path to VAE model\")\n    parser.add_argument(\"--added_prompt\", type=str, default=\"clean, high-resolution, 8k\", help=\"Additional prompt for generation\")\n    parser.add_argument(\"--negative_prompt\", type=str, default=\"blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed\", help=\"Negative prompt to avoid certain features\")\n    parser.add_argument(\"--image_path\", type=str, default=\"\", help=\"Path to input image or directory\")\n    parser.add_argument(\"--output_dir\", type=str, default=\"\", help=\"Directory to save outputs\")\n    parser.add_argument(\"--mixed_precision\", type=str, choices=[\"no\", \"fp16\", \"bf16\"], default=\"fp16\", help=\"Mixed precision mode\")\n    parser.add_argument(\"--guidance_scale\", type=float, default=1.0, help=\"Guidance scale for generation\")\n    parser.add_argument(\"--conditioning_scale\", type=float, default=1.0, help=\"Conditioning scale\")\n    parser.add_argument(\"--num_inference_steps\", type=int, default=1, help=\"Number of inference steps(not the final inference time)\")\n    # final_inference_time = num_inference_steps * (t_max - t_min) + 1\n    parser.add_argument(\"--t_max\", type=float, default=0.6666, help=\"Maximum timestep\")\n    parser.add_argument(\"--t_min\", type=float, default=0.0, help=\"Minimum timestep\")\n    parser.add_argument(\"--process_size\", type=int, default=512, help=\"Processing size of the image\")\n    parser.add_argument(\"--upscale\", type=int, default=1, help=\"Upscaling factor\")\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"Random seed\")\n    parser.add_argument(\"--sample_times\", type=int, default=5, help=\"Number of samples to generate per image\")\n    parser.add_argument(\"--sample_method\", type=str, choices=['unipcmultistep', 'ddpm', 'dpmmultistep'], default='ddpm', help=\"Sampling method\")\n    parser.add_argument(\"--align_method\", type=str, choices=['wavelet', 'adain', 'nofix'], default='adain', help=\"Alignment method for color fixing\")\n    parser.add_argument(\"--start_steps\", type=int, default=999, help=\"Starting steps\")\n    parser.add_argument(\"--start_point\", type=str, choices=['lr', 'noise'], default='lr', help=\"Starting point for generation\")\n    parser.add_argument(\"--baseline_name\", type=str, default='ccsr-v2', help=\"Baseline name for output naming\")\n    parser.add_argument(\"--use_vae_encode_condition\", action='store_true', help=\"Use VAE encoding LQ condition\")\n    \n    # Tiling settings for high-resolution SR\n    parser.add_argument(\"--tile_diffusion\", action=\"store_true\", help=\"Optionally! Enable tile-based diffusion\")\n    parser.add_argument(\"--tile_diffusion_size\", type=int, default=512, help=\"Tile size for diffusion\")\n    parser.add_argument(\"--tile_diffusion_stride\", type=int, default=256, help=\"Stride size for diffusion tiles\")\n    parser.add_argument(\"--tile_vae\", action=\"store_true\", help=\"Optionally! Enable tiling for VAE\")\n    parser.add_argument(\"--vae_decoder_tile_size\", type=int, default=224, help=\"Tile size for VAE decoder\")\n    parser.add_argument(\"--vae_encoder_tile_size\", type=int, default=1024, help=\"Tile size for VAE encoder\")\n\n    args = parser.parse_args()\n    main(args)"
  },
  {
    "path": "train_ccsr_stage1.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionControlNetPipeline,\n    UniPCMultistepScheduler,\n)\nfrom models.controlnet import ControlNetModel\nfrom models.unet_2d_condition import UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\nfrom dataloaders.paired_dataset_txt import PairedCaptionDataset\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.21.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n\n    controlnet = accelerator.unwrap_model(controlnet)\n\n    pipeline = StableDiffusionControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        unet=unet,\n        controlnet=controlnet,\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    validation_prompt, validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = []\n\n                formatted_images.append(np.asarray(validation_image))\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warn(f\"image logging not implemented for {tracker.name}\")\n\n        return image_logs\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- controlnet\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n\n\n    parser.add_argument('--dataset_root_folders', type=str, default=\"\")\n    parser.add_argument(\"--t_max\", type=float, default=0.6667)\n    parser.add_argument(\"--start_timesteps\", type=int, default=999)\n\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=\"\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default='',\n        help=\"Path to pretrained controlnet model.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"./experiments/test\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    \n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1000)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    \n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    \n    \n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"fp16\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--is_start_lr\",\n        type=bool,\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=100,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=1,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_ccsr_stage1\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\ndef previous_timestep(timestep):\n    if noise_scheduler.custom_timesteps:\n        index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0]\n        if index == noise_scheduler.timesteps.shape[0] - 1:\n            prev_t = torch.tensor(-1)\n        else:\n            prev_t = noise_scheduler.timesteps[index + 1]\n    else:\n        num_inference_steps = (\n            noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps\n        )\n        prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps\n\n    return prev_t\n\ndef predict_start_from_noise(sample, t, model_output):\n    t = t.to(noise_scheduler.alphas_cumprod.device)\n    prev_t = previous_timestep(t)\n\n    # 1. compute alphas, betas\n    alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device)\n    alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one\n    alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device)\n    beta_prod_t = 1 - alpha_prod_t\n    beta_prod_t_prev = 1 - alpha_prod_t_prev\n    current_alpha_t = alpha_prod_t / alpha_prod_t_prev\n    current_beta_t = 1 - current_alpha_t\n\n    # 2. compute predicted original sample from predicted noise also called\n    # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n    elif noise_scheduler.config.prediction_type == \"sample\":\n        pred_original_sample = model_output\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n    else:\n        raise ValueError(\n            f\"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or\"\n            \" `v_prediction`  for the DDPMScheduler.\"\n        )\n\n    return pred_original_sample\n\n# def main(args):\nargs = parse_args()\nlogging_dir = Path(args.output_dir, args.logging_dir)\n\naccelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\naccelerator = Accelerator(\n    gradient_accumulation_steps=args.gradient_accumulation_steps,\n    mixed_precision=args.mixed_precision,\n    log_with=args.report_to,\n    project_config=accelerator_project_config,\n)\n\n# Make one log on every process with the configuration for debugging.\nlogging.basicConfig(\n    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n    datefmt=\"%m/%d/%Y %H:%M:%S\",\n    level=logging.INFO,\n)\nlogger.info(accelerator.state, main_process_only=False)\nif accelerator.is_local_main_process:\n    transformers.utils.logging.set_verbosity_warning()\n    diffusers.utils.logging.set_verbosity_info()\nelse:\n    transformers.utils.logging.set_verbosity_error()\n    diffusers.utils.logging.set_verbosity_error()\n\n# If passed along, set the training seed now.\nif args.seed is not None:\n    set_seed(args.seed)\n\n# Handle the repository creation\nif accelerator.is_main_process:\n    if args.output_dir is not None:\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    if args.push_to_hub:\n        repo_id = create_repo(\n            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n        ).repo_id\n\n# Load the tokenizer\nif args.tokenizer_name:\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\nelif args.pretrained_model_name_or_path:\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n# import correct text encoder class\ntext_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n# Load scheduler and models\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = text_encoder_cls.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\nvae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n\nif args.controlnet_model_name_or_path:\n    logger.info(\"Loading existing controlnet weights\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path, subfolder=\"controlnet\")\nelse:\n    logger.info(\"Initializing controlnet weights from unet\")\n    controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True)\n\n# `accelerate` 0.16.0 will have better support for customized saving\nif version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        i = len(weights) - 1\n        for i, model in enumerate(models):\n            sub_dir = \"unet\" if isinstance(model, UNet2DConditionModel) else \"controlnet\"\n            model.save_pretrained(os.path.join(output_dir, sub_dir))\n            # make sure to pop weight so that corresponding model is not saved again\n            weights.pop()\n\n    def load_model_hook(models, input_dir):\n        assert len(models) == 2\n        for i in range(len(models)):\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            # load diffusers style into model\n            if not isinstance(model, UNet2DConditionModel):\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n            else:\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n\n            model.register_to_config(**load_model.config)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\nvae.requires_grad_(False)\nunet.requires_grad_(False)\ntext_encoder.requires_grad_(False)\ncontrolnet.train()\n\nif args.enable_xformers_memory_efficient_attention:\n    if is_xformers_available():\n        import xformers\n\n        xformers_version = version.parse(xformers.__version__)\n        if xformers_version == version.parse(\"0.0.16\"):\n            logger.warn(\n                \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n            )\n        unet.enable_xformers_memory_efficient_attention()\n        controlnet.enable_xformers_memory_efficient_attention()\n    else:\n        raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\nif args.gradient_checkpointing:\n    unet.enable_gradient_checkpointing()\n    controlnet.enable_gradient_checkpointing()\n\n# Check that all trainable models are in full precision\nlow_precision_error_string = (\n    \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n    \" doing mixed precision training, copy of the weights should still be float32.\"\n)\n\nif accelerator.unwrap_model(controlnet).dtype != torch.float32:\n    raise ValueError(\n        f\"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n    )\nif accelerator.unwrap_model(unet).dtype != torch.float32:\n    raise ValueError(\n        f\"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n    )\n\n\n# Enable TF32 for faster training on Ampere GPUs,\n# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\nif args.allow_tf32:\n    torch.backends.cuda.matmul.allow_tf32 = True\n\nif args.scale_lr:\n    args.learning_rate = (\n        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n    )\n\n# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\nif args.use_8bit_adam:\n    try:\n        import bitsandbytes as bnb\n    except ImportError:\n        raise ImportError(\n            \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n        )\n\n    optimizer_class = bnb.optim.AdamW8bit\nelse:\n    optimizer_class = torch.optim.AdamW\n\nparams_to_optimize = list(controlnet.parameters())\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n\ntrain_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders,\n                                    tokenizer=tokenizer,\n                                    gt_ratio=0) # let lr is gt\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset,\n    num_workers=args.dataloader_num_workers,\n    batch_size=args.train_batch_size,\n    shuffle=False\n)\n\n# For mixed precision training we cast the text_encoder and vae weights to half-precision\n# as these models are only used for inference, keeping weights in full precision is not required.\nweight_dtype = torch.float32\nif accelerator.mixed_precision == \"fp16\":\n    weight_dtype = torch.float16\nelif accelerator.mixed_precision == \"bf16\":\n    weight_dtype = torch.bfloat16\n\n# Move vae, unet and text_encoder to device and cast to weight_dtype\nvae.to(accelerator.device, dtype=weight_dtype)\ntext_encoder.to(accelerator.device, dtype=weight_dtype)\nunet.to(accelerator.device, dtype=weight_dtype)\n\n# Scheduler and math around the number of training steps.\noverrode_max_train_steps = False\nnum_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\nif args.max_train_steps is None:\n    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    overrode_max_train_steps = True\n\nlr_scheduler = get_scheduler(\n    args.lr_scheduler,\n    optimizer=optimizer,\n    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n    num_training_steps=args.max_train_steps * accelerator.num_processes,\n    num_cycles=args.lr_num_cycles,\n    power=args.lr_power,\n)\n\n# Prepare everything with our `accelerator`.\ncontrolnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n    controlnet, optimizer, train_dataloader, lr_scheduler\n)\n\n\n# We need to initialize the trackers we use, and also store our configuration.\n# The trackers initializes automatically on the main process.\nif accelerator.is_main_process:\n    tracker_config = dict(vars(args))\n\n    # tensorboard cannot handle list types for config\n    tracker_config.pop(\"validation_prompt\")\n    tracker_config.pop(\"validation_image\")\n\n    accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n# Train!\ntotal_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\nlogger.info(\"***** Running training *****\")\n\nlogger.info(f\"  Num Epochs = {args.num_train_epochs}\")\nlogger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\nlogger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\nlogger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\nlogger.info(f\"  Total optimization steps = {args.max_train_steps}\")\nglobal_step = 0\nfirst_epoch = 0\n\n# Potentially load in the weights and states from a previous save\nif args.resume_from_checkpoint:\n    if args.resume_from_checkpoint != \"latest\":\n        path = os.path.basename(args.resume_from_checkpoint)\n    else:\n        # Get the most recent checkpoint\n        dirs = os.listdir(args.output_dir)\n        dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n        dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n        path = dirs[-1] if len(dirs) > 0 else None\n\n    if path is None:\n        accelerator.print(\n            f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n        )\n        args.resume_from_checkpoint = None\n        initial_global_step = 0\n    else:\n        accelerator.print(f\"Resuming from checkpoint {path}\")\n        accelerator.load_state(os.path.join(args.output_dir, path))\n        global_step = int(path.split(\"-\")[1])\n\n        initial_global_step = global_step\n        first_epoch = global_step // num_update_steps_per_epoch\nelse:\n    initial_global_step = 0\n\nprogress_bar = tqdm(\n    range(0, args.max_train_steps),\n    initial=initial_global_step,\n    desc=\"Steps\",\n    # Only show the progress bar once on each machine.\n    disable=not accelerator.is_local_main_process,\n)\n\n\nfor epoch in range(first_epoch, args.num_train_epochs):\n    for step, batch in enumerate(train_dataloader):\n        with accelerator.accumulate(controlnet):\n            pixel_values = batch[\"pixel_values\"].to(accelerator.device, dtype=weight_dtype)\n            # Convert images to latent space\n            latents = vae.encode(pixel_values).latent_dist.sample()\n            latents = latents * vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn_like(latents)\n            bsz = latents.shape[0]\n            # Sample a random timestep for each image\n            t_max = round(noise_scheduler.config.num_train_timesteps*args.t_max)\n            timesteps = torch.randint(0, t_max, (bsz,), device=latents.device)\n            timesteps = timesteps.long()\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n            # # Get the text embedding for conditioning\n            encoder_hidden_states = text_encoder(batch[\"input_caption\"].to(accelerator.device))[0]\n\n            controlnet_image = batch[\"conditioning_pixel_values\"].to(accelerator.device, dtype=weight_dtype)\n\n            vae_encode_condition_hidden_states = vae.encode(2*controlnet_image-1).latent_dist.sample()\n            vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor\n\n            down_block_res_samples, mid_block_res_sample = controlnet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                controlnet_cond=controlnet_image,\n                return_dict=False,\n                vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n            )\n\n            # Predict the noise residual\n            model_pred = unet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                down_block_additional_residuals=[\n                    sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                ],\n                mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n            loss_ori = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\") # original loss\n\n            ###   loss for t_max ###\n            noise2 = torch.randn_like(latents)\n            timesteps = args.start_timesteps * torch.ones(model_pred.shape[0]).to(accelerator.device)\n            timesteps = timesteps.long()\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            # (this is the forward diffusion process)\n            if args.start_timesteps==1:\n                noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise2, timesteps)\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                    vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n                )\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    down_block_additional_residuals=[\n                        sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                    ],\n                    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                ).sample\n\n                # Predict x0 for T\n                x0_T = noisy_latents - model_pred\n\n            else:\n                if args.is_start_lr:\n                    noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise2, timesteps)\n                else:\n                    noisy_latents = noise_scheduler.add_noise(latents, noise2, timesteps)\n\n                down_block_res_samples, mid_block_res_sample = controlnet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    controlnet_cond=controlnet_image,\n                    return_dict=False,\n                    vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n                )\n\n                # Predict the noise residual\n                model_pred = unet(\n                    noisy_latents,\n                    timesteps,\n                    encoder_hidden_states=encoder_hidden_states,\n                    down_block_additional_residuals=[\n                        sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                    ],\n                    mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                ).sample\n\n                # Predict x0 for T\n                x0_T = predict_start_from_noise(noisy_latents, timesteps[0], model_pred)\n\n            # Re-add noise on x0_tmax\n            noise3 = torch.randn_like(latents)\n            timesteps = t_max * torch.ones(model_pred.shape[0]).to(accelerator.device)\n            timesteps = timesteps.long()\n            noisy_latents = noise_scheduler.add_noise(x0_T, noise3, timesteps[0])\n\n            down_block_res_samples, mid_block_res_sample = controlnet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                controlnet_cond=controlnet_image,\n                return_dict=False,\n                vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n            )\n\n            # Predict the noise residual\n            model_pred = unet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                down_block_additional_residuals=[\n                    sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                ],\n                mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n            ).sample\n\n            # Predict x0 for t_max\n            x0_tmax = predict_start_from_noise(noisy_latents, timesteps[0], model_pred)\n\n            loss_x0 = F.mse_loss(x0_T.float(), latents.float(), reduction=\"mean\")\n            loss_x0_from_tao = F.mse_loss(x0_tmax.float(), latents.float(), reduction=\"mean\")\n\n            loss = loss_ori + loss_x0 + loss_x0_from_tao\n\n            accelerator.backward(loss)\n            if accelerator.sync_gradients:\n                params_to_clip = controlnet.parameters()\n                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if accelerator.sync_gradients:\n            progress_bar.update(1)\n            global_step += 1\n\n            if accelerator.is_main_process:\n                if global_step % args.checkpointing_steps == 0:\n                    \n                    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n\n                # if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                if False:\n                    image_logs = log_validation(\n                        vae,\n                        text_encoder,\n                        tokenizer,\n                        unet,\n                        controlnet,\n                        args,\n                        accelerator,\n                        weight_dtype,\n                        global_step,\n                    )\n\n        logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n        progress_bar.set_postfix(**logs)\n        accelerator.log(logs, step=global_step)\n\n        if global_step >= args.max_train_steps:\n            break\n\n# Create the pipeline using using the trained modules and save it.\naccelerator.wait_for_everyone()\nif accelerator.is_main_process:\n    controlnet = accelerator.unwrap_model(controlnet)\n    controlnet.save_pretrained(args.output_dir)\n\n\n    if args.push_to_hub:\n        save_model_card(\n            repo_id,\n            image_logs=image_logs,\n            base_model=args.pretrained_model_name_or_path,\n            repo_folder=args.output_dir,\n        )\n        upload_folder(\n            repo_id=repo_id,\n            folder_path=args.output_dir,\n            commit_message=\"End of training\",\n            ignore_patterns=[\"step_*\", \"epoch_*\"],\n        )\n\naccelerator.end_training()"
  },
  {
    "path": "train_ccsr_stage2.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\nfrom ADD.models.discriminator import ProjectedDiscriminator\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionControlNetPipeline,\n    UniPCMultistepScheduler,\n)\nfrom models.controlnet import ControlNetModel\nfrom models.unet_2d_condition import UNet2DConditionModel\n# from models.losses import LPIPSWithDiscriminator\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\nfrom accelerate import DistributedDataParallelKwargs\n\nfrom dataloaders.paired_dataset_txt import PairedCaptionDataset\n\n\nfrom ADD.models.vit import vit_large, vit_small\nimport ADD.utils.util_net as util_net\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.21.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n\n    controlnet = accelerator.unwrap_model(controlnet)\n\n    pipeline = StableDiffusionControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        unet=unet,\n        controlnet=controlnet,\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    validation_prompt, validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = []\n\n                formatted_images.append(np.asarray(validation_image))\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warn(f\"image logging not implemented for {tracker.name}\")\n\n        return image_logs\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- controlnet\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n\n\n    parser.add_argument('--dataset_root_folders', type=str, default=\"\")\n    parser.add_argument(\"--is_module\", action=\"store_true\")\n    parser.add_argument(\"--t_max\", type=float, default=0.6666)\n    parser.add_argument(\"--t_min\", type=float, default=0.5)\n    parser.add_argument(\"--num_inference_steps\", type=int, default=1)\n    parser.add_argument(\"--start_timesteps\", type=int, default=999)\n\n    parser.add_argument(\"--lambda_l2\", type=float, default=1.0)\n    parser.add_argument(\"--lambda_lpips\", type=float, default=1.0)\n    parser.add_argument(\"--lambda_disc\", type=float, default=0.05)\n    parser.add_argument(\"--lambda_disc_train\", type=float, default=0.5)\n    parser.add_argument(\"--begin_disc\", type=float, default=100)\n\n    parser.add_argument(\n        \"--is_start_lr\",\n        type=bool,\n        default=True,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--vae_model_name_or_path\",\n        type=str,\n        default='',\n        help=\"Path to pretrained vae model.\"\n        \" If not specified vae weights are initialized from pre-trained model.\",\n    )\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=\"\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default='',\n        help=\"Path to pretrained controlnet model.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"./experiments/test\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    \n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1000)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    \n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    \n    \n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"fp16\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=100,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=1,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_ccsr_stage2\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n    \n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n\ndef previous_timestep(timestep):\n    if noise_scheduler.custom_timesteps:\n        index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0]\n        if index == noise_scheduler.timesteps.shape[0] - 1:\n            prev_t = torch.tensor(-1)\n        else:\n            prev_t = noise_scheduler.timesteps[index + 1]\n    else:\n        num_inference_steps = (\n            noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps\n        )\n        prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps\n\n    return prev_t\n\ndef predict_start_from_noise(sample, t, model_output):\n    t = t.to(noise_scheduler.alphas_cumprod.device)\n    prev_t = previous_timestep(t)\n\n    # 1. compute alphas, betas\n    alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device)\n    alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one\n    alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device)\n    beta_prod_t = 1 - alpha_prod_t\n    beta_prod_t_prev = 1 - alpha_prod_t_prev\n    current_alpha_t = alpha_prod_t / alpha_prod_t_prev\n    current_beta_t = 1 - current_alpha_t\n\n    # 2. compute predicted original sample from predicted noise also called\n    # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n    elif noise_scheduler.config.prediction_type == \"sample\":\n        pred_original_sample = model_output\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n    else:\n        raise ValueError(\n            f\"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or\"\n            \" `v_prediction`  for the DDPMScheduler.\"\n        )\n\n    return pred_original_sample\n\n# def main(args):\nargs = parse_args()\nlogging_dir = Path(args.output_dir, args.logging_dir)\n\naccelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\naccelerator = Accelerator(\n    gradient_accumulation_steps=args.gradient_accumulation_steps,\n    mixed_precision=args.mixed_precision,\n    log_with=args.report_to,\n    project_config=accelerator_project_config,\n)\n\n# Make one log on every process with the configuration for debugging.\nlogging.basicConfig(\n    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n    datefmt=\"%m/%d/%Y %H:%M:%S\",\n    level=logging.INFO,\n)\nlogger.info(accelerator.state, main_process_only=False)\nif accelerator.is_local_main_process:\n    transformers.utils.logging.set_verbosity_warning()\n    diffusers.utils.logging.set_verbosity_info()\nelse:\n    transformers.utils.logging.set_verbosity_error()\n    diffusers.utils.logging.set_verbosity_error()\n\n# If passed along, set the training seed now.\nif args.seed is not None:\n    set_seed(args.seed)\n\n# Handle the repository creation\nif accelerator.is_main_process:\n    if args.output_dir is not None:\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    if args.push_to_hub:\n        repo_id = create_repo(\n            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n        ).repo_id\n\n# Load the tokenizer\nif args.tokenizer_name:\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\nelif args.pretrained_model_name_or_path:\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n# import correct text encoder class\ntext_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n# Load scheduler and smodels\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\nnoise_scheduler.set_timesteps(args.num_inference_steps)\n\ntext_encoder = text_encoder_cls.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n\n# Load VAE model\nif args.vae_model_name_or_path:\n    logger.info(\"Loading existing vae weights\")\n    vae = AutoencoderKL.from_pretrained(args.vae_model_name_or_path, subfolder=\"vae\", revision=args.revision)\nelse:\n    logger.info(\"Loading pretrained vae weights\")\n    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\n\n# Load Controlnet model\nif args.controlnet_model_name_or_path:\n    logger.info(\"Loading existing controlnet weights\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path,  subfolder=\"controlnet\")\nelse:\n    logger.info(\"Initializing controlnet weights from unet\")\n    controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True)\n    \n# # Load discriminator model\n# discriminatornet = LPIPSWithDiscriminator(disc_start=1.0, kl_weight=0, perceptual_weight=1.0, disc_weight=0.5, disc_factor=1.0)\n\n# Load discriminator model\ndiscriminatornet = ProjectedDiscriminator(c_dim=384).train()\ncriterion_GAN = torch.nn.BCEWithLogitsLoss()\n# 实例化提取cls_lr的特征网络\nmodel_fea = vit_small(patch_size=14, img_size=518, block_chunks=0, init_values=1.0)\nutil_net.reload_model(model_fea, torch.load('preset/models/dino/dinov2_vits14_pretrain.pth'))\nmodel_fea.requires_grad_(False)\n\n# load lpips model\nimport lpips\nnet_lpips = lpips.LPIPS(net='vgg').cuda()\n\n# `accelerate` 0.16.0 will have better support for customized saving\nif version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        i = len(weights) - 1\n        assert len(models) == 2 and len(weights) == 2\n        for i, model in enumerate(models):\n            if i==0:\n                sub_dir = 'vae'\n                model.save_pretrained(os.path.join(output_dir, sub_dir))\n            # make sure to pop weight so that corresponding model is not saved again\n            weights.pop()\n\n    def load_model_hook(models, input_dir):\n        assert len(models) == 2\n        for i in range(len(models)):\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            # load diffusers style into model\n            if not isinstance(model, UNet2DConditionModel):\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n            else:\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n\n            model.register_to_config(**load_model.config)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\nvae.requires_grad_(False)\nunet.requires_grad_(False)\ntext_encoder.requires_grad_(False)\ncontrolnet.requires_grad_(False)\ndiscriminatornet.train()\nvae.train()\n\n# unlease vae decoder for training\nfor name, params in vae.named_parameters():\n    if 'decoder' in name:\n        params.requires_grad = True\n\nif args.enable_xformers_memory_efficient_attention:\n    if is_xformers_available():\n        import xformers\n\n        xformers_version = version.parse(xformers.__version__)\n        if xformers_version == version.parse(\"0.0.16\"):\n            logger.warn(\n                \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n            )\n        unet.enable_xformers_memory_efficient_attention()\n        controlnet.enable_xformers_memory_efficient_attention()\n    else:\n        raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\nif args.gradient_checkpointing:\n    vae.enable_gradient_checkpointing()\n    discriminatornet.enable_gradient_checkpointing()\n\n# Check that all trainable models are in full precision\nlow_precision_error_string = (\n    \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n    \" doing mixed precision training, copy of the weights should still be float32.\"\n)\n\nif accelerator.unwrap_model(vae).dtype != torch.float32:\n    raise ValueError(\n        f\"vae loaded as datatype {accelerator.unwrap_model(vae).dtype}. {low_precision_error_string}\"\n    )\n\n# Enable TF32 for faster training on Ampere GPUs,\n# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\nif args.allow_tf32:\n    torch.backends.cuda.matmul.allow_tf32 = True\n\nif args.scale_lr:\n    args.learning_rate = (\n        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n    )\n\n# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\nif args.use_8bit_adam:\n    try:\n        import bitsandbytes as bnb\n    except ImportError:\n        raise ImportError(\n            \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n        )\n\n    optimizer_class = bnb.optim.AdamW8bit\nelse:\n    optimizer_class = torch.optim.AdamW\n\n# Optimizer creation\nparams_to_optimize = list(vae.parameters())\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n    \nparams_to_optimize_disc = list(discriminatornet.parameters())\noptimizer_disc = optimizer_class(\n    params_to_optimize_disc,\n    lr=args.learning_rate,\n    betas=(0.9, 0.999),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n\ntrain_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders,\n                                    tokenizer=tokenizer,\n                                    gt_ratio=0) # let lr is gt\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset,\n    num_workers=args.dataloader_num_workers,\n    batch_size=args.train_batch_size,\n    shuffle=False\n)\n\n# For mixed precision training we cast the text_encoder and vae weights to half-precision\n# as these models are only used for inference, keeping weights in full precision is not required.\nweight_dtype = torch.float32\nif accelerator.mixed_precision == \"fp16\":\n    weight_dtype = torch.float16\nelif accelerator.mixed_precision == \"bf16\":\n    weight_dtype = torch.bfloat16\n\n# Move controlnet, unet and text_encoder to device and cast to weight_dtype\ncontrolnet.to(accelerator.device, dtype=weight_dtype)\ntext_encoder.to(accelerator.device, dtype=weight_dtype)\nunet.to(accelerator.device, dtype=weight_dtype)\nmodel_fea.to(accelerator.device, dtype=weight_dtype)\n\n# Scheduler and math around the number of training steps.\noverrode_max_train_steps = False\nnum_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\nif args.max_train_steps is None:\n    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    overrode_max_train_steps = True\n\nlr_scheduler = get_scheduler(\n    args.lr_scheduler,\n    optimizer=optimizer,\n    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n    num_training_steps=args.max_train_steps * accelerator.num_processes,\n    num_cycles=args.lr_num_cycles,\n    power=args.lr_power,\n)\n\nlr_scheduler_disc = get_scheduler(\n    args.lr_scheduler,\n    optimizer=optimizer_disc,\n    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n    num_training_steps=args.max_train_steps * accelerator.num_processes,\n    num_cycles=args.lr_num_cycles,\n    power=args.lr_power,\n)\n\n# Prepare everything with our `accelerator`.\nvae, discriminatornet, optimizer, optimizer_disc, train_dataloader, lr_scheduler, lr_scheduler_disc = accelerator.prepare(\n    vae, discriminatornet, optimizer, optimizer_disc, train_dataloader, lr_scheduler, lr_scheduler_disc\n)\n\n# We need to initialize the trackers we use, and also store our configuration.\n# The trackers initializes automatically on the main process.\nif accelerator.is_main_process:\n    tracker_config = dict(vars(args))\n\n    # tensorboard cannot handle list types for config\n    tracker_config.pop(\"validation_prompt\")\n    tracker_config.pop(\"validation_image\")\n\n    accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n# Train!\ntotal_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\nlogger.info(\"***** Running training *****\")\n\nlogger.info(f\"  Num Epochs = {args.num_train_epochs}\")\nlogger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\nlogger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\nlogger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\nlogger.info(f\"  Total optimization steps = {args.max_train_steps}\")\nglobal_step = 0\nfirst_epoch = 0\n\n# Potentially load in the weights and states from a previous save\nif args.resume_from_checkpoint:\n    if args.resume_from_checkpoint != \"latest\":\n        path = os.path.basename(args.resume_from_checkpoint)\n    else:\n        # Get the most recent checkpoint\n        dirs = os.listdir(args.output_dir)\n        dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n        dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n        path = dirs[-1] if len(dirs) > 0 else None\n\n    if path is None:\n        accelerator.print(\n            f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n        )\n        args.resume_from_checkpoint = None\n        initial_global_step = 0\n    else:\n        accelerator.print(f\"Resuming from checkpoint {path}\")\n        accelerator.load_state(os.path.join(args.output_dir, path))\n        global_step = int(path.split(\"-\")[1])\n\n        initial_global_step = global_step\n        first_epoch = global_step // num_update_steps_per_epoch\nelse:\n    initial_global_step = 0\n\nprogress_bar = tqdm(\n    range(0, args.max_train_steps),\n    initial=initial_global_step,\n    desc=\"Steps\",\n    # Only show the progress bar once on each machine.\n    disable=not accelerator.is_local_main_process,\n)\n\nfor epoch in range(first_epoch, args.num_train_epochs):\n    for step, batch in enumerate(train_dataloader):\n        l_acc = [vae, discriminatornet]\n        with accelerator.accumulate(*l_acc):\n            with torch.no_grad():\n                total_time_steps = noise_scheduler.timesteps\n                num_time_steps = len(total_time_steps)\n                if num_time_steps != 1:\n                    timesteps_loop = total_time_steps[-round(num_time_steps*args.t_max):]\n                    timesteps_loop = timesteps_loop[:-round(num_time_steps*args.t_min)]\n                    t_max = timesteps_loop[0]\n                    t_min = timesteps_loop[-1]\n\n                pixel_values = batch[\"pixel_values\"].to(accelerator.device)\n                if args.is_module:\n                    latents_gt = vae.module.encode(pixel_values).latent_dist.sample()\n                    latents_gt = latents_gt * vae.module.config.scaling_factor # Convert images to latent space\n                else:\n                    latents_gt = vae.encode(pixel_values).latent_dist.sample()\n                    latents_gt = latents_gt * vae.config.scaling_factor # Convert images to latent space\n\n                encoder_hidden_states = text_encoder(batch[\"input_caption\"].to(accelerator.device))[0]\n\n                controlnet_image = batch[\"conditioning_pixel_values\"].to(accelerator.device)\n                controlnet_image_encode = 2*controlnet_image-1\n                if args.is_module:\n                    vae_encode_condition_hidden_states = vae.module.encode(controlnet_image_encode).latent_dist.sample()\n                    vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.module.config.scaling_factor\n                else:\n                    vae_encode_condition_hidden_states = vae.encode(controlnet_image_encode).latent_dist.sample()\n                    vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor # Convert images to latent space\n                                \n                if global_step > args.begin_disc:\n                    lambda_l2 = args.lambda_l2\n                    lambda_lpips = args.lambda_lpips\n                    lambda_disc = args.lambda_disc\n                    lambda_disc_train = args.lambda_disc_train\n                else:\n                    lambda_l2 = args.lambda_l2\n                    lambda_lpips = 0\n                    lambda_disc = 0\n                    lambda_disc_train = args.lambda_disc_train\n\n                noise = torch.randn_like(latents_gt)\n                bsz = latents_gt.shape[0]\n                \n                timesteps = args.start_timesteps * torch.ones(latents_gt.shape[0]).to(accelerator.device)\n                timesteps = timesteps.long()\n\n                # Add noise to the latents according to the noise magnitude at each timestep\n                # (this is the forward diffusion process)\n                if args.start_timesteps==1:\n                    noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise, timesteps)\n                    noisy_latents = noisy_latents.to(dtype=weight_dtype)\n                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)\n                    controlnet_image = controlnet_image.to(dtype=weight_dtype)\n                    vae_encode_condition_hidden_states = vae_encode_condition_hidden_states.to(dtype=weight_dtype)\n\n                    down_block_res_samples, mid_block_res_sample = controlnet(\n                        noisy_latents,\n                        timesteps,\n                        encoder_hidden_states=encoder_hidden_states,\n                        controlnet_cond=controlnet_image,\n                        return_dict=False,\n                        vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n                    )\n\n                    # Predict the noise residual\n                    model_pred = unet(\n                        noisy_latents,\n                        timesteps,\n                        encoder_hidden_states=encoder_hidden_states,\n                        down_block_additional_residuals=[\n                            sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                        ],\n                        mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n                    ).sample\n\n                    # Predict x0 for T\n                    x0_T = noisy_latents - model_pred\n                else:\n                    # Sample noise based on LR (controlnet_image) or a Random Noise?\n                    if args.is_start_lr:\n                        noisy_latents = noise_scheduler.add_noise(vae_encode_condition_hidden_states, noise, timesteps)\n                        noisy_latents = noisy_latents.to(dtype=weight_dtype)\n                    else:\n                        noisy_latents = noise_scheduler.add_noise(latents_gt, noise, timesteps)\n                        noisy_latents = noisy_latents.to(dtype=weight_dtype)\n\n                    encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)\n                    controlnet_image = controlnet_image.to(dtype=weight_dtype)\n                    vae_encode_condition_hidden_states = vae_encode_condition_hidden_states.to(dtype=weight_dtype)\n\n                    down_block_res_samples, mid_block_res_sample = controlnet(\n                            noisy_latents,\n                            timesteps,\n                            encoder_hidden_states=encoder_hidden_states,\n                            controlnet_cond=controlnet_image,\n                            return_dict=False,\n                            vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n                    )\n\n                    # Predict the noise residual\n                    model_pred = unet(\n                        noisy_latents,\n                        timesteps,\n                        encoder_hidden_states=encoder_hidden_states,\n                        down_block_additional_residuals=down_block_res_samples,\n                        mid_block_additional_residual=mid_block_res_sample,\n                        return_dict=False,\n                    )[0]\n\n                    # Predict x0 for T\n                    x0_T = predict_start_from_noise(noisy_latents, timesteps[0], model_pred)\n\n                if num_time_steps!=1:\n                    # Re-add noise on x0_tmax\n                    noise2 = torch.randn_like(latents_gt)\n                    timesteps = t_max * torch.ones(model_pred.shape[0]).to(accelerator.device)\n                    timesteps = timesteps.long()\n                    latents = noise_scheduler.add_noise(x0_T, noise2, timesteps[0])\n\n\n                    # Denoising loop\n                    for i, t in enumerate(timesteps_loop):\n                        \n                        # controlnet_latent_model_input = noise_scheduler.scale_model_input(latents, t)\n                        latents = latents.to(dtype=weight_dtype)\n                        down_block_res_samples, mid_block_res_sample = controlnet(\n                            latents,\n                            t,\n                            encoder_hidden_states=encoder_hidden_states,\n                            controlnet_cond=controlnet_image,\n                            return_dict=False,\n                            vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n                        )\n\n                        # predict the noise residual\n                        noise_pred = unet(\n                            latents,\n                            t,\n                            encoder_hidden_states=encoder_hidden_states,\n                            down_block_additional_residuals=down_block_res_samples,\n                            mid_block_additional_residual=mid_block_res_sample,\n                            return_dict=False,\n                        )[0]\n\n                        # compute the previous noisy sample x_t -> x_t-1\n                        latents_old = latents\n                        latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                    x0_tmin = predict_start_from_noise(latents_old, t, noise_pred)\n                    latents = x0_tmin\n                    latents = latents.to(dtype=torch.float32)\n                else:\n                    latents = x0_T.to(dtype=torch.float32)\n\n        \n            # optimize the generator: vae decoder\n            discriminatornet.requires_grad_(False)\n            if args.is_module:\n                image = vae.module.decode(latents / vae.module.config.scaling_factor, return_dict=False)[0].clamp(-1, 1)\n            else:\n                image = vae.decode(latents / vae.config.scaling_factor, return_dict=False)[0].clamp(-1, 1)\n            # compute the discriminator loss & update parameters\n            _, cls_lr = model_fea(F.interpolate(controlnet_image, size=518, mode='bilinear'))\n\n            # compute the generator loss\n            pred_fake, _ = discriminatornet(image, cls_lr.detach())\n            pred_fake = torch.cat(pred_fake, dim=1)\n            gan_loss = -torch.mean(pred_fake)\n\n            loss_x0 = F.mse_loss(image.float(), pixel_values.float(), reduction=\"mean\") * lambda_l2\n            if lambda_lpips != 0:\n                loss_lpips = net_lpips(image.float(), pixel_values.float()).mean() * lambda_lpips\n                loss_x0 = loss_lpips + loss_x0\n\n            loss = loss_x0 + lambda_disc * gan_loss\n\n            accelerator.backward(loss)\n\n            if accelerator.sync_gradients:\n                params_to_clip = vae.parameters()\n                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n            # update discriminator\n            discriminatornet.requires_grad_(True)\n            if args.is_module:\n                discriminatornet.module.dino.requires_grad_(False)\n            else:\n                discriminatornet.dino.requires_grad_(False)\n            pred_real, features = discriminatornet(pixel_values, cls_lr.detach())\n            pred_fake, _ = discriminatornet(image.detach(), cls_lr.detach())\n            pred_fake = torch.cat(pred_fake, dim=1)\n\n            pred_real = torch.cat(pred_real, dim=1)\n            loss_real = torch.mean(torch.relu(1.0 - pred_real)) * lambda_disc_train\n\n            loss_fake = torch.mean(torch.relu(1.0 + pred_fake)) * lambda_disc_train\n            loss_disc = loss_real + loss_fake\n\n            accelerator.backward(loss_disc)\n            if accelerator.sync_gradients:\n                params_to_clip = discriminatornet.parameters()\n                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n            optimizer_disc.step()\n            lr_scheduler_disc.step()\n            optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)\n            model_fea.zero_grad(set_to_none=args.set_grads_to_none)\n\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if accelerator.sync_gradients:\n            progress_bar.update(1)\n            global_step += 1\n\n            if accelerator.is_main_process:\n                if global_step % args.checkpointing_steps == 0:\n\n                    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n\n                # if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                if False:\n                    image_logs = log_validation(\n                        vae,\n                        text_encoder,\n                        tokenizer,\n                        unet,\n                        controlnet,\n                        args,\n                        accelerator,\n                        weight_dtype,\n                        global_step,\n                    )\n\n        logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n        progress_bar.set_postfix(**logs)\n        accelerator.log(logs, step=global_step)\n\n        if global_step >= args.max_train_steps:\n            break\n\n# Create the pipeline using using the trained modules and save it.\naccelerator.wait_for_everyone()\nif accelerator.is_main_process:\n    vae = accelerator.unwrap_model(vae)\n    vae.save_pretrained(args.output_dir)\n\n    if args.push_to_hub:\n        save_model_card(\n            repo_id,\n            image_logs=image_logs,\n            base_model=args.pretrained_model_name_or_path,\n            repo_folder=args.output_dir,\n        )\n        upload_folder(\n            repo_id=repo_id,\n            folder_path=args.output_dir,\n            commit_message=\"End of training\",\n            ignore_patterns=[\"step_*\", \"epoch_*\"],\n        )\n\naccelerator.end_training()"
  },
  {
    "path": "train_controlnet.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2023 The HuggingFace Inc. team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n\nimport argparse\nimport logging\nimport math\nimport os\nimport random\nimport shutil\nfrom pathlib import Path\n\nimport accelerate\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nimport torch.utils.checkpoint\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\nfrom huggingface_hub import create_repo, upload_folder\nfrom packaging import version\nfrom PIL import Image\nfrom torchvision import transforms\nfrom tqdm.auto import tqdm\nfrom transformers import AutoTokenizer, PretrainedConfig\n\nimport diffusers\n\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    StableDiffusionControlNetPipeline,\n    UniPCMultistepScheduler,\n)\nfrom models.controlnet import ControlNetModel\nfrom models.unet_2d_condition import UNet2DConditionModel\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.utils import check_min_version, is_wandb_available\nfrom diffusers.utils.import_utils import is_xformers_available\n\nfrom dataloaders.paired_dataset_txt import PairedCaptionDataset\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.21.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef image_grid(imgs, rows, cols):\n    assert len(imgs) == rows * cols\n\n    w, h = imgs[0].size\n    grid = Image.new(\"RGB\", size=(cols * w, rows * h))\n\n    for i, img in enumerate(imgs):\n        grid.paste(img, box=(i % cols * w, i // cols * h))\n    return grid\n\n\ndef log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):\n    logger.info(\"Running validation... \")\n\n    controlnet = accelerator.unwrap_model(controlnet)\n\n    pipeline = StableDiffusionControlNetPipeline.from_pretrained(\n        args.pretrained_model_name_or_path,\n        vae=vae,\n        text_encoder=text_encoder,\n        tokenizer=tokenizer,\n        unet=unet,\n        controlnet=controlnet,\n        safety_checker=None,\n        revision=args.revision,\n        torch_dtype=weight_dtype,\n    )\n    pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)\n    pipeline = pipeline.to(accelerator.device)\n    pipeline.set_progress_bar_config(disable=True)\n\n    if args.enable_xformers_memory_efficient_attention:\n        pipeline.enable_xformers_memory_efficient_attention()\n\n    if args.seed is None:\n        generator = None\n    else:\n        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)\n\n    if len(args.validation_image) == len(args.validation_prompt):\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_image) == 1:\n        validation_images = args.validation_image * len(args.validation_prompt)\n        validation_prompts = args.validation_prompt\n    elif len(args.validation_prompt) == 1:\n        validation_images = args.validation_image\n        validation_prompts = args.validation_prompt * len(args.validation_image)\n    else:\n        raise ValueError(\n            \"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`\"\n        )\n\n    image_logs = []\n\n    for validation_prompt, validation_image in zip(validation_prompts, validation_images):\n        validation_image = Image.open(validation_image).convert(\"RGB\")\n\n        images = []\n\n        for _ in range(args.num_validation_images):\n            with torch.autocast(\"cuda\"):\n                image = pipeline(\n                    validation_prompt, validation_image, num_inference_steps=20, generator=generator\n                ).images[0]\n\n            images.append(image)\n\n        image_logs.append(\n            {\"validation_image\": validation_image, \"images\": images, \"validation_prompt\": validation_prompt}\n        )\n\n    for tracker in accelerator.trackers:\n        if tracker.name == \"tensorboard\":\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images = []\n\n                formatted_images.append(np.asarray(validation_image))\n\n                for image in images:\n                    formatted_images.append(np.asarray(image))\n\n                formatted_images = np.stack(formatted_images)\n\n                tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats=\"NHWC\")\n        elif tracker.name == \"wandb\":\n            formatted_images = []\n\n            for log in image_logs:\n                images = log[\"images\"]\n                validation_prompt = log[\"validation_prompt\"]\n                validation_image = log[\"validation_image\"]\n\n                formatted_images.append(wandb.Image(validation_image, caption=\"Controlnet conditioning\"))\n\n                for image in images:\n                    image = wandb.Image(image, caption=validation_prompt)\n                    formatted_images.append(image)\n\n            tracker.log({\"validation\": formatted_images})\n        else:\n            logger.warn(f\"image logging not implemented for {tracker.name}\")\n\n        return image_logs\n\n\ndef import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path,\n        subfolder=\"text_encoder\",\n        revision=revision,\n    )\n    model_class = text_encoder_config.architectures[0]\n\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"RobertaSeriesModelWithTransformation\":\n        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation\n\n        return RobertaSeriesModelWithTransformation\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):\n    img_str = \"\"\n    if image_logs is not None:\n        img_str = \"You can find some example images below.\\n\"\n        for i, log in enumerate(image_logs):\n            images = log[\"images\"]\n            validation_prompt = log[\"validation_prompt\"]\n            validation_image = log[\"validation_image\"]\n            validation_image.save(os.path.join(repo_folder, \"image_control.png\"))\n            img_str += f\"prompt: {validation_prompt}\\n\"\n            images = [validation_image] + images\n            image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f\"images_{i}.png\"))\n            img_str += f\"![images_{i})](./images_{i}.png)\\n\"\n\n    yaml = f\"\"\"\n---\nlicense: creativeml-openrail-m\nbase_model: {base_model}\ntags:\n- stable-diffusion\n- stable-diffusion-diffusers\n- text-to-image\n- diffusers\n- controlnet\ninference: true\n---\n    \"\"\"\n    model_card = f\"\"\"\n# controlnet-{repo_id}\n\nThese are controlnet weights trained on {base_model} with new type of conditioning.\n{img_str}\n\"\"\"\n    with open(os.path.join(repo_folder, \"README.md\"), \"w\") as f:\n        f.write(yaml + model_card)\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a ControlNet training script.\")\n\n\n    parser.add_argument('--dataset_root_folders', type=str, default=\"\")\n    parser.add_argument(\"--t_max\", type=float, default=0.6667)\n    parser.add_argument(\"--start_timesteps\", type=int, default=999)\n\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=\"\",\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--controlnet_model_name_or_path\",\n        type=str,\n        default='',\n        help=\"Path to pretrained controlnet model.\"\n        \" If not specified controlnet weights are initialized from unet.\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"./experiments/test\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=512,\n        help=(\n            \"The resolution for input images, all the images in the train/validation dataset will be resized to this\"\n            \" resolution\"\n        ),\n    )\n    \n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1000)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=500,\n        help=(\n            \"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. \"\n            \"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference.\"\n            \"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components.\"\n            \"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step\"\n            \"instructions.\"\n        ),\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=(\n            \"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be\"\n            \" float32 precision.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tokenizer_name\",\n        type=str,\n        default=None,\n        help=\"Pretrained tokenizer name or path if not the same as model_name\",\n    )\n    \n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=0, help=\"A seed for reproducible training.\")\n    \n    \n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=5e-5,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--use_8bit_adam\", action=\"store_true\", help=\"Whether or not to use 8-bit Adam from bitsandbytes.\"\n    )\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam optimizer.\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-2, help=\"Weight decay to use.\")\n    parser.add_argument(\"--adam_epsilon\", type=float, default=1e-08, help=\"Epsilon value for the Adam optimizer\")\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\"--push_to_hub\", action=\"store_true\", help=\"Whether or not to push the model to the Hub.\")\n    parser.add_argument(\"--hub_token\", type=str, default=None, help=\"The token to use to push to the Model Hub.\")\n    parser.add_argument(\n        \"--hub_model_id\",\n        type=str,\n        default=None,\n        help=\"The name of the repository to keep in sync with the local `output_dir`.\",\n    )\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=\"fp16\",\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--enable_xformers_memory_efficient_attention\", action=\"store_true\", help=\"Whether or not to use xformers.\"\n    )\n    parser.add_argument(\n        \"--set_grads_to_none\",\n        action=\"store_true\",\n        help=(\n            \"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain\"\n            \" behaviors, so disable this argument if it causes any problems. More info:\"\n            \" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_name\",\n        type=str,\n        default=None,\n        help=(\n            \"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,\"\n            \" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,\"\n            \" or to a folder containing files that 🤗 Datasets can understand.\"\n        ),\n    )\n    parser.add_argument(\n        \"--dataset_config_name\",\n        type=str,\n        default=None,\n        help=\"The config of the Dataset, leave as None if there's only one config.\",\n    )\n    parser.add_argument(\n        \"--image_column\", type=str, default=\"image\", help=\"The column of the dataset containing the target image.\"\n    )\n    parser.add_argument(\n        \"--conditioning_image_column\",\n        type=str,\n        default=\"conditioning_image\",\n        help=\"The column of the dataset containing the controlnet conditioning image.\",\n    )\n    parser.add_argument(\n        \"--caption_column\",\n        type=str,\n        default=\"text\",\n        help=\"The column of the dataset containing a caption or a list of captions.\",\n    )\n    parser.add_argument(\n        \"--max_train_samples\",\n        type=int,\n        default=None,\n        help=(\n            \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n            \"value if set.\"\n        ),\n    )\n    parser.add_argument(\n        \"--proportion_empty_prompts\",\n        type=float,\n        default=0,\n        help=\"Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of prompts evaluated every `--validation_steps` and logged to `--report_to`.\"\n            \" Provide either a matching number of `--validation_image`s, a single `--validation_image`\"\n            \" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--validation_image\",\n        type=str,\n        default=[\"\"],\n        nargs=\"+\",\n        help=(\n            \"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`\"\n            \" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a\"\n            \" a single `--validation_prompt` to be used with all `--validation_image`s, or a single\"\n            \" `--validation_image` that will be used with all `--validation_prompt`s.\"\n        ),\n    )\n    parser.add_argument(\n        \"--is_start_lr\",\n        type=bool,\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=100,\n        help=\"Number of images to be generated for each `--validation_image`, `--validation_prompt` pair\",\n    )\n    parser.add_argument(\n        \"--validation_steps\",\n        type=int,\n        default=1,\n        help=(\n            \"Run validation every X steps. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`\"\n            \" and logging the images.\"\n        ),\n    )\n    parser.add_argument(\n        \"--tracker_project_name\",\n        type=str,\n        default=\"train_ccsr_stage1\",\n        help=(\n            \"The `project_name` argument passed to Accelerator.init_trackers for\"\n            \" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator\"\n        ),\n    )\n\n    \n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    if args.resolution % 8 != 0:\n        raise ValueError(\n            \"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder.\"\n        )\n\n    return args\n    \ndef previous_timestep(timestep):\n    if noise_scheduler.custom_timesteps:\n        index = (noise_scheduler.timesteps == timestep).nonzero(as_tuple=True)[0][0]\n        if index == noise_scheduler.timesteps.shape[0] - 1:\n            prev_t = torch.tensor(-1)\n        else:\n            prev_t = noise_scheduler.timesteps[index + 1]\n    else:\n        num_inference_steps = (\n            noise_scheduler.num_inference_steps if noise_scheduler.num_inference_steps else noise_scheduler.config.num_train_timesteps\n        )\n        prev_t = timestep - noise_scheduler.config.num_train_timesteps // num_inference_steps\n\n    return prev_t\n\ndef predict_start_from_noise(sample, t, model_output):\n    t = t.to(noise_scheduler.alphas_cumprod.device)\n    prev_t = previous_timestep(t)\n\n    # 1. compute alphas, betas\n    alpha_prod_t = noise_scheduler.alphas_cumprod[t].to(sample.device)\n    alpha_prod_t_prev = noise_scheduler.alphas_cumprod[prev_t] if prev_t >= 0 else noise_scheduler.one\n    alpha_prod_t_prev = alpha_prod_t_prev.to(sample.device)\n    beta_prod_t = 1 - alpha_prod_t\n    beta_prod_t_prev = 1 - alpha_prod_t_prev\n    current_alpha_t = alpha_prod_t / alpha_prod_t_prev\n    current_beta_t = 1 - current_alpha_t\n\n    # 2. compute predicted original sample from predicted noise also called\n    # \"predicted x_0\" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf\n    if noise_scheduler.config.prediction_type == \"epsilon\":\n        pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)\n    elif noise_scheduler.config.prediction_type == \"sample\":\n        pred_original_sample = model_output\n    elif noise_scheduler.config.prediction_type == \"v_prediction\":\n        pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output\n    else:\n        raise ValueError(\n            f\"prediction_type given as {noise_scheduler.config.prediction_type} must be one of `epsilon`, `sample` or\"\n            \" `v_prediction`  for the DDPMScheduler.\"\n        )\n\n    return pred_original_sample\n\n# def main(args):\nargs = parse_args()\nlogging_dir = Path(args.output_dir, args.logging_dir)\n\naccelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n\naccelerator = Accelerator(\n    gradient_accumulation_steps=args.gradient_accumulation_steps,\n    mixed_precision=args.mixed_precision,\n    log_with=args.report_to,\n    project_config=accelerator_project_config,\n)\n\n# Make one log on every process with the configuration for debugging.\nlogging.basicConfig(\n    format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n    datefmt=\"%m/%d/%Y %H:%M:%S\",\n    level=logging.INFO,\n)\nlogger.info(accelerator.state, main_process_only=False)\nif accelerator.is_local_main_process:\n    transformers.utils.logging.set_verbosity_warning()\n    diffusers.utils.logging.set_verbosity_info()\nelse:\n    transformers.utils.logging.set_verbosity_error()\n    diffusers.utils.logging.set_verbosity_error()\n\n# If passed along, set the training seed now.\nif args.seed is not None:\n    set_seed(args.seed)\n\n# Handle the repository creation\nif accelerator.is_main_process:\n    if args.output_dir is not None:\n        os.makedirs(args.output_dir, exist_ok=True)\n\n    if args.push_to_hub:\n        repo_id = create_repo(\n            repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token\n        ).repo_id\n\n# Load the tokenizer\nif args.tokenizer_name:\n    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)\nelif args.pretrained_model_name_or_path:\n    tokenizer = AutoTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        use_fast=False,\n    )\n\n# import correct text encoder class\ntext_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)\n\n# Load scheduler and models\nnoise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"scheduler\")\ntext_encoder = text_encoder_cls.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision\n)\nvae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder=\"vae\", revision=args.revision)\nunet = UNet2DConditionModel.from_pretrained(\n    args.pretrained_model_name_or_path, subfolder=\"unet\", revision=args.revision\n)\n\nif args.controlnet_model_name_or_path:\n    logger.info(\"Loading existing controlnet weights\")\n    controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path,  subfolder=\"controlnet\")\nelse:\n    logger.info(\"Initializing controlnet weights from unet\")\n    controlnet = ControlNetModel.from_unet(unet, use_vae_encode_condition=True)\n# `accelerate` 0.16.0 will have better support for customized saving\nif version.parse(accelerate.__version__) >= version.parse(\"0.16.0\"):\n    # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format\n    def save_model_hook(models, weights, output_dir):\n        i = len(weights) - 1\n        for i, model in enumerate(models):\n            sub_dir = \"unet\" if isinstance(model, UNet2DConditionModel) else \"controlnet\"\n            model.save_pretrained(os.path.join(output_dir, sub_dir))\n            # make sure to pop weight so that corresponding model is not saved again\n            weights.pop()\n\n    def load_model_hook(models, input_dir):\n        assert len(models) == 2\n        for i in range(len(models)):\n            # pop models so that they are not loaded again\n            model = models.pop()\n\n            # load diffusers style into model\n            if not isinstance(model, UNet2DConditionModel):\n                load_model = ControlNetModel.from_pretrained(input_dir, subfolder=\"controlnet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n            else:\n                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder=\"unet\") # , low_cpu_mem_usage=False, ignore_mismatched_sizes=True\n\n            model.register_to_config(**load_model.config)\n\n            model.load_state_dict(load_model.state_dict())\n            del load_model\n\n    accelerator.register_save_state_pre_hook(save_model_hook)\n    accelerator.register_load_state_pre_hook(load_model_hook)\n\nvae.requires_grad_(False)\nunet.requires_grad_(False)\ntext_encoder.requires_grad_(False)\ncontrolnet.train()\n\nif args.enable_xformers_memory_efficient_attention:\n    if is_xformers_available():\n        import xformers\n\n        xformers_version = version.parse(xformers.__version__)\n        if xformers_version == version.parse(\"0.0.16\"):\n            logger.warn(\n                \"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details.\"\n            )\n        unet.enable_xformers_memory_efficient_attention()\n        controlnet.enable_xformers_memory_efficient_attention()\n    else:\n        raise ValueError(\"xformers is not available. Make sure it is installed correctly\")\n\nif args.gradient_checkpointing:\n    unet.enable_gradient_checkpointing()\n    controlnet.enable_gradient_checkpointing()\n\n# Check that all trainable models are in full precision\nlow_precision_error_string = (\n    \" Please make sure to always have all model weights in full float32 precision when starting training - even if\"\n    \" doing mixed precision training, copy of the weights should still be float32.\"\n)\n\nif accelerator.unwrap_model(controlnet).dtype != torch.float32:\n    raise ValueError(\n        f\"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}\"\n    )\nif accelerator.unwrap_model(unet).dtype != torch.float32:\n    raise ValueError(\n        f\"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}\"\n    )\n\n\n# Enable TF32 for faster training on Ampere GPUs,\n# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\nif args.allow_tf32:\n    torch.backends.cuda.matmul.allow_tf32 = True\n\n# Optimizer creation\nif args.scale_lr:\n    args.learning_rate = (\n        args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n    )\n\n# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs\nif args.use_8bit_adam:\n    try:\n        import bitsandbytes as bnb\n    except ImportError:\n        raise ImportError(\n            \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n        )\n\n    optimizer_class = bnb.optim.AdamW8bit\nelse:\n    optimizer_class = torch.optim.AdamW\n\nparams_to_optimize = list(controlnet.parameters())\noptimizer = optimizer_class(\n    params_to_optimize,\n    lr=args.learning_rate,\n    betas=(args.adam_beta1, args.adam_beta2),\n    weight_decay=args.adam_weight_decay,\n    eps=args.adam_epsilon,\n)\n\n# Training dataset creation\ntrain_dataset = PairedCaptionDataset(root_folders=args.dataset_root_folders,\n                                    tokenizer=tokenizer,\n                                    gt_ratio=0) # let lr is gt\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset,\n    num_workers=args.dataloader_num_workers,\n    batch_size=args.train_batch_size,\n    shuffle=True\n)\n\n# For mixed precision training we cast the text_encoder and vae weights to half-precision\n# as these models are only used for inference, keeping weights in full precision is not required.\nweight_dtype = torch.float32\nif accelerator.mixed_precision == \"fp16\":\n    weight_dtype = torch.float16\nelif accelerator.mixed_precision == \"bf16\":\n    weight_dtype = torch.bfloat16\n\n# Move vae, unet and text_encoder to device and cast to weight_dtype\nvae.to(accelerator.device, dtype=weight_dtype)\ntext_encoder.to(accelerator.device, dtype=weight_dtype)\nunet.to(accelerator.device, dtype=weight_dtype)\n\n# Scheduler and math around the number of training steps.\noverrode_max_train_steps = False\nnum_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\nif args.max_train_steps is None:\n    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    overrode_max_train_steps = True\n\nlr_scheduler = get_scheduler(\n    args.lr_scheduler,\n    optimizer=optimizer,\n    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n    num_training_steps=args.max_train_steps * accelerator.num_processes,\n    num_cycles=args.lr_num_cycles,\n    power=args.lr_power,\n)\n\n# Prepare everything with our `accelerator`.\ncontrolnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n    controlnet, optimizer, train_dataloader, lr_scheduler\n)\n\n# We need to initialize the trackers we use, and also store our configuration.\n# The trackers initializes automatically on the main process.\nif accelerator.is_main_process:\n    tracker_config = dict(vars(args))\n\n    # tensorboard cannot handle list types for config\n    tracker_config.pop(\"validation_prompt\")\n    tracker_config.pop(\"validation_image\")\n\n    accelerator.init_trackers(args.tracker_project_name, config=tracker_config)\n\n# Begin to train\ntotal_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\nlogger.info(\"***** Running training *****\")\nlogger.info(f\"  Num Epochs = {args.num_train_epochs}\")\nlogger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\nlogger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\nlogger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\nlogger.info(f\"  Total optimization steps = {args.max_train_steps}\")\nglobal_step = 0\nfirst_epoch = 0\n\n# Potentially load in the weights and states from a previous save\nif args.resume_from_checkpoint:\n    if args.resume_from_checkpoint != \"latest\":\n        path = os.path.basename(args.resume_from_checkpoint)\n    else:\n        # Get the most recent checkpoint\n        dirs = os.listdir(args.output_dir)\n        dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n        dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n        path = dirs[-1] if len(dirs) > 0 else None\n\n    if path is None:\n        accelerator.print(\n            f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n        )\n        args.resume_from_checkpoint = None\n        initial_global_step = 0\n    else:\n        accelerator.print(f\"Resuming from checkpoint {path}\")\n        accelerator.load_state(os.path.join(args.output_dir, path))\n        global_step = int(path.split(\"-\")[1])\n\n        initial_global_step = global_step\n        first_epoch = global_step // num_update_steps_per_epoch\nelse:\n    initial_global_step = 0\n\nprogress_bar = tqdm(\n    range(0, args.max_train_steps),\n    initial=initial_global_step,\n    desc=\"Steps\",\n    # Only show the progress bar once on each machine.\n    disable=not accelerator.is_local_main_process,\n)\n\n\nfor epoch in range(first_epoch, args.num_train_epochs):\n    for step, batch in enumerate(train_dataloader):\n        with accelerator.accumulate(controlnet):\n            pixel_values = batch[\"pixel_values\"].to(accelerator.device, dtype=weight_dtype)\n            # Convert images to latent space\n            latents = vae.encode(pixel_values).latent_dist.sample()\n            latents = latents * vae.config.scaling_factor\n\n            # Sample noise that we'll add to the latents\n            noise = torch.randn_like(latents)\n            bsz = latents.shape[0]\n            # Sample a random timestep for each image\n            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)\n            timesteps = timesteps.long()\n\n            # Add noise to the latents according to the noise magnitude at each timestep\n            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)\n\n            # # Get the text embedding for conditioning\n            encoder_hidden_states = text_encoder(batch[\"input_caption\"].to(accelerator.device))[0]\n\n            controlnet_image = batch[\"conditioning_pixel_values\"].to(accelerator.device, dtype=weight_dtype)\n\n            vae_encode_condition_hidden_states = vae.encode(2*controlnet_image-1).latent_dist.sample()\n            vae_encode_condition_hidden_states = vae_encode_condition_hidden_states * vae.config.scaling_factor\n\n            down_block_res_samples, mid_block_res_sample = controlnet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                controlnet_cond=controlnet_image,\n                return_dict=False,\n                vae_encode_condition_hidden_states=vae_encode_condition_hidden_states,\n            )\n\n            # Predict the noise residual\n            model_pred = unet(\n                noisy_latents,\n                timesteps,\n                encoder_hidden_states=encoder_hidden_states,\n                down_block_additional_residuals=[\n                    sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n                ],\n                mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),\n            ).sample\n\n            # Get the target for loss depending on the prediction type\n            if noise_scheduler.config.prediction_type == \"epsilon\":\n                target = noise\n            elif noise_scheduler.config.prediction_type == \"v_prediction\":\n                target = noise_scheduler.get_velocity(latents, noise, timesteps)\n            else:\n                raise ValueError(f\"Unknown prediction type {noise_scheduler.config.prediction_type}\")\n            loss = F.mse_loss(model_pred.float(), target.float(), reduction=\"mean\") # original loss\n\n            accelerator.backward(loss)\n            if accelerator.sync_gradients:\n                params_to_clip = controlnet.parameters()\n                accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n            optimizer.step()\n            lr_scheduler.step()\n            optimizer.zero_grad(set_to_none=args.set_grads_to_none)\n\n        # Checks if the accelerator has performed an optimization step behind the scenes\n        if accelerator.sync_gradients:\n            progress_bar.update(1)\n            global_step += 1\n\n            if accelerator.is_main_process:\n                if global_step % args.checkpointing_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    accelerator.save_state(save_path)\n                    logger.info(f\"Saved state to {save_path}\")\n\n                # if args.validation_prompt is not None and global_step % args.validation_steps == 0:\n                if False:\n                    image_logs = log_validation(\n                        vae,\n                        text_encoder,\n                        tokenizer,\n                        unet,\n                        controlnet,\n                        args,\n                        accelerator,\n                        weight_dtype,\n                        global_step,\n                    )\n\n        logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n        progress_bar.set_postfix(**logs)\n        accelerator.log(logs, step=global_step)\n\n        if global_step >= args.max_train_steps:\n            break\n\n# Create the pipeline using using the trained modules and save it.\naccelerator.wait_for_everyone()\nif accelerator.is_main_process:\n    controlnet = accelerator.unwrap_model(controlnet)\n    controlnet.save_pretrained(args.output_dir)\n\n    unet = accelerator.unwrap_model(unet)\n    unet.save_pretrained(args.output_dir)\n\n    if args.push_to_hub:\n        save_model_card(\n            repo_id,\n            image_logs=image_logs,\n            base_model=args.pretrained_model_name_or_path,\n            repo_folder=args.output_dir,\n        )\n        upload_folder(\n            repo_id=repo_id,\n            folder_path=args.output_dir,\n            commit_message=\"End of training\",\n            ignore_patterns=[\"step_*\", \"epoch_*\"],\n        )\n\naccelerator.end_training()\n\n\n# if __name__ == \"__main__\":\n#     args = parse_args()\n#     main(args)\n"
  },
  {
    "path": "utils/devices.py",
    "content": "import sys\nimport contextlib\nfrom functools import lru_cache\n\nimport torch\n#from modules import errors\n\nif sys.platform == \"darwin\":\n    from modules import mac_specific\n\n\ndef has_mps() -> bool:\n    if sys.platform != \"darwin\":\n        return False\n    else:\n        return mac_specific.has_mps\n\n\ndef get_cuda_device_string():\n    return \"cuda\"\n\n\ndef get_optimal_device_name():\n    if torch.cuda.is_available():\n        return get_cuda_device_string()\n\n    if has_mps():\n        return \"mps\"\n\n    return \"cpu\"\n\n\ndef get_optimal_device():\n    return torch.device(get_optimal_device_name())\n\n\ndef get_device_for(task):\n    return get_optimal_device()\n\n\ndef torch_gc():\n\n    if torch.cuda.is_available():\n        with torch.cuda.device(get_cuda_device_string()):\n            torch.cuda.empty_cache()\n            torch.cuda.ipc_collect()\n\n    if has_mps():\n        mac_specific.torch_mps_gc()\n\n\ndef enable_tf32():\n    if torch.cuda.is_available():\n\n        # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't\n        # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407\n        if any(torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())):\n            torch.backends.cudnn.benchmark = True\n\n        torch.backends.cuda.matmul.allow_tf32 = True\n        torch.backends.cudnn.allow_tf32 = True\n\n\nenable_tf32()\n#errors.run(enable_tf32, \"Enabling TF32\")\n\ncpu = torch.device(\"cpu\")\ndevice = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = torch.device(\"cuda\")\ndtype = torch.float16\ndtype_vae = torch.float16\ndtype_unet = torch.float16\nunet_needs_upcast = False\n\n\ndef cond_cast_unet(input):\n    return input.to(dtype_unet) if unet_needs_upcast else input\n\n\ndef cond_cast_float(input):\n    return input.float() if unet_needs_upcast else input\n\n\ndef randn(seed, shape):\n    torch.manual_seed(seed)\n    return torch.randn(shape, device=device)\n\n\ndef randn_without_seed(shape):\n    return torch.randn(shape, device=device)\n\n\ndef autocast(disable=False):\n    if disable:\n        return contextlib.nullcontext()\n\n    return torch.autocast(\"cuda\")\n\n\ndef without_autocast(disable=False):\n    return torch.autocast(\"cuda\", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()\n\n\nclass NansException(Exception):\n    pass\n\n\ndef test_for_nans(x, where):\n    if not torch.all(torch.isnan(x)).item():\n        return\n\n    if where == \"unet\":\n        message = \"A tensor with all NaNs was produced in Unet.\"\n\n    elif where == \"vae\":\n        message = \"A tensor with all NaNs was produced in VAE.\"\n\n    else:\n        message = \"A tensor with all NaNs was produced.\"\n\n    message += \" Use --disable-nan-check commandline argument to disable this check.\"\n\n    raise NansException(message)\n\n\n@lru_cache\ndef first_time_calculation():\n    \"\"\"\n    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and\n    spends about 2.7 seconds doing that, at least wih NVidia.\n    \"\"\"\n\n    x = torch.zeros((1, 1)).to(device, dtype)\n    linear = torch.nn.Linear(1, 1).to(device, dtype)\n    linear(x)\n\n    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)\n    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)\n    conv2d(x)\n"
  },
  {
    "path": "utils/img_util.py",
    "content": "import os\nimport PIL\nimport cv2\nimport math\nimport numpy as np\nimport torch\nimport torchvision\nimport imageio\n\nfrom einops import rearrange\n\ndef save_videos_grid(videos, path=None, rescale=True, n_rows=4, fps=8, discardN=0):\n    videos = rearrange(videos, \"b c t h w -> t b c h w\").cpu()\n    outputs = []\n    for x in videos:\n        x = torchvision.utils.make_grid(x, nrow=n_rows)\n        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)\n        if rescale:\n            x = (x / 2.0 + 0.5).clamp(0, 1)  # -1,1 -> 0,1\n        x = (x * 255).numpy().astype(np.uint8)\n        #x = adjust_gamma(x, 0.5)\n        outputs.append(x)\n\n    outputs = outputs[discardN:]\n\n    if path is not None:\n        #os.makedirs(os.path.dirname(path), exist_ok=True)\n        imageio.mimsave(path, outputs, duration=1000/fps, loop=0)\n\n    return outputs\n\ndef convert_image_to_fn(img_type, minsize, image, eps=0.02):\n    width, height = image.size\n    if min(width, height) < minsize:\n        scale = minsize/min(width, height) + eps\n        image = image.resize((math.ceil(width*scale), math.ceil(height*scale)))\n\n    if image.mode != img_type:\n        return image.convert(img_type)\n    return image"
  },
  {
    "path": "utils/misc.py",
    "content": "import os\nimport binascii\nfrom safetensors import safe_open\n\nimport torch\n\nfrom diffusers.pipelines.stable_diffusion.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint\n\ndef rand_name(length=8, suffix=''):\n    name = binascii.b2a_hex(os.urandom(length)).decode('utf-8')\n    if suffix:\n        if not suffix.startswith('.'):\n            suffix = '.' + suffix\n        name += suffix\n    return name\n\ndef cycle(dl):\n    while True:\n        for data in dl:\n            yield data\n\ndef exists(x):\n    return x is not None\n\ndef identity(x):\n    return x\n\ndef load_dreambooth_lora(unet, vae=None, model_path=None, alpha=1.0, model_base=\"\"):\n    if model_path is None: return unet\n    \n    if model_path.endswith(\".ckpt\"):\n        base_state_dict = torch.load(model_path)['state_dict']\n    elif model_path.endswith(\".safetensors\"):\n        state_dict = {}\n        with safe_open(model_path, framework=\"pt\", device=\"cpu\") as f:\n            for key in f.keys():\n                state_dict[key] = f.get_tensor(key)\n                            \n        is_lora = all(\"lora\" in k for k in state_dict.keys())\n        if not is_lora:\n            base_state_dict = state_dict\n        else:\n            base_state_dict = {}\n            with safe_open(model_base, framework=\"pt\", device=\"cpu\") as f:\n                for key in f.keys():\n                    base_state_dict[key] = f.get_tensor(key)\n                                 \n    converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, unet.config)\n    unet_state_dict = unet.state_dict()\n    for key in converted_unet_checkpoint:\n        converted_unet_checkpoint[key] = alpha * converted_unet_checkpoint[key] + (1.0-alpha) * unet_state_dict[key]\n    unet.load_state_dict(converted_unet_checkpoint, strict=False)\n\n    if vae is not None:\n        converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, vae.config)\n        vae.load_state_dict(converted_vae_checkpoint)\n    \n    return unet, vae"
  },
  {
    "path": "utils/vaehook.py",
    "content": "# ------------------------------------------------------------------------\n#\n#   Ultimate VAE Tile Optimization\n#\n#   Introducing a revolutionary new optimization designed to make\n#   the VAE work with giant images on limited VRAM!\n#   Say goodbye to the frustration of OOM and hello to seamless output!\n#\n# ------------------------------------------------------------------------\n#\n#   This script is a wild hack that splits the image into tiles,\n#   encodes each tile separately, and merges the result back together.\n#\n#   Advantages:\n#   - The VAE can now work with giant images on limited VRAM\n#       (~10 GB for 8K images!)\n#   - The merged output is completely seamless without any post-processing.\n#\n#   Drawbacks:\n#   - Giant RAM needed. To store the intermediate results for a 4096x4096\n#       images, you need 32 GB RAM it consumes ~20GB); for 8192x8192\n#       you need 128 GB RAM machine (it consumes ~100 GB)\n#   - NaNs always appear in for 8k images when you use fp16 (half) VAE\n#       You must use --no-half-vae to disable half VAE for that giant image.\n#   - Slow speed. With default tile size, it takes around 50/200 seconds\n#       to encode/decode a 4096x4096 image; and 200/900 seconds to encode/decode\n#       a 8192x8192 image. (The speed is limited by both the GPU and the CPU.)\n#   - The gradient calculation is not compatible with this hack. It\n#       will break any backward() or torch.autograd.grad() that passes VAE.\n#       (But you can still use the VAE to generate training data.)\n#\n#   How it works:\n#   1) The image is split into tiles.\n#       - To ensure perfect results, each tile is padded with 32 pixels\n#           on each side.\n#       - Then the conv2d/silu/upsample/downsample can produce identical\n#           results to the original image without splitting.\n#   2) The original forward is decomposed into a task queue and a task worker.\n#       - The task queue is a list of functions that will be executed in order.\n#       - The task worker is a loop that executes the tasks in the queue.\n#   3) The task queue is executed for each tile.\n#       - Current tile is sent to GPU.\n#       - local operations are directly executed.\n#       - Group norm calculation is temporarily suspended until the mean\n#           and var of all tiles are calculated.\n#       - The residual is pre-calculated and stored and addded back later.\n#       - When need to go to the next tile, the current tile is send to cpu.\n#   4) After all tiles are processed, tiles are merged on cpu and return.\n#\n#   Enjoy!\n#\n#   @author: LI YI @ Nanyang Technological University - Singapore\n#   @date: 2023-03-02\n#   @license: MIT License\n#\n#   Please give me a star if you like this project!\n#\n# -------------------------------------------------------------------------\n\nimport gc\nfrom time import time\nimport math\nfrom tqdm import tqdm\n\nimport torch\nimport torch.version\nimport torch.nn.functional as F\nfrom einops import rearrange\nimport sys\nimport myutils.devices as devices\n#from modules.shared import state\n#from ldm.modules.diffusionmodules.model import AttnBlock, MemoryEfficientAttnBlock\n\ntry:\n    import xformers\n    import xformers.ops\nexcept ImportError:\n    pass\n\nsd_flag = False\n\ndef get_recommend_encoder_tile_size():\n    if torch.cuda.is_available():\n        total_memory = torch.cuda.get_device_properties(\n            devices.device).total_memory // 2**20\n        if total_memory > 16*1000:\n            ENCODER_TILE_SIZE = 3072\n        elif total_memory > 12*1000:\n            ENCODER_TILE_SIZE = 2048\n        elif total_memory > 8*1000:\n            ENCODER_TILE_SIZE = 1536\n        else:\n            ENCODER_TILE_SIZE = 960\n    else:\n        ENCODER_TILE_SIZE = 512\n    return ENCODER_TILE_SIZE\n\n\ndef get_recommend_decoder_tile_size():\n    if torch.cuda.is_available():\n        total_memory = torch.cuda.get_device_properties(\n            devices.device).total_memory // 2**20\n        if total_memory > 30*1000:\n            DECODER_TILE_SIZE = 256\n        elif total_memory > 16*1000:\n            DECODER_TILE_SIZE = 192\n        elif total_memory > 12*1000:\n            DECODER_TILE_SIZE = 128\n        elif total_memory > 8*1000:\n            DECODER_TILE_SIZE = 96\n        else:\n            DECODER_TILE_SIZE = 64\n    else:\n        DECODER_TILE_SIZE = 64\n    return DECODER_TILE_SIZE\n\n\nif 'global const':\n    DEFAULT_ENABLED = False\n    DEFAULT_MOVE_TO_GPU = False\n    DEFAULT_FAST_ENCODER = True\n    DEFAULT_FAST_DECODER = True\n    DEFAULT_COLOR_FIX = 0\n    DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size()\n    DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size()\n\n\n# inplace version of silu\ndef inplace_nonlinearity(x):\n    # Test: fix for Nans\n    return F.silu(x, inplace=True)\n\n# extracted from ldm.modules.diffusionmodules.model\n\n# from diffusers lib\ndef attn_forward_new(self, h_):\n    batch_size, channel, height, width = h_.shape\n    hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2)\n\n    attention_mask = None\n    encoder_hidden_states = None\n    batch_size, sequence_length, _ = hidden_states.shape\n    attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n\n    query = self.to_q(hidden_states)\n\n    if encoder_hidden_states is None:\n        encoder_hidden_states = hidden_states\n    elif self.norm_cross:\n        encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states)\n\n    key = self.to_k(encoder_hidden_states)\n    value = self.to_v(encoder_hidden_states)\n\n    query = self.head_to_batch_dim(query)\n    key = self.head_to_batch_dim(key)\n    value = self.head_to_batch_dim(value)\n\n    attention_probs = self.get_attention_scores(query, key, attention_mask)\n    hidden_states = torch.bmm(attention_probs, value)\n    hidden_states = self.batch_to_head_dim(hidden_states)\n\n    # linear proj\n    hidden_states = self.to_out[0](hidden_states)\n    # dropout\n    hidden_states = self.to_out[1](hidden_states)\n\n    hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n    return hidden_states\n\ndef attn_forward(self, h_):\n    q = self.q(h_)\n    k = self.k(h_)\n    v = self.v(h_)\n\n    # compute attention\n    b, c, h, w = q.shape\n    q = q.reshape(b, c, h*w)\n    q = q.permute(0, 2, 1)   # b,hw,c\n    k = k.reshape(b, c, h*w)  # b,c,hw\n    w_ = torch.bmm(q, k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]\n    w_ = w_ * (int(c)**(-0.5))\n    w_ = torch.nn.functional.softmax(w_, dim=2)\n\n    # attend to values\n    v = v.reshape(b, c, h*w)\n    w_ = w_.permute(0, 2, 1)   # b,hw,hw (first hw of k, second of q)\n    # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]\n    h_ = torch.bmm(v, w_)\n    h_ = h_.reshape(b, c, h, w)\n\n    h_ = self.proj_out(h_)\n\n    return h_\n\n\ndef xformer_attn_forward(self, h_):\n    q = self.q(h_)\n    k = self.k(h_)\n    v = self.v(h_)\n\n    # compute attention\n    B, C, H, W = q.shape\n    q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))\n\n    q, k, v = map(\n        lambda t: t.unsqueeze(3)\n        .reshape(B, t.shape[1], 1, C)\n        .permute(0, 2, 1, 3)\n        .reshape(B * 1, t.shape[1], C)\n        .contiguous(),\n        (q, k, v),\n    )\n    out = xformers.ops.memory_efficient_attention(\n        q, k, v, attn_bias=None, op=self.attention_op)\n\n    out = (\n        out.unsqueeze(0)\n        .reshape(B, 1, out.shape[1], C)\n        .permute(0, 2, 1, 3)\n        .reshape(B, out.shape[1], C)\n    )\n    out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)\n    out = self.proj_out(out)\n    return out\n\n\ndef attn2task(task_queue, net):\n    if False: #isinstance(net, AttnBlock):\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.norm))\n        task_queue.append(('attn', lambda x, net=net: attn_forward(net, x)))\n        task_queue.append(['add_res', None])\n    elif False: #isinstance(net, MemoryEfficientAttnBlock):\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.norm))\n        task_queue.append(\n            ('attn', lambda x, net=net: xformer_attn_forward(net, x)))\n        task_queue.append(['add_res', None])\n    else:\n        task_queue.append(('store_res', lambda x: x))\n        task_queue.append(('pre_norm', net.group_norm))\n        task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x)))\n        task_queue.append(['add_res', None])\n\ndef resblock2task(queue, block):\n    \"\"\"\n    Turn a ResNetBlock into a sequence of tasks and append to the task queue\n\n    @param queue: the target task queue\n    @param block: ResNetBlock\n\n    \"\"\"\n    if block.in_channels != block.out_channels:\n        if sd_flag:\n            if block.use_conv_shortcut:\n                queue.append(('store_res', block.conv_shortcut))\n            else:\n                queue.append(('store_res', block.nin_shortcut))\n        else:\n            if block.use_in_shortcut:\n                queue.append(('store_res', block.conv_shortcut))\n            else:\n                queue.append(('store_res', block.nin_shortcut))\n\n    else:\n        queue.append(('store_res', lambda x: x))\n    queue.append(('pre_norm', block.norm1))\n    queue.append(('silu', inplace_nonlinearity))\n    queue.append(('conv1', block.conv1))\n    queue.append(('pre_norm', block.norm2))\n    queue.append(('silu', inplace_nonlinearity))\n    queue.append(('conv2', block.conv2))\n    queue.append(['add_res', None])\n\n\ndef build_sampling(task_queue, net, is_decoder):\n    \"\"\"\n    Build the sampling part of a task queue\n    @param task_queue: the target task queue\n    @param net: the network\n    @param is_decoder: currently building decoder or encoder\n    \"\"\"\n    if is_decoder:\n        # resblock2task(task_queue, net.mid.block_1)\n        # attn2task(task_queue, net.mid.attn_1)\n        # resblock2task(task_queue, net.mid.block_2)\n        # resolution_iter = reversed(range(net.num_resolutions))\n        # block_ids = net.num_res_blocks + 1\n        # condition = 0\n        # module = net.up\n        # func_name = 'upsample'\n        resblock2task(task_queue, net.mid_block.resnets[0])\n        attn2task(task_queue, net.mid_block.attentions[0])\n        resblock2task(task_queue, net.mid_block.resnets[1])\n        resolution_iter = (range(len(net.up_blocks)))  # range(0,4)\n        block_ids = 2 + 1\n        condition = len(net.up_blocks) - 1\n        module = net.up_blocks\n        func_name = 'upsamplers'\n    else:\n        # resolution_iter = range(net.num_resolutions)\n        # block_ids = net.num_res_blocks\n        # condition = net.num_resolutions - 1\n        # module = net.down\n        # func_name = 'downsample'\n        resolution_iter = (range(len(net.down_blocks)))  # range(0,4)\n        block_ids = 2\n        condition = len(net.down_blocks) - 1\n        module = net.down_blocks\n        func_name = 'downsamplers'\n\n    for i_level in resolution_iter:\n        for i_block in range(block_ids):\n            resblock2task(task_queue, module[i_level].resnets[i_block])\n        if i_level != condition:\n            if is_decoder:\n                task_queue.append((func_name, module[i_level].upsamplers[0]))\n            else:\n                task_queue.append((func_name, module[i_level].downsamplers[0]))\n\n    if not is_decoder:\n        resblock2task(task_queue, net.mid_block.resnets[0])\n        attn2task(task_queue, net.mid_block.attentions[0])\n        resblock2task(task_queue, net.mid_block.resnets[1])\n\n\ndef build_task_queue(net, is_decoder):\n    \"\"\"\n    Build a single task queue for the encoder or decoder\n    @param net: the VAE decoder or encoder network\n    @param is_decoder: currently building decoder or encoder\n    @return: the task queue\n    \"\"\"\n    task_queue = []\n    task_queue.append(('conv_in', net.conv_in))\n\n    # construct the sampling part of the task queue\n    # because encoder and decoder share the same architecture, we extract the sampling part\n    build_sampling(task_queue, net, is_decoder)\n    if is_decoder and not sd_flag:\n        net.give_pre_end = False\n        net.tanh_out = False\n\n    if not is_decoder or not net.give_pre_end:\n        if sd_flag:\n            task_queue.append(('pre_norm', net.norm_out))\n        else:\n            task_queue.append(('pre_norm', net.conv_norm_out))\n        task_queue.append(('silu', inplace_nonlinearity))\n        task_queue.append(('conv_out', net.conv_out))\n        if is_decoder and net.tanh_out:\n            task_queue.append(('tanh', torch.tanh))\n\n    return task_queue\n\n\ndef clone_task_queue(task_queue):\n    \"\"\"\n    Clone a task queue\n    @param task_queue: the task queue to be cloned\n    @return: the cloned task queue\n    \"\"\"\n    return [[item for item in task] for task in task_queue]\n\n\ndef get_var_mean(input, num_groups, eps=1e-6):\n    \"\"\"\n    Get mean and var for group norm\n    \"\"\"\n    b, c = input.size(0), input.size(1)\n    channel_in_group = int(c/num_groups)\n    input_reshaped = input.contiguous().view(\n        1, int(b * num_groups), channel_in_group, *input.size()[2:])\n    var, mean = torch.var_mean(\n        input_reshaped, dim=[0, 2, 3, 4], unbiased=False)\n    return var, mean\n\n\ndef custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6):\n    \"\"\"\n    Custom group norm with fixed mean and var\n\n    @param input: input tensor\n    @param num_groups: number of groups. by default, num_groups = 32\n    @param mean: mean, must be pre-calculated by get_var_mean\n    @param var: var, must be pre-calculated by get_var_mean\n    @param weight: weight, should be fetched from the original group norm\n    @param bias: bias, should be fetched from the original group norm\n    @param eps: epsilon, by default, eps = 1e-6 to match the original group norm\n\n    @return: normalized tensor\n    \"\"\"\n    b, c = input.size(0), input.size(1)\n    channel_in_group = int(c/num_groups)\n    input_reshaped = input.contiguous().view(\n        1, int(b * num_groups), channel_in_group, *input.size()[2:])\n\n    out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None,\n                       training=False, momentum=0, eps=eps)\n\n    out = out.view(b, c, *input.size()[2:])\n\n    # post affine transform\n    if weight is not None:\n        out *= weight.view(1, -1, 1, 1)\n    if bias is not None:\n        out += bias.view(1, -1, 1, 1)\n    return out\n\n\ndef crop_valid_region(x, input_bbox, target_bbox, is_decoder):\n    \"\"\"\n    Crop the valid region from the tile\n    @param x: input tile\n    @param input_bbox: original input bounding box\n    @param target_bbox: output bounding box\n    @param scale: scale factor\n    @return: cropped tile\n    \"\"\"\n    padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox]\n    margin = [target_bbox[i] - padded_bbox[i] for i in range(4)]\n    return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]]\n\n# ↓↓↓ https://github.com/Kahsolt/stable-diffusion-webui-vae-tile-infer ↓↓↓\n\n\ndef perfcount(fn):\n    def wrapper(*args, **kwargs):\n        ts = time()\n\n        if torch.cuda.is_available():\n            torch.cuda.reset_peak_memory_stats(devices.device)\n        devices.torch_gc()\n        gc.collect()\n\n        ret = fn(*args, **kwargs)\n\n        devices.torch_gc()\n        gc.collect()\n        if torch.cuda.is_available():\n            vram = torch.cuda.max_memory_allocated(devices.device) / 2**20\n            torch.cuda.reset_peak_memory_stats(devices.device)\n            print(\n                f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB')\n        else:\n            print(f'[Tiled VAE]: Done in {time() - ts:.3f}s')\n\n        return ret\n    return wrapper\n\n# copy end :)\n\n\nclass GroupNormParam:\n    def __init__(self):\n        self.var_list = []\n        self.mean_list = []\n        self.pixel_list = []\n        self.weight = None\n        self.bias = None\n\n    def add_tile(self, tile, layer):\n        var, mean = get_var_mean(tile, 32)\n        # For giant images, the variance can be larger than max float16\n        # In this case we create a copy to float32\n        if var.dtype == torch.float16 and var.isinf().any():\n            fp32_tile = tile.float()\n            var, mean = get_var_mean(fp32_tile, 32)\n        # ============= DEBUG: test for infinite =============\n        # if torch.isinf(var).any():\n        #    print('var: ', var)\n        # ====================================================\n        self.var_list.append(var)\n        self.mean_list.append(mean)\n        self.pixel_list.append(\n            tile.shape[2]*tile.shape[3])\n        if hasattr(layer, 'weight'):\n            self.weight = layer.weight\n            self.bias = layer.bias\n        else:\n            self.weight = None\n            self.bias = None\n\n    def summary(self):\n        \"\"\"\n        summarize the mean and var and return a function\n        that apply group norm on each tile\n        \"\"\"\n        if len(self.var_list) == 0:\n            return None\n        var = torch.vstack(self.var_list)\n        mean = torch.vstack(self.mean_list)\n        max_value = max(self.pixel_list)\n        pixels = torch.tensor(\n            self.pixel_list, dtype=torch.float32, device=devices.device) / max_value\n        sum_pixels = torch.sum(pixels)\n        pixels = pixels.unsqueeze(\n            1) / sum_pixels\n        var = torch.sum(\n            var * pixels, dim=0)\n        mean = torch.sum(\n            mean * pixels, dim=0)\n        return lambda x:  custom_group_norm(x, 32, mean, var, self.weight, self.bias)\n\n    @staticmethod\n    def from_tile(tile, norm):\n        \"\"\"\n        create a function from a single tile without summary\n        \"\"\"\n        var, mean = get_var_mean(tile, 32)\n        if var.dtype == torch.float16 and var.isinf().any():\n            fp32_tile = tile.float()\n            var, mean = get_var_mean(fp32_tile, 32)\n            # if it is a macbook, we need to convert back to float16\n            if var.device.type == 'mps':\n                # clamp to avoid overflow\n                var = torch.clamp(var, 0, 60000)\n                var = var.half()\n                mean = mean.half()\n        if hasattr(norm, 'weight'):\n            weight = norm.weight\n            bias = norm.bias\n        else:\n            weight = None\n            bias = None\n\n        def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias):\n            return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6)\n        return group_norm_func\n\n\nclass VAEHook:\n    def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False):\n        self.net = net                  # encoder | decoder\n        self.tile_size = tile_size\n        self.is_decoder = is_decoder\n        self.fast_mode = (fast_encoder and not is_decoder) or (\n            fast_decoder and is_decoder)\n        self.color_fix = color_fix and not is_decoder\n        self.to_gpu = to_gpu\n        self.pad = 11 if is_decoder else 32\n\n    def __call__(self, x):\n        B, C, H, W = x.shape\n        original_device = next(self.net.parameters()).device\n        try:\n            if self.to_gpu:\n                self.net.to(devices.get_optimal_device())\n            if max(H, W) <= self.pad * 2 + self.tile_size:\n                print(\"[Tiled VAE]: the input size is tiny and unnecessary to tile.\")\n                x_type = x.dtype\n                x = self.net.original_forward(x)\n                x = x.to(dtype=x_type)\n                return x\n            else:\n                x_type = x.dtype\n                x = self.vae_tile_forward(x)\n                x = x.to(dtype=x_type)\n                return x\n        finally:\n            self.net.to(original_device)\n\n    def get_best_tile_size(self, lowerbound, upperbound):\n        \"\"\"\n        Get the best tile size for GPU memory\n        \"\"\"\n        divider = 32\n        while divider >= 2:\n            remainer = lowerbound % divider\n            if remainer == 0:\n                return lowerbound\n            candidate = lowerbound - remainer + divider\n            if candidate <= upperbound:\n                return candidate\n            divider //= 2\n        return lowerbound\n\n    def split_tiles(self, h, w):\n        \"\"\"\n        Tool function to split the image into tiles\n        @param h: height of the image\n        @param w: width of the image\n        @return: tile_input_bboxes, tile_output_bboxes\n        \"\"\"\n        tile_input_bboxes, tile_output_bboxes = [], []\n        tile_size = self.tile_size\n        pad = self.pad\n        num_height_tiles = math.ceil((h - 2 * pad) / tile_size)\n        num_width_tiles = math.ceil((w - 2 * pad) / tile_size)\n        # If any of the numbers are 0, we let it be 1\n        # This is to deal with long and thin images\n        num_height_tiles = max(num_height_tiles, 1)\n        num_width_tiles = max(num_width_tiles, 1)\n\n        # Suggestions from https://github.com/Kahsolt: auto shrink the tile size\n        real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles)\n        real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles)\n        real_tile_height = self.get_best_tile_size(real_tile_height, tile_size)\n        real_tile_width = self.get_best_tile_size(real_tile_width, tile_size)\n\n        print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' +\n              f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}')\n\n        for i in range(num_height_tiles):\n            for j in range(num_width_tiles):\n                # bbox: [x1, x2, y1, y2]\n                # the padding is is unnessary for image borders. So we directly start from (32, 32)\n                input_bbox = [\n                    pad + j * real_tile_width,\n                    min(pad + (j + 1) * real_tile_width, w),\n                    pad + i * real_tile_height,\n                    min(pad + (i + 1) * real_tile_height, h),\n                ]\n\n                # if the output bbox is close to the image boundary, we extend it to the image boundary\n                output_bbox = [\n                    input_bbox[0] if input_bbox[0] > pad else 0,\n                    input_bbox[1] if input_bbox[1] < w - pad else w,\n                    input_bbox[2] if input_bbox[2] > pad else 0,\n                    input_bbox[3] if input_bbox[3] < h - pad else h,\n                ]\n\n                # scale to get the final output bbox\n                output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox]\n                tile_output_bboxes.append(output_bbox)\n\n                # indistinguishable expand the input bbox by pad pixels\n                tile_input_bboxes.append([\n                    max(0, input_bbox[0] - pad),\n                    min(w, input_bbox[1] + pad),\n                    max(0, input_bbox[2] - pad),\n                    min(h, input_bbox[3] + pad),\n                ])\n\n        return tile_input_bboxes, tile_output_bboxes\n\n    @torch.no_grad()\n    def estimate_group_norm(self, z, task_queue, color_fix):\n        device = z.device\n        tile = z\n        last_id = len(task_queue) - 1\n        while last_id >= 0 and task_queue[last_id][0] != 'pre_norm':\n            last_id -= 1\n        if last_id <= 0 or task_queue[last_id][0] != 'pre_norm':\n            raise ValueError('No group norm found in the task queue')\n        # estimate until the last group norm\n        for i in range(last_id + 1):\n            task = task_queue[i]\n            if task[0] == 'pre_norm':\n                group_norm_func = GroupNormParam.from_tile(tile, task[1])\n                task_queue[i] = ('apply_norm', group_norm_func)\n                if i == last_id:\n                    return True\n                tile = group_norm_func(tile)\n            elif task[0] == 'store_res':\n                task_id = i + 1\n                while task_id < last_id and task_queue[task_id][0] != 'add_res':\n                    task_id += 1\n                if task_id >= last_id:\n                    continue\n                task_queue[task_id][1] = task[1](tile)\n            elif task[0] == 'add_res':\n                tile += task[1].to(device)\n                task[1] = None\n            elif color_fix and task[0] == 'downsample':\n                for j in range(i, last_id + 1):\n                    if task_queue[j][0] == 'store_res':\n                        task_queue[j] = ('store_res_cpu', task_queue[j][1])\n                return True\n            else:\n                tile = task[1](tile)\n            try:\n                devices.test_for_nans(tile, \"vae\")\n            except:\n                print(f'Nan detected in fast mode estimation. Fast mode disabled.')\n                return False\n\n        raise IndexError('Should not reach here')\n\n    @perfcount\n    @torch.no_grad()\n    def vae_tile_forward(self, z):\n        \"\"\"\n        Decode a latent vector z into an image in a tiled manner.\n        @param z: latent vector\n        @return: image\n        \"\"\"\n        device = next(self.net.parameters()).device\n        net = self.net\n        tile_size = self.tile_size\n        is_decoder = self.is_decoder\n\n        z = z.detach() # detach the input to avoid backprop\n\n        N, height, width = z.shape[0], z.shape[2], z.shape[3]\n        net.last_z_shape = z.shape\n\n        # Split the input into tiles and build a task queue for each tile\n        print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}')\n\n        in_bboxes, out_bboxes = self.split_tiles(height, width)\n\n        # Prepare tiles by split the input latents\n        tiles = []\n        for input_bbox in in_bboxes:\n            tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu()\n            tiles.append(tile)\n\n        num_tiles = len(tiles)\n        num_completed = 0\n\n        # Build task queues\n        single_task_queue = build_task_queue(net, is_decoder)\n        #print(single_task_queue)\n        if self.fast_mode:\n            # Fast mode: downsample the input image to the tile size,\n            # then estimate the group norm parameters on the downsampled image\n            scale_factor = tile_size / max(height, width)\n            z = z.to(device)\n            downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact')\n            # use nearest-exact to keep statictics as close as possible\n            print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image')\n\n            # ======= Special thanks to @Kahsolt for distribution shift issue ======= #\n            # The downsampling will heavily distort its mean and std, so we need to recover it.\n            std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True)\n            std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True)\n            downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old\n            del std_old, mean_old, std_new, mean_new\n            # occasionally the std_new is too small or too large, which exceeds the range of float16\n            # so we need to clamp it to max z's range.\n            downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max())\n            estimate_task_queue = clone_task_queue(single_task_queue)\n            if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix):\n                single_task_queue = estimate_task_queue\n            del downsampled_z\n\n        task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)]\n\n        # Dummy result\n        result = None\n        result_approx = None\n        #try:\n        #    with devices.autocast():\n        #        result_approx = torch.cat([F.interpolate(cheap_approximation(x).unsqueeze(0), scale_factor=opt_f, mode='nearest-exact') for x in z], dim=0).cpu()\n        #except: pass\n        # Free memory of input latent tensor\n        del z\n\n        # Task queue execution\n        pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f\"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: \")\n\n        # execute the task back and forth when switch tiles so that we always\n        # keep one tile on the GPU to reduce unnecessary data transfer\n        forward = True\n        interrupted = False\n        #state.interrupted = interrupted\n        while True:\n            #if state.interrupted: interrupted = True ; break\n\n            group_norm_param = GroupNormParam()\n            for i in range(num_tiles) if forward else reversed(range(num_tiles)):\n                #if state.interrupted: interrupted = True ; break\n\n                tile = tiles[i].to(device)\n                input_bbox = in_bboxes[i]\n                task_queue = task_queues[i]\n\n                interrupted = False\n                while len(task_queue) > 0:\n                    #if state.interrupted: interrupted = True ; break\n\n                    # DEBUG: current task\n                    # print('Running task: ', task_queue[0][0], ' on tile ', i, '/', num_tiles, ' with shape ', tile.shape)\n                    task = task_queue.pop(0)\n                    if task[0] == 'pre_norm':\n                        group_norm_param.add_tile(tile, task[1])\n                        break\n                    elif task[0] == 'store_res' or task[0] == 'store_res_cpu':\n                        task_id = 0\n                        res = task[1](tile)\n                        if not self.fast_mode or task[0] == 'store_res_cpu':\n                            res = res.cpu()\n                        while task_queue[task_id][0] != 'add_res':\n                            task_id += 1\n                        task_queue[task_id][1] = res\n                    elif task[0] == 'add_res':\n                        tile += task[1].to(device)\n                        task[1] = None\n                    else:\n                        tile = task[1](tile)\n                        #print(tiles[i].shape, tile.shape, task)\n                    pbar.update(1)\n\n                if interrupted: break\n\n                # check for NaNs in the tile.\n                # If there are NaNs, we abort the process to save user's time\n                #devices.test_for_nans(tile, \"vae\")\n\n                #print(tiles[i].shape, tile.shape, i, num_tiles)\n                if len(task_queue) == 0:\n                    tiles[i] = None\n                    num_completed += 1\n                    if result is None:      # NOTE: dim C varies from different cases, can only be inited dynamically\n                        result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False)\n                    result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder)\n                    del tile\n                elif i == num_tiles - 1 and forward:\n                    forward = False\n                    tiles[i] = tile\n                elif i == 0 and not forward:\n                    forward = True\n                    tiles[i] = tile\n                else:\n                    tiles[i] = tile.cpu()\n                    del tile\n\n            if interrupted: break\n            if num_completed == num_tiles: break\n\n            # insert the group norm task to the head of each task queue\n            group_norm_func = group_norm_param.summary()\n            if group_norm_func is not None:\n                for i in range(num_tiles):\n                    task_queue = task_queues[i]\n                    task_queue.insert(0, ('apply_norm', group_norm_func))\n\n        # Done!\n        pbar.close()\n        return result if result is not None else result_approx.to(device)"
  },
  {
    "path": "utils/wavelet_color_fix.py",
    "content": "'''\n# --------------------------------------------------------------------------------\n#   Color fixed script from Li Yi (https://github.com/pkuliyi2015/sd-webui-stablesr/blob/master/srmodule/colorfix.py)\n# --------------------------------------------------------------------------------\n'''\n\nimport torch\nfrom PIL import Image\nfrom torch import Tensor\nfrom torch.nn import functional as F\n\nfrom torchvision.transforms import ToTensor, ToPILImage\n\ndef adain_color_fix(target: Image, source: Image):\n    # Convert images to tensors\n    to_tensor = ToTensor()\n    target_tensor = to_tensor(target).unsqueeze(0)\n    source_tensor = to_tensor(source).unsqueeze(0)\n\n    # Apply adaptive instance normalization\n    result_tensor = adaptive_instance_normalization(target_tensor, source_tensor)\n\n    # Convert tensor back to image\n    to_image = ToPILImage()\n    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))\n\n    return result_image\n\ndef wavelet_color_fix(target: Image, source: Image):\n    # Convert images to tensors\n    to_tensor = ToTensor()\n    target_tensor = to_tensor(target).unsqueeze(0)\n    source_tensor = to_tensor(source).unsqueeze(0)\n\n    # Apply wavelet reconstruction\n    result_tensor = wavelet_reconstruction(target_tensor, source_tensor)\n\n    # Convert tensor back to image\n    to_image = ToPILImage()\n    result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0))\n\n    return result_image\n\ndef calc_mean_std(feat: Tensor, eps=1e-5):\n    \"\"\"Calculate mean and std for adaptive_instance_normalization.\n    Args:\n        feat (Tensor): 4D tensor.\n        eps (float): A small value added to the variance to avoid\n            divide-by-zero. Default: 1e-5.\n    \"\"\"\n    size = feat.size()\n    assert len(size) == 4, 'The input feature should be 4D tensor.'\n    b, c = size[:2]\n    feat_var = feat.reshape(b, c, -1).var(dim=2) + eps\n    feat_std = feat_var.sqrt().reshape(b, c, 1, 1)\n    feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1)\n    return feat_mean, feat_std\n\ndef adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"Adaptive instance normalization.\n    Adjust the reference features to have the similar color and illuminations\n    as those in the degradate features.\n    Args:\n        content_feat (Tensor): The reference feature.\n        style_feat (Tensor): The degradate features.\n    \"\"\"\n    size = content_feat.size()\n    style_mean, style_std = calc_mean_std(style_feat)\n    content_mean, content_std = calc_mean_std(content_feat)\n    normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)\n    return normalized_feat * style_std.expand(size) + style_mean.expand(size)\n\ndef wavelet_blur(image: Tensor, radius: int):\n    \"\"\"\n    Apply wavelet blur to the input tensor.\n    \"\"\"\n    # input shape: (1, 3, H, W)\n    # convolution kernel\n    kernel_vals = [\n        [0.0625, 0.125, 0.0625],\n        [0.125, 0.25, 0.125],\n        [0.0625, 0.125, 0.0625],\n    ]\n    kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)\n    # add channel dimensions to the kernel to make it a 4D tensor\n    kernel = kernel[None, None]\n    # repeat the kernel across all input channels\n    kernel = kernel.repeat(3, 1, 1, 1)\n    image = F.pad(image, (radius, radius, radius, radius), mode='replicate')\n    # apply convolution\n    output = F.conv2d(image, kernel, groups=3, dilation=radius)\n    return output\n\ndef wavelet_decomposition(image: Tensor, levels=5):\n    \"\"\"\n    Apply wavelet decomposition to the input tensor.\n    This function only returns the low frequency & the high frequency.\n    \"\"\"\n    high_freq = torch.zeros_like(image)\n    for i in range(levels):\n        radius = 2 ** i\n        low_freq = wavelet_blur(image, radius)\n        high_freq += (image - low_freq)\n        image = low_freq\n\n    return high_freq, low_freq\n\ndef wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor):\n    \"\"\"\n    Apply wavelet decomposition, so that the content will have the same color as the style.\n    \"\"\"\n    # calculate the wavelet decomposition of the content feature\n    content_high_freq, content_low_freq = wavelet_decomposition(content_feat)\n    del content_low_freq\n    # calculate the wavelet decomposition of the style feature\n    style_high_freq, style_low_freq = wavelet_decomposition(style_feat)\n    del style_high_freq\n    # reconstruct the content feature with the style's high frequency\n    return content_high_freq + style_low_freq\n"
  }
]