[
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# CLEAR\n<a href=\"https://arxiv.org/abs/2412.16112\"><img src=\"https://img.shields.io/badge/arXiv-2412.16112-A42C25.svg\" alt=\"arXiv\"></a> \n</div>\n\n\n> **CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up**\n> <br>\n> [Songhua Liu](http://huage001.github.io/), \n> [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ&hl=en), \n> and \n> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)\n> <br>\n> NeurIPS 2025\n> <br>\n> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore\n> <br>\n\n![](./assets/teaser.png)\n\n## 🔥News\n\n**[2025/9/18]** CLEAR is accepted at NeurIPS 2025.\n\n**[2024/12/20]** We release training and inference codes of CLEAR, a simple yet effective strategy to linearize the complexity of pre-trained diffusion transformers, such as FLUX and SD3.\n\n## Introduction\n\nDiffusion Transformers (DiT) have become a leading architecture in image generation. However, the quadratic complexity of attention mechanisms, which are responsible for modeling token-wise relationships, results in significant latency when generating high-resolution images. To address this issue, we aim at a linear attention mechanism in this paper that reduces the complexity of pre-trained DiTs to linear. We begin our exploration with a comprehensive summary of existing efficient attention mechanisms and identify four key factors crucial for successful linearization of pre-trained DiTs: locality, formulation consistency, high-rank attention maps, and feature integrity. Based on these insights, we introduce a convolution-like local attention strategy termed CLEAR, which limits feature interactions to a local window around each query token, and thus achieves linear complexity. \nOur experiments indicate that, by fine-tuning the attention layer on merely 10K self-generated samples for 10K iterations, we can effectively transfer knowledge from a pre-trained DiT to a student model with linear complexity, yielding results comparable to the teacher model. Simultaneously, it reduces attention computations by 99.5% and accelerates generation by 6.3 times for generating 8K-resolution images. Furthermore, we investigate favorable properties in the distilled attention layers, such as zero-shot generalization across various models and plugins, and improved support for multi-GPU parallel inference.\n\n**TL;DR**: For pre-trained diffusion transformers, enforcing an image token to interact with only tokens within **a local window** can effectively reduce the complexity of the original models to a linear scale.\n\n## Installation\n\n* CLEAR requires ``torch>=2.5.0``, ``diffusers>=0.31.0``, and other packages listed in ``requirements.txt``. You can set up a new experiment with:\n\n  ```bash\n  conda create -n CLEAR python=3.12\n  conda activate CLEAR\n  pip install -r requirements.txt\n  ```\n\n* Clone this repo to your project directory:\n\n  ``` bash\n  git clone https://github.com/Huage001/CLEAR.git\n  ```\n\n## Supported Models\n\nWe release a series of variants for linearized [FLUX-1.dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with various local window sizes. \n\nWe experimentally find that when the local window size is small, e.g., 8, the model can produce repetitive patterns in many cases. To alleviate the problem, in some variants, we also include down-sampled key-value tokens besides local tokens for attention interaction.\n\nThe supported models and the download links are:\n\n| window_size | down_factor |                             link                             |\n| :---------: | :---------: | :----------------------------------------------------------: |\n|     32      |     NA      | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_32.safetensors) |\n|     16      |     NA      | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16.safetensors) |\n|      8      |     NA      | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8.safetensors) |\n|     16      |      4      | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16_down_4.safetensors) |\n|      8      |      4      | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors) |\n\nYou are encouraged to download the model weights you need to ``ckpt`` beforehand. For example:\n\n```bash\nmkdir ckpt\nwget https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors\n```\n\n## Inference\n\n* If you want to compare the linearized FLUX with the original model, please try ``inference_t2i.ipynb``.\n\n* If you want to use CLEAR for high-resolution acceleration, please try ``inference_t2i_highres.ipynb``. We currently adopt the strategy of [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image). The basic idea is to generate a low-resolution result at first, based on which we gradually upscale the image.\n\n* Please configure ``down_factor`` and ``window_size`` in the notebooks to use different variants of CLEAR. If you do not want to include down-sampled key-value tokens, specify ``down_factor=1``. The models will be downloaded automatically to ``ckpt`` if not downloaded.\n\n* Currently, a GPU card with 48 GB VMem is recommended for high-resolution generation.\n\n\n## Training\n\n* Configure ``/path/to/t2i_1024`` in multiple ``.sh`` files.\n\n* Download training images from [here](https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_1024_10K/data_000000.tar), which contains 10K 1024-resolution images generated by ``FLUX-1.dev`` itself, and unzip it to ``/path/to/t2i_1024``:\n\n  ```\n  tar -xvf data_000000.tar -C /path/to/t2i_1024\n  ```\n\n* [Optional but Recommended] Cache T5 and CLIP text embeddings and VAE features beforehand:\n\n  ```bash\n  bash cache_prompt_embeds.sh\n  bash cache_latent_codes.sh\n  ```\n\n* Start Training:\n\n  ```bash\n  bash distill.sh\n  ```\n\n  By default, it uses 4 80G-VMem GPUs with ``train_batch_size=2`` and ``gradient_accumulation_steps=4``. Please feel free to configure them in ``distill.sh`` and ``deepspeed_config.yaml`` according to your situations.\n\n## Acknowledgement\n\n* [FLUX](https://blackforestlabs.ai/announcing-black-forest-labs/) for the source models.\n* [flexattention](https://pytorch.org/blog/flexattention/) for kernel implementation.\n* [diffusers](https://github.com/huggingface/diffusers) for the code base.\n* [DeepSpeed](https://github.com/microsoft/DeepSpeed) for the training framework.\n* [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image) for high-resolution image generation.\n* [@Weihao Yu](https://github.com/yuweihao) and [@Xinyin Ma](https://github.com/horseee) for valuable discussions.\n* NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.\n\n## Citation\n\nIf you find this repo helpful, please consider citing:\n\n```bib\n@article{liu2024clear,\n    title     = {CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up},\n    author    = {Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},\n    journal   = {NeurIPS},\n    year      = {2025},\n}\n```\n"
  },
  {
    "path": "attention_processor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom torch.nn.attention.flex_attention import create_block_mask, flex_attention\ncreate_block_mask = torch.compile(create_block_mask)\nfrom diffusers.models.attention_processor import Attention\nfrom typing import Optional\nfrom functools import partial, lru_cache\nfrom diffusers.models.embeddings import apply_rotary_emb\n\n\nattn_outputs_teacher = []\nattn_outputs = []\nBLOCK_MASK = None\nHEIGHT = None\nWIDTH = None\n\n\nclass FluxAttnProcessor2_0:\n    \"\"\"Attention processor used typically in processing the SD3-like self-attention projections.\"\"\"\n\n    def __init__(self, distill=False):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n        self.distill = distill\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: torch.FloatTensor = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n        proportional_attention=False\n    ) -> torch.FloatTensor:\n        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n\n        # `sample` projections.\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`\n        if encoder_hidden_states is not None:\n            # `context` projections.\n            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)\n            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n\n            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n\n            if attn.norm_added_q is not None:\n                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)\n            if attn.norm_added_k is not None:\n                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)\n\n            # attention\n            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)\n\n        if image_rotary_emb is not None:\n            query = apply_rotary_emb(query, image_rotary_emb)\n            key = apply_rotary_emb(key, image_rotary_emb)\n\n        train_seq_len = 64 ** 2 + 512\n        if proportional_attention:\n            attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim)\n        else:\n            attention_scale = math.sqrt(1 / head_dim)\n\n        hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, hidden_states = (\n                hidden_states[:, : encoder_hidden_states.shape[1]],\n                hidden_states[:, encoder_hidden_states.shape[1] :],\n            )\n\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n\n            return hidden_states, encoder_hidden_states\n        else:\n            if self.distill:\n                attn_outputs_teacher.append(hidden_states)\n            return hidden_states\n\n\n@lru_cache\ndef init_local_downsample_mask_flex(height, width, text_length, window_size, down_factor, device):\n    \n    def local_dwonsample_mask(b, h, q_idx, kv_idx):\n        q_y = (q_idx - text_length) // width\n        q_x = (q_idx - text_length) % width\n        kv_y = (kv_idx - text_length) // width\n        kv_x = (kv_idx - text_length) % width\n        return torch.logical_or(\n            torch.logical_and(\n                q_idx < text_length, \n                kv_idx < text_length + height * width\n            ),\n            torch.logical_and(\n                q_idx >= text_length, \n                torch.logical_or(\n                    torch.logical_or(kv_idx < text_length, kv_idx >= text_length + height * width),\n                    (q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2)\n            )\n        )\n    \n    global BLOCK_MASK, HEIGHT, WIDTH\n    BLOCK_MASK = create_block_mask(local_dwonsample_mask, B=None, H=None, device=device,\n                                   Q_LEN=text_length + height * width, \n                                   KV_LEN=text_length + height * width + (height // down_factor) * (width // down_factor), _compile=True)\n    HEIGHT = height\n    WIDTH = width\n\n\nclass LocalDownsampleFlexAttnProcessor(nn.Module):\n    \n    def __init__(self, down_factor=4, distill=False):\n        super().__init__()\n        assert BLOCK_MASK is not None\n        self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK)\n        self.flex_attn = torch.compile(self.flex_attn, dynamic=False)\n        self.down_factor = down_factor\n        self.spatial_weight = nn.Parameter(torch.ones(1, 1, 1, down_factor, 1, down_factor, 1) / (down_factor * down_factor))\n        self.distill = distill\n        \n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: torch.FloatTensor = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n        proportional_attention=False\n    ) -> torch.FloatTensor:\n        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n\n        # `sample` projections.\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`\n        if encoder_hidden_states is not None:\n            # `context` projections.\n            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)\n            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n\n            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n\n            if attn.norm_added_q is not None:\n                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)\n            if attn.norm_added_k is not None:\n                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)\n\n            # attention\n            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)\n\n        if image_rotary_emb is not None:\n            query = apply_rotary_emb(query, image_rotary_emb)\n            key = apply_rotary_emb(key, image_rotary_emb)\n\n        train_seq_len = 64 ** 2 + 512\n        if proportional_attention:\n            attention_scale = math.sqrt(math.log(10 * key.size(2), train_seq_len) / head_dim)\n        else:\n            attention_scale = math.sqrt(1 / head_dim)\n        \n        key_downsample = (key[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor, \n                                                        WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3)\n        value_downsample = (value[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor, \n                                                            WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3)\n\n        hidden_states = self.flex_attn(query, torch.cat([key, key_downsample], dim=2), torch.cat([value, value_downsample], dim=2), scale=attention_scale)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, hidden_states = (\n                hidden_states[:, : encoder_hidden_states.shape[1]],\n                hidden_states[:, encoder_hidden_states.shape[1] :],\n            )\n\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n            if self.distill:\n                attn_outputs.append(hidden_states)\n\n            return hidden_states, encoder_hidden_states\n        else:\n            if self.distill:\n                attn_outputs.append(hidden_states)\n            return hidden_states\n\n\n@lru_cache\ndef init_local_mask_flex(height, width, text_length, window_size, device):\n    \n    def local_mask(b, h, q_idx, kv_idx):\n        q_y = (q_idx - text_length) // width\n        q_x = (q_idx - text_length) % width\n        kv_y = (kv_idx - text_length) // width\n        kv_x = (kv_idx - text_length) % width\n        return torch.logical_or(torch.logical_or(q_idx < text_length, kv_idx < text_length),\n                                (q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2)\n    \n    global BLOCK_MASK, HEIGHT, WIDTH\n    BLOCK_MASK = create_block_mask(local_mask, B=None, H=None, device=device,\n                                   Q_LEN=text_length + height * width, \n                                   KV_LEN=text_length + height * width, _compile=True)\n    HEIGHT = height\n    WIDTH = width\n\n\nclass LocalFlexAttnProcessor:\n    \"\"\"Attention processor used typically in processing the SD3-like self-attention projections.\"\"\"\n\n    def __init__(self, distill=False):\n        super().__init__()\n        self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK)\n        self.flex_attn = torch.compile(self.flex_attn, dynamic=False)\n        self.distill = distill\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: torch.FloatTensor = None,\n        attention_mask: Optional[torch.FloatTensor] = None,\n        image_rotary_emb: Optional[torch.Tensor] = None,\n        proportional_attention=False\n    ) -> torch.FloatTensor:\n        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n\n        # `sample` projections.\n        query = attn.to_q(hidden_states)\n        key = attn.to_k(hidden_states)\n        value = attn.to_v(hidden_states)\n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        if attn.norm_q is not None:\n            query = attn.norm_q(query)\n        if attn.norm_k is not None:\n            key = attn.norm_k(key)\n\n        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`\n        if encoder_hidden_states is not None:\n            # `context` projections.\n            encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)\n            encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)\n            encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)\n\n            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(\n                batch_size, -1, attn.heads, head_dim\n            ).transpose(1, 2)\n\n            if attn.norm_added_q is not None:\n                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)\n            if attn.norm_added_k is not None:\n                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)\n\n            # attention\n            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)\n            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)\n            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)\n\n        if image_rotary_emb is not None:\n            query = apply_rotary_emb(query, image_rotary_emb)\n            key = apply_rotary_emb(key, image_rotary_emb)\n\n        train_seq_len = 64 ** 2 + 512\n        if proportional_attention:\n            attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim)\n        else:\n            attention_scale = math.sqrt(1 / head_dim)\n\n        hidden_states = self.flex_attn(query, key, value, scale=attention_scale)\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        if encoder_hidden_states is not None:\n            encoder_hidden_states, hidden_states = (\n                hidden_states[:, : encoder_hidden_states.shape[1]],\n                hidden_states[:, encoder_hidden_states.shape[1] :],\n            )\n\n            # linear proj\n            hidden_states = attn.to_out[0](hidden_states)\n            # dropout\n            hidden_states = attn.to_out[1](hidden_states)\n            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)\n            if self.distill:\n                attn_outputs.append(hidden_states)\n\n            return hidden_states, encoder_hidden_states\n        else:\n            if self.distill:\n                attn_outputs.append(hidden_states)\n            return hidden_states"
  },
  {
    "path": "cache_latent_codes.py",
    "content": "import argparse\nimport os\nimport numpy as np\nfrom PIL import Image\nimport math\nfrom safetensors.torch import save_file\nimport torch\nimport tqdm\nfrom diffusers import AutoencoderKL\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--data_root\",\n        type=str,\n        default=None\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=\"resolution\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=0, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=1,\n        help=\"Number of workers\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    args.local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    \n    return args\n\n\ndef main(args):\n    if torch.cuda.is_available():\n        torch.cuda.set_device(args.local_rank)\n    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n    if args.mixed_precision == 'fp16':\n        dtype = torch.float16\n    elif args.mixed_precision == 'bf16':\n        dtype = torch.bfloat16\n    else:\n        dtype = torch.float32\n\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    ).to(device, dtype)\n    \n    all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.jpg' in i or '.png' in i]\n\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    work_load = math.ceil(len(all_info) / args.num_workers)\n    for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)):\n        images = []\n        paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_latent_code.safetensors') for item in all_info[idx:idx + args.batch_size]]\n        for item in all_info[idx:idx + args.batch_size]:\n            img = Image.open(os.path.join(args.data_root, item)).convert('RGB')\n            img = img.resize((args.resolution, args.resolution))\n            img = torch.from_numpy((np.array(img) / 127.5) - 1)\n            img = img.permute(2, 0, 1)\n            images.append(img)\n        with torch.no_grad():\n            images = torch.stack(images, dim=0)\n            data = vae.encode(images.to(device, vae.dtype)).latent_dist\n            means = data.mean.cpu().data\n            stds = data.std.cpu().data\n        for path, mean, std in zip(paths, means.unbind(), stds.unbind()):\n            save_file(\n                {'mean': mean, 'std': std},\n                path\n            )\n\n\nif __name__ == '__main__':\n    main(parse_args())"
  },
  {
    "path": "cache_latent_codes.sh",
    "content": "export NUM_WORKERS=4\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\ntorchrun --nproc_per_node=$NUM_WORKERS cache_latent_codes.py \\\n    --data_root=\"/path/to/t2i_1024\" \\\n    --batch_size=16 \\\n    --num_worker=$NUM_WORKERS \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --mixed_precision='bf16' \\\n    --output_dir=\"/path/to/t2i_1024\"\n    \n"
  },
  {
    "path": "cache_prompt_embeds.py",
    "content": "import argparse\nimport os\nimport json\nimport math\nfrom safetensors.torch import save_file\nimport torch\nimport tqdm\nfrom transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, T5EncoderModel\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--data_root\",\n        type=str,\n        default=None\n    )\n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-lora\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\n        \"--batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=0, help=\"For distributed training: local_rank\")\n    parser.add_argument(\n        \"--num_workers\",\n        type=int,\n        default=1,\n        help=\"Number of workers\",\n    )\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    args.local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n    \n    return args\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    \n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device if device is not None else text_encoders[0].device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[1].device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    return prompt_embeds, pooled_prompt_embeds\n\n\ndef main(args):\n    if torch.cuda.is_available():\n        torch.cuda.set_device(args.local_rank)\n    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n    if args.mixed_precision == 'fp16':\n        dtype = torch.float16\n    elif args.mixed_precision == 'bf16':\n        dtype = torch.bfloat16\n    else:\n        dtype = torch.float32\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n        variant=args.variant,\n        cache_dir=args.cache_dir\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n        variant=args.variant,\n        cache_dir=args.cache_dir\n    )\n\n    text_encoder_one = CLIPTextModel.from_pretrained(\n        args.pretrained_model_name_or_path, \n        revision=args.revision, \n        subfolder=\"text_encoder\",\n        variant=args.variant,\n        cache_dir=args.cache_dir\n    ).to(device, dtype)\n    text_encoder_two = T5EncoderModel.from_pretrained(\n        args.pretrained_model_name_or_path, \n        revision=args.revision, \n        subfolder=\"text_encoder_2\",\n        variant=args.variant,\n        cache_dir=args.cache_dir\n    ).to(device, dtype)\n    tokenizers = [tokenizer_one, tokenizer_two]\n    text_encoders = [text_encoder_one, text_encoder_two]\n\n    all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.json' in i]\n\n    os.makedirs(args.output_dir, exist_ok=True)\n\n    work_load = math.ceil(len(all_info) / args.num_workers)\n    for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)):\n        texts = []\n        for item in all_info[idx:idx + args.batch_size]:\n            with open(os.path.join(args.data_root, item)) as f:\n                texts.append(json.load(f)['prompt'])\n        paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_prompt_embed.safetensors') for item in all_info[idx:idx + args.batch_size]]\n        with torch.no_grad():\n            prompt_embeds, pooled_prompt_embeds = encode_prompt(\n                text_encoders, tokenizers, texts, args.max_sequence_length\n            )\n            prompt_embeds = prompt_embeds.cpu().data\n            pooled_prompt_embeds = pooled_prompt_embeds.cpu().data\n        for path, prompt_embed, pooled_prompt_embed in zip(paths, prompt_embeds.unbind(), pooled_prompt_embeds.unbind()):\n            save_file(\n                {'caption_feature_t5': prompt_embed, 'caption_feature_clip': pooled_prompt_embed},\n                path\n            )\n\n\nif __name__ == '__main__':\n    main(parse_args())"
  },
  {
    "path": "cache_prompt_embeds.sh",
    "content": "export NUM_WORKERS=4\nexport MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\ntorchrun --nproc_per_node=$NUM_WORKERS cache_prompt_embeds.py \\\n    --data_root=\"/path/to/t2i_1024\" \\\n    --batch_size=256 \\\n    --num_worker=$NUM_WORKERS \\\n    --pretrained_model_name_or_path=$MODEL_NAME \\\n    --mixed_precision='bf16' \\\n    --output_dir=\"/path/to/t2i_1024\"\n    \n"
  },
  {
    "path": "dataset.py",
    "content": "import os\nimport io\nimport numpy as np\nfrom PIL import Image\nimport torch\nfrom torch.utils.data import Dataset, DataLoader\nimport json\nimport random\nimport cv2\nfrom safetensors import safe_open\n\n\ndef image_resize(img, max_size=512):\n    w, h = img.size\n    if w >= h:\n        new_w = max_size\n        new_h = int((max_size / w) * h)\n    else:\n        new_h = max_size\n        new_w = int((max_size / h) * w)\n    new_w = (w // 32) * 32\n    new_h = (h // 32) * 32\n    return img.resize((new_w, new_h))\n\n\ndef c_crop(image):\n    width, height = image.size\n    new_size = min(width, height)\n    left = (width - new_size) / 2\n    top = (height - new_size) / 2\n    right = (width + new_size) / 2\n    bottom = (height + new_size) / 2\n    return image.crop((left, top, right, bottom))\n\n\ndef crop_to_aspect_ratio(image, ratio=\"16:9\"):\n    width, height = image.size\n    ratio_map = {\n        \"16:9\": (16, 9),\n        \"4:3\": (4, 3),\n        \"1:1\": (1, 1)\n    }\n    target_w, target_h = ratio_map[ratio]\n    target_ratio_value = target_w / target_h\n\n    current_ratio = width / height\n\n    if current_ratio > target_ratio_value:\n        new_width = int(height * target_ratio_value)\n        offset = (width - new_width) // 2\n        crop_box = (offset, 0, offset + new_width, height)\n    else:\n        new_height = int(width / target_ratio_value)\n        offset = (height - new_height) // 2\n        crop_box = (0, offset, width, offset + new_height)\n\n    cropped_img = image.crop(crop_box)\n    return cropped_img\n\n\nclass CustomImageDataset(Dataset):\n    def __init__(self, img_dir, img_size=512, caption_type='json', \n                 random_ratio=False, use_cached_prompt_embeds=False, use_cached_latent_codes=False):\n        self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]\n        self.images.sort()\n        self.img_size = img_size\n        self.caption_type = caption_type\n        self.random_ratio = random_ratio\n        self.use_cached_prompt_embeds = use_cached_prompt_embeds\n        self.use_cached_latent_codes = use_cached_latent_codes\n\n    def __len__(self):\n        return len(self.images)\n\n    def __getitem__(self, idx):\n        try:\n            batch = {}\n            if not self.use_cached_latent_codes:\n                img = Image.open(self.images[idx]).convert('RGB')\n                if self.random_ratio:\n                    ratio = random.choice([\"16:9\", \"default\", \"1:1\", \"4:3\"])\n                    if ratio != \"default\":\n                        img = crop_to_aspect_ratio(img, ratio)\n                img = image_resize(img, self.img_size)\n                img = torch.from_numpy((np.array(img) / 127.5) - 1)\n                img = img.permute(2, 0, 1)\n                batch['images'] = img\n            else:\n                with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_latent_code.safetensors', framework=\"pt\") as f:\n                    batch['latent_codes_mean'] = f.get_tensor('mean')\n                    batch['latent_codes_std'] = f.get_tensor('std')\n            if not self.use_cached_prompt_embeds:\n                json_path = self.images[idx].split('.')[0] + '.' + self.caption_type\n                if self.caption_type == \"json\":\n                    prompt = json.load(open(json_path))['prompt']\n                else:\n                    prompt = open(json_path).read()\n                batch['prompts'] = prompt\n            else:\n                with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_prompt_embed.safetensors', framework=\"pt\") as f:\n                    batch['prompt_embeds_t5'] = f.get_tensor('caption_feature_t5')\n                    batch['prompt_embeds_clip'] = f.get_tensor('caption_feature_clip')\n            return batch\n        except Exception as e:\n            print(e)\n            return self.__getitem__(random.randint(0, len(self.images) - 1))\n        \n\ndef loader(train_batch_size, num_workers, **args):\n    dataset = CustomImageDataset(**args)\n    return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)"
  },
  {
    "path": "deepspeed_config.yaml",
    "content": "compute_environment: LOCAL_MACHINE\ndebug: false\ndeepspeed_config:\n  gradient_accumulation_steps: 4\n  gradient_clipping: 1.0\n  offload_optimizer_device: none\n  offload_param_device: none\n  zero3_init_flag: false\n  zero_stage: 2\ndistributed_type: DEEPSPEED\ndowncast_bf16: 'no'\nenable_cpu_affinity: false\nmachine_rank: 0\nmain_training_function: main\nmixed_precision: bf16\nnum_machines: 1\nnum_processes: 4\nrdzv_backend: static\nsame_network: true\ntpu_env: []\ntpu_use_cluster: false\ntpu_use_sudo: false\nuse_cpu: false\n"
  },
  {
    "path": "distill.py",
    "content": "import argparse\nimport copy\nimport logging\nimport math\nimport os\nimport shutil\nfrom contextlib import nullcontext\nfrom pathlib import Path\n\nimport numpy as np\nimport torch\nimport torch.utils.checkpoint\nfrom torch.utils.data import RandomSampler\nimport transformers\nfrom accelerate import Accelerator\nfrom accelerate.logging import get_logger\nfrom accelerate.utils import ProjectConfiguration, set_seed\n\nfrom tqdm.auto import tqdm\nfrom transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast\n\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    FlowMatchEulerDiscreteScheduler,\n    FluxPipeline,\n    FluxTransformer2DModel,\n)\nfrom diffusers.optimization import get_scheduler\nfrom diffusers.training_utils import (\n    cast_training_params,\n    compute_density_for_timestep_sampling,\n    compute_loss_weighting_for_sd3,\n    free_memory,\n)\nfrom diffusers.utils import (\n    check_min_version,\n    is_wandb_available,\n)\nfrom diffusers.utils.torch_utils import is_compiled_module\nfrom safetensors.torch import save_file\nfrom safetensors import safe_open\nfrom dataset import loader\nfrom attention_processor import FluxAttnProcessor2_0, LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex\nfrom attention_processor import attn_outputs_teacher, attn_outputs\n\n\nif is_wandb_available():\n    import wandb\n\n# Will error if the minimal version of diffusers is not installed. Remove at your own risks.\ncheck_min_version(\"0.31.0.dev0\")\n\nlogger = get_logger(__name__)\n\n\ndef load_text_encoders(class_one, class_two):\n    text_encoder_one = class_one.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder\", revision=args.revision, variant=args.variant\n    )\n    text_encoder_two = class_two.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"text_encoder_2\", revision=args.revision, variant=args.variant\n    )\n    return text_encoder_one, text_encoder_two\n\n\ndef log_validation(\n    pipeline,\n    args,\n    accelerator,\n    pipeline_args,\n    epoch,\n    torch_dtype,\n    is_final_validation=False,\n):\n    logger.info(\n        f\"Running validation... \\n Generating {args.num_validation_images} images with prompt:\"\n        f\" {args.validation_prompt}.\"\n    )\n    pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)\n    pipeline.set_progress_bar_config(disable=True)\n\n    # run inference\n    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None\n    # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()\n    autocast_ctx = nullcontext()\n\n    with autocast_ctx:\n        images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]\n\n    for tracker in accelerator.trackers:\n        phase_name = \"test\" if is_final_validation else \"validation\"\n        if tracker.name == \"tensorboard\":\n            np_images = np.stack([np.asarray(img) for img in images])\n            tracker.writer.add_images(phase_name, np_images, epoch, dataformats=\"NHWC\")\n        if tracker.name == \"wandb\":\n            tracker.log(\n                {\n                    phase_name: [\n                        wandb.Image(image, caption=f\"{i}: {args.validation_prompt}\") for i, image in enumerate(images)\n                    ]\n                }\n            )\n\n    del pipeline\n    if torch.cuda.is_available():\n        torch.cuda.empty_cache()\n\n    return images\n\n\ndef import_model_class_from_model_name_or_path(\n    pretrained_model_name_or_path: str, revision: str, subfolder: str = \"text_encoder\"\n):\n    text_encoder_config = PretrainedConfig.from_pretrained(\n        pretrained_model_name_or_path, subfolder=subfolder, revision=revision\n    )\n    model_class = text_encoder_config.architectures[0]\n    if model_class == \"CLIPTextModel\":\n        from transformers import CLIPTextModel\n\n        return CLIPTextModel\n    elif model_class == \"T5EncoderModel\":\n        from transformers import T5EncoderModel\n\n        return T5EncoderModel\n    else:\n        raise ValueError(f\"{model_class} is not supported.\")\n\n\ndef parse_args(input_args=None):\n    parser = argparse.ArgumentParser(description=\"Simple example of a training script.\")\n    parser.add_argument(\n        \"--pretrained_model_name_or_path\",\n        type=str,\n        default=None,\n        required=True,\n        help=\"Path to pretrained model or model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--revision\",\n        type=str,\n        default=None,\n        required=False,\n        help=\"Revision of pretrained model identifier from huggingface.co/models.\",\n    )\n    parser.add_argument(\n        \"--variant\",\n        type=str,\n        default=None,\n        help=\"Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16\",\n    )\n    parser.add_argument(\n        \"--data_root\",\n        type=str,\n        default=None\n    )\n    \n    parser.add_argument(\n        \"--cache_dir\",\n        type=str,\n        default=None,\n        help=\"The directory where the downloaded models and datasets will be stored.\",\n    )\n\n    parser.add_argument(\"--repeats\", type=int, default=1, help=\"How many times to repeat the training data.\")\n\n    parser.add_argument(\n        \"--max_sequence_length\",\n        type=int,\n        default=512,\n        help=\"Maximum sequence length to use with with the T5 text encoder\",\n    )\n    parser.add_argument(\n        \"--validation_prompt\",\n        type=str,\n        default=None,\n        help=\"A prompt that is used during validation to verify that the model is learning.\",\n    )\n    parser.add_argument(\n        \"--num_validation_images\",\n        type=int,\n        default=4,\n        help=\"Number of images that should be generated during validation with `validation_prompt`.\",\n    )\n    parser.add_argument(\n        \"--validation_epochs\",\n        type=int,\n        default=1,\n        help=(\n            \"Run validation every X epochs. Validation consists of running the prompt\"\n            \" `args.validation_prompt` multiple times: `args.num_validation_images`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--output_dir\",\n        type=str,\n        default=\"flux-attn\",\n        help=\"The output directory where the model predictions and checkpoints will be written.\",\n    )\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"A seed for reproducible training.\")\n    parser.add_argument(\n        \"--resolution\",\n        type=int,\n        default=1024,\n        help=(\n            \"The resolution for input images\"\n        )\n    )\n    parser.add_argument(\n        \"--center_crop\",\n        default=False,\n        action=\"store_true\",\n        help=(\n            \"Whether to center crop the input images to the resolution. If not set, the images will be randomly\"\n            \" cropped. The images will be resized to the resolution first before cropping.\"\n        ),\n    )\n    parser.add_argument(\n        \"--train_batch_size\", type=int, default=4, help=\"Batch size (per device) for the training dataloader.\"\n    )\n    parser.add_argument(\n        \"--sample_batch_size\", type=int, default=4, help=\"Batch size (per device) for sampling images.\"\n    )\n    parser.add_argument(\"--num_train_epochs\", type=int, default=1)\n    parser.add_argument(\n        \"--max_train_steps\",\n        type=int,\n        default=None,\n        help=\"Total number of training steps to perform.  If provided, overrides num_train_epochs.\",\n    )\n    parser.add_argument(\n        \"--window_size\",\n        type=int,\n        default=16,\n        help=(\n            \"Size of local window for attention sampling.\"\n        ),\n    )\n    parser.add_argument(\n        \"--down_factor\",\n        type=int,\n        default=1,\n        help=(\n            \"Factor of downsampling for key-value tokens.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpointing_steps\",\n        type=int,\n        default=10000,\n        help=(\n            \"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final\"\n            \" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming\"\n            \" training using `--resume_from_checkpoint`.\"\n        ),\n    )\n    parser.add_argument(\n        \"--checkpoints_total_limit\",\n        type=int,\n        default=None,\n        help=(\"Max number of checkpoints to store.\"),\n    )\n    parser.add_argument(\n        \"--resume_from_checkpoint\",\n        type=str,\n        default=None,\n        help=(\n            \"Whether training should be resumed from a previous checkpoint. Use a path saved by\"\n            ' `--checkpointing_steps`, or `\"latest\"` to automatically select the last available checkpoint.'\n        ),\n    )\n    parser.add_argument(\n        \"--gradient_accumulation_steps\",\n        type=int,\n        default=1,\n        help=\"Number of updates steps to accumulate before performing a backward/update pass.\",\n    )\n    parser.add_argument(\n        \"--gradient_checkpointing\",\n        action=\"store_true\",\n        help=\"Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.\",\n    )\n    parser.add_argument(\n        \"--learning_rate\",\n        type=float,\n        default=1e-4,\n        help=\"Initial learning rate (after the potential warmup period) to use.\",\n    )\n\n    parser.add_argument(\n        \"--guidance_scale\",\n        type=float,\n        default=3.5,\n        help=\"the FLUX.1 dev variant is a guidance distilled model\",\n    )\n\n    parser.add_argument(\n        \"--scale_lr\",\n        action=\"store_true\",\n        default=False,\n        help=\"Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.\",\n    )\n    parser.add_argument(\n        \"--lr_scheduler\",\n        type=str,\n        default=\"constant\",\n        help=(\n            'The scheduler type to use. Choose between [\"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\",'\n            ' \"constant\", \"constant_with_warmup\"]'\n        ),\n    )\n    parser.add_argument(\n        \"--lr_warmup_steps\", type=int, default=500, help=\"Number of steps for the warmup in the lr scheduler.\"\n    )\n    parser.add_argument(\n        \"--lr_num_cycles\",\n        type=int,\n        default=1,\n        help=\"Number of hard resets of the lr in cosine_with_restarts scheduler.\",\n    )\n    parser.add_argument(\"--lr_power\", type=float, default=1.0, help=\"Power factor of the polynomial scheduler.\")\n    parser.add_argument(\n        \"--dataloader_num_workers\",\n        type=int,\n        default=0,\n        help=(\n            \"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.\"\n        ),\n    )\n    parser.add_argument(\n        \"--weighting_scheme\",\n        type=str,\n        default=\"none\",\n        choices=[\"sigma_sqrt\", \"logit_normal\", \"mode\", \"cosmap\", \"none\"],\n        help=('We default to the \"none\" weighting scheme for uniform sampling and uniform loss'),\n    )\n    parser.add_argument(\n        \"--logit_mean\", type=float, default=0.0, help=\"mean to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--logit_std\", type=float, default=1.0, help=\"std to use when using the `'logit_normal'` weighting scheme.\"\n    )\n    parser.add_argument(\n        \"--mode_scale\",\n        type=float,\n        default=1.29,\n        help=\"Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.\",\n    )\n    parser.add_argument(\n        \"--optimizer\",\n        type=str,\n        default=\"AdamW\",\n        help=('The optimizer type to use. Choose between [\"AdamW\", \"prodigy\"]'),\n    )\n\n    parser.add_argument(\n        \"--use_8bit_adam\",\n        action=\"store_true\",\n        help=\"Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW\",\n    )\n\n    parser.add_argument(\n        \"--adam_beta1\", type=float, default=0.9, help=\"The beta1 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--adam_beta2\", type=float, default=0.999, help=\"The beta2 parameter for the Adam and Prodigy optimizers.\"\n    )\n    parser.add_argument(\n        \"--prodigy_beta3\",\n        type=float,\n        default=None,\n        help=\"coefficients for computing the Prodigy stepsize using running averages. If set to None, \"\n        \"uses the value of square root of beta2. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--prodigy_decouple\", type=bool, default=True, help=\"Use AdamW style decoupled weight decay\")\n    parser.add_argument(\"--adam_weight_decay\", type=float, default=1e-04, help=\"Weight decay to use for unet params\")\n    \n    parser.add_argument(\n        \"--adam_epsilon\",\n        type=float,\n        default=1e-08,\n        help=\"Epsilon value for the Adam optimizer and Prodigy optimizers.\",\n    )\n\n    parser.add_argument(\n        \"--prodigy_use_bias_correction\",\n        type=bool,\n        default=True,\n        help=\"Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\n        \"--prodigy_safeguard_warmup\",\n        type=bool,\n        default=True,\n        help=\"Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. \"\n        \"Ignored if optimizer is adamW\",\n    )\n    parser.add_argument(\"--max_grad_norm\", default=1.0, type=float, help=\"Max gradient norm.\")\n    parser.add_argument(\n        \"--logging_dir\",\n        type=str,\n        default=\"logs\",\n        help=(\n            \"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to\"\n            \" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***.\"\n        ),\n    )\n    parser.add_argument(\n        \"--allow_tf32\",\n        action=\"store_true\",\n        help=(\n            \"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see\"\n            \" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\"\n        ),\n    )\n    parser.add_argument(\n        \"--use_cached_latent\",\n        action=\"store_true\",\n        default=False,\n        help=\"Cache the VAE latents\",\n    )\n    parser.add_argument(\n        \"--use_cached_prompt_embed\",\n        action=\"store_true\",\n        default=False,\n        help=\"Use cached T5 and CLIP features\",\n    )\n    parser.add_argument(\n        \"--report_to\",\n        type=str,\n        default=\"tensorboard\",\n        help=(\n            'The integration to report the results and logs to. Supported platforms are `\"tensorboard\"`'\n            ' (default), `\"wandb\"` and `\"comet_ml\"`. Use `\"all\"` to report to all integrations.'\n        ),\n    )\n    parser.add_argument(\n        \"--mixed_precision\",\n        type=str,\n        default=None,\n        choices=[\"no\", \"fp16\", \"bf16\"],\n        help=(\n            \"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=\"\n            \" 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the\"\n            \" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config.\"\n        ),\n    )\n    parser.add_argument(\n        \"--upcast_before_saving\",\n        action=\"store_true\",\n        default=False,\n        help=(\n            \"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). \"\n            \"Defaults to precision dtype used for training to save memory\"\n        ),\n    )\n    parser.add_argument(\"--local_rank\", type=int, default=-1, help=\"For distributed training: local_rank\")\n\n    if input_args is not None:\n        args = parser.parse_args(input_args)\n    else:\n        args = parser.parse_args()\n\n    env_local_rank = int(os.environ.get(\"LOCAL_RANK\", -1))\n    if env_local_rank != -1 and env_local_rank != args.local_rank:\n        args.local_rank = env_local_rank\n\n    return args\n\n\ndef tokenize_prompt(tokenizer, prompt, max_sequence_length):\n    text_inputs = tokenizer(\n        prompt,\n        padding=\"max_length\",\n        max_length=max_sequence_length,\n        truncation=True,\n        return_length=False,\n        return_overflowing_tokens=False,\n        return_tensors=\"pt\",\n    )\n    text_input_ids = text_inputs.input_ids\n    return text_input_ids\n\n\ndef _encode_prompt_with_t5(\n    text_encoder,\n    tokenizer,\n    max_sequence_length=512,\n    prompt=None,\n    num_images_per_prompt=1,\n    device=None,\n    text_input_ids=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device))[0]\n\n    dtype = text_encoder.dtype\n    prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n    _, seq_len, _ = prompt_embeds.shape\n\n    # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n    return prompt_embeds\n\n\ndef _encode_prompt_with_clip(\n    text_encoder,\n    tokenizer,\n    prompt: str,\n    device=None,\n    text_input_ids=None,\n    num_images_per_prompt: int = 1,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n\n    if tokenizer is not None:\n        text_inputs = tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=77,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n    else:\n        if text_input_ids is None:\n            raise ValueError(\"text_input_ids must be provided when the tokenizer is not specified\")\n\n    prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n    # Use pooled output of CLIPTextModel\n    prompt_embeds = prompt_embeds.pooler_output\n    prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)\n\n    # duplicate text embeddings for each generation per prompt, using mps friendly method\n    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n    prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n    return prompt_embeds\n\n\ndef encode_prompt(\n    text_encoders,\n    tokenizers,\n    prompt: str,\n    max_sequence_length,\n    device=None,\n    num_images_per_prompt: int = 1,\n    text_input_ids_list=None,\n):\n    prompt = [prompt] if isinstance(prompt, str) else prompt\n    batch_size = len(prompt)\n    dtype = text_encoders[0].dtype\n\n    pooled_prompt_embeds = _encode_prompt_with_clip(\n        text_encoder=text_encoders[0],\n        tokenizer=tokenizers[0],\n        prompt=prompt,\n        device=device if device is not None else text_encoders[0].device,\n        num_images_per_prompt=num_images_per_prompt,\n        text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,\n    )\n\n    prompt_embeds = _encode_prompt_with_t5(\n        text_encoder=text_encoders[1],\n        tokenizer=tokenizers[1],\n        max_sequence_length=max_sequence_length,\n        prompt=prompt,\n        num_images_per_prompt=num_images_per_prompt,\n        device=device if device is not None else text_encoders[1].device,\n        text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,\n    )\n\n    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n    \n    return prompt_embeds, pooled_prompt_embeds, text_ids\n\n\ndef main(args):\n    if torch.backends.mps.is_available() and args.mixed_precision == \"bf16\":\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    logging_dir = Path(args.output_dir, args.logging_dir)\n\n    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)\n    accelerator = Accelerator(\n        gradient_accumulation_steps=args.gradient_accumulation_steps,\n        mixed_precision=args.mixed_precision,\n        log_with=args.report_to,\n        project_config=accelerator_project_config\n    )\n\n    # Disable AMP for MPS.\n    if torch.backends.mps.is_available():\n        accelerator.native_amp = False\n\n    if args.report_to == \"wandb\":\n        if not is_wandb_available():\n            raise ImportError(\"Make sure to install wandb if you want to use it for logging during training.\")\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    logger.info(accelerator.state, main_process_only=False)\n    if accelerator.is_local_main_process:\n        transformers.utils.logging.set_verbosity_warning()\n        diffusers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n        diffusers.utils.logging.set_verbosity_error()\n\n    # If passed along, set the training seed now.\n    if args.seed is not None:\n        set_seed(args.seed)\n\n    # Handle the repository creation\n    if accelerator.is_main_process:\n        if args.output_dir is not None:\n            os.makedirs(args.output_dir, exist_ok=True)\n\n    # Load the tokenizers\n    tokenizer_one = CLIPTokenizer.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer\",\n        revision=args.revision,\n    )\n    tokenizer_two = T5TokenizerFast.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"tokenizer_2\",\n        revision=args.revision,\n    )\n\n    # import correct text encoder classes\n    text_encoder_cls_one = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision\n    )\n    text_encoder_cls_two = import_model_class_from_model_name_or_path(\n        args.pretrained_model_name_or_path, args.revision, subfolder=\"text_encoder_2\"\n    )\n\n    # Load scheduler and models\n    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"scheduler\"\n    )\n    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n    vae = AutoencoderKL.from_pretrained(\n        args.pretrained_model_name_or_path,\n        subfolder=\"vae\",\n        revision=args.revision,\n        variant=args.variant,\n    )\n    transformer = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n    transformer_teacher = FluxTransformer2DModel.from_pretrained(\n        args.pretrained_model_name_or_path, subfolder=\"transformer\", revision=args.revision, variant=args.variant\n    )\n\n    if args.down_factor == 1:\n        init_local_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size, device=accelerator.device)\n        attn_processors = {}\n        for idx, name in enumerate(transformer.attn_processors):\n            attn_processors[name] = LocalFlexAttnProcessor(distill='single' in name and idx % 4 == 0)\n    else:\n        init_local_downsample_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size, \n                                        down_factor=args.down_factor, device=accelerator.device)\n        attn_processors = {}\n        for idx, name in enumerate(transformer.attn_processors):\n            attn_processors[name] = LocalDownsampleFlexAttnProcessor(down_factor=args.down_factor, distill='single' in name and idx % 4 == 0)\n    attn_processors_teacher = {}\n    for idx, name in enumerate(transformer_teacher.attn_processors.keys()):\n        attn_processors_teacher[name] = FluxAttnProcessor2_0(distill='single' in name and idx % 4 == 0)\n    transformer.set_attn_processor(attn_processors)\n    transformer_teacher.set_attn_processor(attn_processors_teacher)\n\n    # We only train the Attn layers\n    transformer.requires_grad_(False)\n    transformer_teacher.requires_grad_(False)\n    vae.requires_grad_(False)\n    text_encoder_one.requires_grad_(False)\n    text_encoder_two.requires_grad_(False)\n\n    # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision\n    # as these weights are only used for inference, keeping weights in full precision is not required.\n    weight_dtype = torch.float32\n    if accelerator.mixed_precision == \"fp16\":\n        weight_dtype = torch.float16\n    elif accelerator.mixed_precision == \"bf16\":\n        weight_dtype = torch.bfloat16\n        \n    if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:\n        # due to pytorch#99272, MPS does not yet support bfloat16.\n        raise ValueError(\n            \"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead.\"\n        )\n\n    vae.to(accelerator.device, dtype=weight_dtype)\n    transformer.to(accelerator.device, dtype=weight_dtype)\n    transformer_teacher.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_one.to(accelerator.device, dtype=weight_dtype)\n    text_encoder_two.to(accelerator.device, dtype=weight_dtype)\n    \n    if args.gradient_checkpointing:\n        transformer.enable_gradient_checkpointing()\n        \n    for _name, _param in transformer.named_parameters():\n        if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name:\n            _param.requires_grad = True\n    \n    def unwrap_model(model):\n        model = accelerator.unwrap_model(model)\n        model = model._orig_mod if is_compiled_module(model) else model\n        return model\n\n    # Enable TF32 for faster training on Ampere GPUs,\n    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices\n    if args.allow_tf32 and torch.cuda.is_available():\n        torch.backends.cuda.matmul.allow_tf32 = True\n\n    if args.scale_lr:\n        args.learning_rate = (\n            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes\n        )\n\n    # Make sure the trainable params are in float32.\n    if args.mixed_precision == \"fp16\":\n        models = [transformer]\n        # only upcast trainable parameters (Attn) into fp32\n        cast_training_params(models, dtype=torch.float32)\n\n    # Optimization parameters\n    transformer_attn_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))\n    transformer_parameters_with_lr = {\"params\": transformer_attn_parameters, \"lr\": args.learning_rate}\n    \n    params_to_optimize = [transformer_parameters_with_lr]\n\n    # Optimizer creation\n    if not (args.optimizer.lower() == \"prodigy\" or args.optimizer.lower() == \"adamw\"):\n        logger.warning(\n            f\"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy].\"\n            \"Defaulting to adamW\"\n        )\n        args.optimizer = \"adamw\"\n\n    if args.use_8bit_adam and not args.optimizer.lower() == \"adamw\":\n        logger.warning(\n            f\"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was \"\n            f\"set to {args.optimizer.lower()}\"\n        )\n\n    if args.optimizer.lower() == \"adamw\":\n        if args.use_8bit_adam:\n            try:\n                import bitsandbytes as bnb\n            except ImportError:\n                raise ImportError(\n                    \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n                )\n\n            optimizer_class = bnb.optim.AdamW8bit\n        else:\n            optimizer_class = torch.optim.AdamW\n\n        optimizer = optimizer_class(\n            params_to_optimize,\n            betas=(args.adam_beta1, args.adam_beta2),\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n        )\n\n    if args.optimizer.lower() == \"prodigy\":\n        try:\n            import prodigyopt\n        except ImportError:\n            raise ImportError(\"To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`\")\n\n        optimizer_class = prodigyopt.Prodigy\n\n        if args.learning_rate <= 0.1:\n            logger.warning(\n                \"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0\"\n            )\n        \n        optimizer = optimizer_class(\n            params_to_optimize,\n            lr=args.learning_rate,\n            betas=(args.adam_beta1, args.adam_beta2),\n            beta3=args.prodigy_beta3,\n            weight_decay=args.adam_weight_decay,\n            eps=args.adam_epsilon,\n            decouple=args.prodigy_decouple,\n            use_bias_correction=args.prodigy_use_bias_correction,\n            safeguard_warmup=args.prodigy_safeguard_warmup,\n        )\n\n    # Dataset and DataLoaders creation:\n    train_dataloader = loader(train_batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers,\n                              img_dir=args.data_root, img_size=args.resolution, \n                              use_cached_prompt_embeds=args.use_cached_prompt_embed,\n                              use_cached_latent_codes=args.use_cached_latent)\n\n    tokenizers = [tokenizer_one, tokenizer_two]\n    text_encoders = [text_encoder_one, text_encoder_two]\n\n    def compute_text_embeddings(prompt, text_encoders, tokenizers):\n        with torch.no_grad():\n            prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(\n                text_encoders, tokenizers, prompt, args.max_sequence_length\n            )\n            prompt_embeds = prompt_embeds.to(accelerator.device)\n            pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)\n            text_ids = text_ids.to(accelerator.device)\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n   \n    # Clear the memory here\n    if args.use_cached_prompt_embed:\n        del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, text_encoders, tokenizers\n        free_memory()\n        \n    vae_config_shift_factor = vae.config.shift_factor\n    vae_config_scaling_factor = vae.config.scaling_factor\n    vae_config_block_out_channels = vae.config.block_out_channels\n    if args.use_cached_latent:\n        del vae\n        free_memory()\n\n    # Scheduler and math around the number of training steps.\n    overrode_max_train_steps = False\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if args.max_train_steps is None:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n        overrode_max_train_steps = True\n\n    lr_scheduler = get_scheduler(\n        args.lr_scheduler,\n        optimizer=optimizer,\n        num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,\n        num_training_steps=args.max_train_steps * accelerator.num_processes,\n        num_cycles=args.lr_num_cycles,\n        power=args.lr_power,\n    )\n\n    # Prepare everything with our `accelerator`.\n    guidance_embeds = transformer.config.guidance_embeds\n    \n    transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n        transformer, optimizer, train_dataloader, lr_scheduler\n    )\n\n    # We need to recalculate our total training steps as the size of the training dataloader may have changed.\n    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)\n    if overrode_max_train_steps:\n        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch\n    # Afterwards we recalculate our number of training epochs\n    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)\n\n    # We need to initialize the trackers we use, and also store our configuration.\n    # The trackers initializes automatically on the main process.\n    if accelerator.is_main_process:\n        tracker_name = \"flux-dev-attn\"\n        accelerator.init_trackers(tracker_name, config=vars(args))\n\n    # Train!\n    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num batches each epoch = {len(train_dataloader)}\")\n    logger.info(f\"  Num Epochs = {args.num_train_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {args.train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n    logger.info(f\"  Gradient Accumulation steps = {args.gradient_accumulation_steps}\")\n    logger.info(f\"  Total optimization steps = {args.max_train_steps}\")\n    global_step = 0\n    first_epoch = 0\n\n    # Potentially load in the weights and states from a previous save\n    if args.resume_from_checkpoint:\n        if args.resume_from_checkpoint != \"latest\":\n            path = args.resume_from_checkpoint\n        else:\n            # Get the mos recent checkpoint\n            dirs = os.listdir(args.output_dir)\n            dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n            dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n            path = dirs[-1] if len(dirs) > 0 else None\n\n        if path is None:\n            accelerator.print(\n                f\"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run.\"\n            )\n            args.resume_from_checkpoint = None\n            initial_global_step = 0\n        else:\n            accelerator.print(f\"Resuming from checkpoint {path}\")\n            accelerator.load_state(path)\n            global_step = int(path.split(\"-\")[-1])\n\n            initial_global_step = global_step\n            first_epoch = global_step // num_update_steps_per_epoch\n\n    else:\n        initial_global_step = 0\n\n    progress_bar = tqdm(\n        range(0, args.max_train_steps),\n        initial=initial_global_step,\n        desc=\"Steps\",\n        # Only show the progress bar once on each machine.\n        disable=not accelerator.is_local_main_process,\n    )\n\n    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):\n        sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)\n        schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)\n        timesteps = timesteps.to(accelerator.device)\n        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]\n\n        sigma = sigmas[step_indices].flatten()\n        while len(sigma.shape) < n_dim:\n            sigma = sigma.unsqueeze(-1)\n        return sigma\n\n    for epoch in range(first_epoch, args.num_train_epochs):\n        transformer.train()\n        \n        for _, batch in enumerate(train_dataloader):\n            with accelerator.accumulate(transformer):\n\n                if not args.use_cached_prompt_embed:\n                    prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(\n                        batch['prompts'], text_encoders, tokenizers\n                    )\n                else:\n                    prompt_embeds = batch['prompt_embeds_t5'].to(dtype=weight_dtype)\n                    pooled_prompt_embeds = batch['prompt_embeds_clip'].to(dtype=weight_dtype)\n                    text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=weight_dtype)\n\n                with torch.no_grad():\n                    # Convert images to latent space\n                    if args.use_cached_latent:\n                        mean = batch['latent_codes_mean'].to(dtype=weight_dtype)\n                        std = batch['latent_codes_std'].to(dtype=weight_dtype)\n                        sample = torch.randn_like(mean)\n                        model_input = mean + std * sample\n                    else:\n                        model_input = vae.encode(batch[0].to(dtype=vae.dtype)).latent_dist.sample()\n                    model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor\n                    \n                    vae_scale_factor = 2 ** (len(vae_config_block_out_channels))\n\n                    latent_image_ids = FluxPipeline._prepare_latent_image_ids(\n                        model_input.shape[0],\n                        model_input.shape[2],\n                        model_input.shape[3],\n                        accelerator.device,\n                        weight_dtype,\n                    )\n                    # Sample noise that we'll add to the latents\n                    noise = torch.randn_like(model_input)\n                    bsz = model_input.shape[0]\n\n                    # Sample a random timestep for each image\n                    # for weighting schemes where we sample timesteps non-uniformly\n                    u = compute_density_for_timestep_sampling(\n                        weighting_scheme=args.weighting_scheme,\n                        batch_size=bsz,\n                        logit_mean=args.logit_mean,\n                        logit_std=args.logit_std,\n                        mode_scale=args.mode_scale,\n                    )\n                    indices = (u * noise_scheduler.config.num_train_timesteps).long()\n                    timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)\n\n                    # Add noise according to flow matching.\n                    # zt = (1 - texp) * x + texp * z1\n                    sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)\n                    noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise\n\n                    packed_noisy_model_input = FluxPipeline._pack_latents(\n                        noisy_model_input,\n                        batch_size=model_input.shape[0],\n                        num_channels_latents=model_input.shape[1],\n                        height=model_input.shape[2],\n                        width=model_input.shape[3],\n                    )\n\n                    # handle guidance\n                    if guidance_embeds:\n                        guidance = torch.tensor([args.guidance_scale], device=accelerator.device)\n                        guidance = guidance.expand(model_input.shape[0])\n                    else:\n                        guidance = None\n\n                    teacher_pred = transformer_teacher(\n                        hidden_states=packed_noisy_model_input,\n                        # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)\n                        timestep=timesteps / 1000,\n                        guidance=guidance,\n                        pooled_projections=pooled_prompt_embeds,\n                        encoder_hidden_states=prompt_embeds,\n                        txt_ids=text_ids,\n                        img_ids=latent_image_ids,\n                        return_dict=False,\n                    )[0]\n                    teacher_pred = FluxPipeline._unpack_latents(\n                        teacher_pred,\n                        height=int(model_input.shape[2] * vae_scale_factor / 2),\n                        width=int(model_input.shape[3] * vae_scale_factor / 2),\n                        vae_scale_factor=vae_scale_factor,\n                    )\n\n                # Predict the noise residual\n                model_pred = transformer(\n                    hidden_states=packed_noisy_model_input,\n                    # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)\n                    timestep=timesteps / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    return_dict=False,\n                )[0]\n                model_pred = FluxPipeline._unpack_latents(\n                    model_pred,\n                    height=int(model_input.shape[2] * vae_scale_factor / 2),\n                    width=int(model_input.shape[3] * vae_scale_factor / 2),\n                    vae_scale_factor=vae_scale_factor,\n                )\n\n                # these weighting schemes use a uniform timestep sampling\n                # and instead post-weight the loss\n                weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)\n\n                # flow matching loss\n                target = noise - model_input\n\n                # Compute regular loss.\n                loss_fm = (weighting.float() * (model_pred.float() - target.float()) ** 2).mean()\n                loss_distill = (weighting.float() * (model_pred.float() - teacher_pred.float()) ** 2).mean()\n                loss_attn = sum([(weighting.float().squeeze(-1) * (attn_output.float() - attn_output_teacher.float()) ** 2).mean()\n                    for attn_output, attn_output_teacher in zip(attn_outputs, attn_outputs_teacher)]) / len(attn_outputs)\n                loss = loss_fm + loss_distill * 0.5 + loss_attn * 0.5\n                \n                accelerator.backward(loss)\n                if accelerator.sync_gradients:\n                    params_to_clip = (\n                        transformer.parameters()\n                    )\n                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)\n\n                optimizer.step()\n                lr_scheduler.step()\n                optimizer.zero_grad()\n                attn_outputs.clear()\n                attn_outputs_teacher.clear()\n\n            # Checks if the accelerator has performed an optimization step behind the scenes\n            if accelerator.sync_gradients:\n                progress_bar.update(1)\n                global_step += 1\n                \n                if accelerator.is_main_process:\n                    \n                    if global_step % args.checkpointing_steps == 0:\n                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`\n                        if args.checkpoints_total_limit is not None:\n                            checkpoints = os.listdir(args.output_dir)\n                            checkpoints = [d for d in checkpoints if d.startswith(\"checkpoint\")]\n                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split(\"-\")[1]))\n\n                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints\n                            if len(checkpoints) >= args.checkpoints_total_limit:\n                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1\n                                removing_checkpoints = checkpoints[0:num_to_remove]\n\n                                logger.info(\n                                    f\"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints\"\n                                )\n                                logger.info(f\"removing checkpoints: {', '.join(removing_checkpoints)}\")\n\n                                for removing_checkpoint in removing_checkpoints:\n                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)\n                                    shutil.rmtree(removing_checkpoint)\n\n                if global_step % args.checkpointing_steps == 0:\n                    save_path = os.path.join(args.output_dir, f\"checkpoint-{global_step}\")\n                    accelerator.save_state(save_path)\n\n                    logger.info(f\"Saved state to {save_path}\")\n\n            logs = {\"loss_fm\": loss_fm.detach().item(), \n                    \"loss_distill\": loss_distill.detach().item(), \n                    \"loss_attn\": loss_attn.detach().item(), \n                    \"loss\": loss.detach().item(), \n                    \"lr\": lr_scheduler.get_last_lr()[0]}\n            progress_bar.set_postfix(**logs)\n            accelerator.log(logs, step=global_step)\n\n            if global_step >= args.max_train_steps:\n                break\n\n        free_memory()\n        if accelerator.is_main_process:\n            if args.validation_prompt is not None and epoch % args.validation_epochs == 0:\n                # create pipeline\n                if args.use_cached_prompt_embed:\n                    text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)\n                if args.use_cached_latent:\n                    vae = AutoencoderKL.from_pretrained(\n                        args.pretrained_model_name_or_path,\n                        subfolder=\"vae\",\n                        revision=args.revision,\n                        variant=args.variant,\n                    )\n                pipeline = FluxPipeline.from_pretrained(\n                    args.pretrained_model_name_or_path,\n                    vae=vae,\n                    text_encoder=accelerator.unwrap_model(text_encoder_one),\n                    text_encoder_2=accelerator.unwrap_model(text_encoder_two),\n                    transformer=accelerator.unwrap_model(transformer),\n                    revision=args.revision,\n                    variant=args.variant,\n                    torch_dtype=weight_dtype,\n                )\n                pipeline_args = {\"prompt\": args.validation_prompt}\n                images = log_validation(\n                    pipeline=pipeline,\n                    args=args,\n                    accelerator=accelerator,\n                    pipeline_args=pipeline_args,\n                    epoch=epoch,\n                    torch_dtype=weight_dtype,\n                )\n                if args.use_cached_prompt_embed:\n                    del text_encoder_one, text_encoder_two\n                if args.use_cached_latent:\n                    del vae\n                free_memory()\n\n    # Save the attn weights\n    accelerator.wait_for_everyone()\n    if accelerator.is_main_process:\n        transformer = unwrap_model(transformer)\n        if args.upcast_before_saving:\n            transformer.to(torch.float32)\n        else:\n            transformer = transformer.to(weight_dtype)\n\n        state_dict = {}\n        for _name, _param in transformer.named_parameters():\n            if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name:\n                state_dict[_name] = _param\n\n        save_file(\n            state_dict,\n            os.path.join(args.output_dir, 'attn_weights.safetensors')\n        )\n\n        # Final inference\n        # Load previous pipeline\n        free_memory()\n        pipeline = FluxPipeline.from_pretrained(\n            args.pretrained_model_name_or_path,\n            revision=args.revision,\n            variant=args.variant,\n            torch_dtype=weight_dtype,\n        )\n        # load attention processors\n        state_dict = {}\n        with safe_open(os.path.join(args.output_dir, 'attn_weights.safetensor'), framework=\"pt\") as f:\n            for k in f.keys():\n                state_dict[k] = f.get_tensor(k)\n\n        missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(state_dict)\n\n        missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n                                              '.attn.to_k.' in p or \n                                              '.attn.to_v.' in p or \n                                              '.attn.to_out.' in p or \n                                              'spatial_weight' in p), missing_keys))\n\n        if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n            logger.warning(\n                f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n                f\" and missing keys: {missing_keys}.\"\n            )\n\n        # run inference\n        images = []\n        if args.validation_prompt and args.num_validation_images > 0:\n            pipeline_args = {\"prompt\": args.validation_prompt}\n            images = log_validation(\n                pipeline=pipeline,\n                args=args,\n                accelerator=accelerator,\n                pipeline_args=pipeline_args,\n                epoch=epoch,\n                is_final_validation=True,\n                torch_dtype=weight_dtype,\n            )\n\n            for idx, image in enumerate(images):\n                image.save(os.path.join(args.output_dir, 'validation_idx_%02d.jpg' % idx))\n\n    accelerator.end_training()\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    main(args)\n    \n"
  },
  {
    "path": "distill.sh",
    "content": "export MODEL_NAME=\"black-forest-labs/FLUX.1-dev\"\nexport DATAPATH=\"/path/to/t2i_1024\"\nexport OUTPUT_DIR=\"ckpt/training_exp\"\nexport PRECISION=\"bf16\"\n\naccelerate launch --config_file deepspeed_config.yaml distill.py \\\n--pretrained_model_name_or_path=$MODEL_NAME  \\\n  --data_root=$DATAPATH \\\n  --output_dir=$OUTPUT_DIR \\\n  --mixed_precision=$PRECISION \\\n  --dataloader_num_workers=8 \\\n  --resolution=1024 \\\n  --train_batch_size=2 \\\n  --gradient_accumulation_steps=4 \\\n  --optimizer=\"prodigy\" \\\n  --learning_rate=1. \\\n  --report_to=\"wandb\" \\\n  --lr_scheduler=\"constant\" \\\n  --lr_warmup_steps=0 \\\n  --max_train_steps=50000 \\\n  --validation_epochs=1 \\\n  --seed=\"0\" \\\n  --checkpointing_steps=5000 \\\n  --use_cached_prompt_embed \\\n  --use_cached_latent \\\n  --gradient_checkpointing \\\n  --down_factor=1 \\\n  --window_size=16\n  \n"
  },
  {
    "path": "inference_t2i.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import os\\n\",\n    \"import requests\\n\",\n    \"from safetensors.torch import load_file\\n\",\n    \"from diffusers import FluxPipeline\\n\",\n    \"from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bfl_repo=\\\"black-forest-labs/FLUX.1-dev\\\"\\n\",\n    \"device = torch.device('cuda')\\n\",\n    \"dtype = torch.bfloat16\\n\",\n    \"pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"height = 1024\\n\",\n    \"width = 1024\\n\",\n    \"down_factor, window_size = 4, 8\\n\",\n    \"# Supported Configurations:\\n\",\n    \"# down_factor, window_size = 1, 8\\n\",\n    \"# down_factor, window_size = 1, 16\\n\",\n    \"# down_factor, window_size = 1, 32\\n\",\n    \"# down_factor, window_size = 4, 16\\n\",\n    \"# down_factor, window_size = 4, 8\\n\",\n    \"if down_factor == 1:\\n\",\n    \"    init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\\n\",\n    \"    attn_processors = {}\\n\",\n    \"    for k in pipe.transformer.attn_processors.keys():\\n\",\n    \"        attn_processors[k] = LocalFlexAttnProcessor()\\n\",\n    \"else:\\n\",\n    \"    init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\\n\",\n    \"    attn_processors = {}\\n\",\n    \"    for k in pipe.transformer.attn_processors.keys():\\n\",\n    \"        attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\\n\",\n    \"pipe.transformer.set_attn_processor(attn_processors)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.exists('ckpt'):\\n\",\n    \"    os.mkdir('ckpt')\\n\",\n    \"if down_factor == 1:\\n\",\n    \"    if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\\n\",\n    \"        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\\n\",\n    \"        response = requests.get(f\\\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\\\")\\n\",\n    \"        response.raise_for_status()\\n\",\n    \"        with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\\n\",\n    \"            f.write(response.content)\\n\",\n    \"    state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\\n\",\n    \"else:\\n\",\n    \"    if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\\n\",\n    \"        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\\n\",\n    \"        response = requests.get(f\\\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\\\")\\n\",\n    \"        response.raise_for_status()\\n\",\n    \"        with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\\n\",\n    \"            f.write(response.content)\\n\",\n    \"    state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\\n\",\n    \"\\n\",\n    \"missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\\n\",\n    \"\\n\",\n    \"missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \\n\",\n    \"                                      '.attn.to_k.' in p or \\n\",\n    \"                                      '.attn.to_v.' in p or \\n\",\n    \"                                      '.attn.to_out.' in p or \\n\",\n    \"                                      'spatial_weight' in p), missing_keys))\\n\",\n    \"\\n\",\n    \"if len(missing_keys) != 0 or len(unexpected_keys) != 0:\\n\",\n    \"    print(\\n\",\n    \"        f\\\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\\\"\\n\",\n    \"        f\\\" and missing keys: {missing_keys}.\\\"\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompt = \\\"A Mickey is eating a pie\\\"\\n\",\n    \"height = 1024\\n\",\n    \"width = 1024\\n\",\n    \"image = pipe(\\n\",\n    \"    prompt,\\n\",\n    \"    height=height,\\n\",\n    \"    width=width,\\n\",\n    \"    guidance_scale=3.5,\\n\",\n    \"    num_inference_steps=20,\\n\",\n    \"    max_sequence_length=512,\\n\",\n    \"    generator=torch.Generator(\\\"cpu\\\").manual_seed(0)\\n\",\n    \").images[0]\\n\",\n    \"image\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"pt25\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.7\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "inference_t2i_highres.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import os\\n\",\n    \"import gc\\n\",\n    \"import requests\\n\",\n    \"from safetensors.torch import load_file\\n\",\n    \"from diffusers import FluxPipeline\\n\",\n    \"from pipeline_flux_img2img import FluxImg2ImgPipeline\\n\",\n    \"from transformer_flux import FluxTransformer2DModel\\n\",\n    \"from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"bfl_repo=\\\"black-forest-labs/FLUX.1-dev\\\"\\n\",\n    \"device = torch.device('cuda')\\n\",\n    \"dtype = torch.bfloat16\\n\",\n    \"pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"prompt = \\\"enchanted forest, glowing plants, towering ancient trees, a mystical girl, magical aura, \\\" \\\\\\n\",\n    \"         \\\"fantasy style, vibrant colors, ethereal lighting, bokeh effect, ultra-detailed, painterly, ultra HD, 8K, \\\" \\\\\\n\",\n    \"         \\\"soft glowing lights, mist and fog, otherworldly ambiance, glowing mushrooms, sparkling particles\\\"\\n\",\n    \"height = 512\\n\",\n    \"width = 1024\\n\",\n    \"image = pipe(\\n\",\n    \"    prompt,\\n\",\n    \"    height=height,\\n\",\n    \"    width=width,\\n\",\n    \"    guidance_scale=3.5,\\n\",\n    \"    num_inference_steps=20,\\n\",\n    \"    max_sequence_length=512,\\n\",\n    \"    generator=torch.Generator(\\\"cpu\\\").manual_seed(0)\\n\",\n    \").images[0]\\n\",\n    \"image\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"del pipe\\n\",\n    \"torch.cuda.empty_cache()\\n\",\n    \"gc.collect()\\n\",\n    \"transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\\\"transformer\\\", torch_dtype=torch.bfloat16)\\n\",\n    \"pipe = FluxImg2ImgPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=torch.bfloat16)\\n\",\n    \"pipe.transformer = transformer\\n\",\n    \"pipe.scheduler.config.use_dynamic_shifting = False\\n\",\n    \"pipe.scheduler.config.time_shift = 10\\n\",\n    \"pipe.vae.enable_tiling()\\n\",\n    \"pipe = pipe.to(device)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"height = 2048\\n\",\n    \"width = 4096\\n\",\n    \"down_factor, window_size = 4, 8\\n\",\n    \"# Supported Configurations:\\n\",\n    \"# down_factor, window_size = 1, 8\\n\",\n    \"# down_factor, window_size = 1, 16\\n\",\n    \"# down_factor, window_size = 1, 32\\n\",\n    \"# down_factor, window_size = 4, 16\\n\",\n    \"# down_factor, window_size = 4, 8\\n\",\n    \"if down_factor == 1:\\n\",\n    \"    init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\\n\",\n    \"    attn_processors = {}\\n\",\n    \"    for k in pipe.transformer.attn_processors.keys():\\n\",\n    \"        attn_processors[k] = LocalFlexAttnProcessor()\\n\",\n    \"else:\\n\",\n    \"    init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\\n\",\n    \"    attn_processors = {}\\n\",\n    \"    for k in pipe.transformer.attn_processors.keys():\\n\",\n    \"        attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\\n\",\n    \"pipe.transformer.set_attn_processor(attn_processors)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"if not os.path.exists('ckpt'):\\n\",\n    \"    os.mkdir('ckpt')\\n\",\n    \"if down_factor == 1:\\n\",\n    \"    if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\\n\",\n    \"        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\\n\",\n    \"        response = requests.get(f\\\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\\\")\\n\",\n    \"        response.raise_for_status()\\n\",\n    \"        with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\\n\",\n    \"            f.write(response.content)\\n\",\n    \"    state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\\n\",\n    \"else:\\n\",\n    \"    if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\\n\",\n    \"        print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\\n\",\n    \"        response = requests.get(f\\\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\\\")\\n\",\n    \"        response.raise_for_status()\\n\",\n    \"        with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\\n\",\n    \"            f.write(response.content)\\n\",\n    \"    state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\\n\",\n    \"\\n\",\n    \"missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\\n\",\n    \"\\n\",\n    \"missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \\n\",\n    \"                                      '.attn.to_k.' in p or \\n\",\n    \"                                      '.attn.to_v.' in p or \\n\",\n    \"                                      '.attn.to_out.' in p or \\n\",\n    \"                                      'spatial_weight' in p), missing_keys))\\n\",\n    \"\\n\",\n    \"if len(missing_keys) != 0 or len(unexpected_keys) != 0:\\n\",\n    \"    print(\\n\",\n    \"        f\\\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\\\"\\n\",\n    \"        f\\\" and missing keys: {missing_keys}.\\\"\\n\",\n    \"    )\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"strength = 0.7\\n\",\n    \"image_hr = pipe(prompt=prompt,\\n\",\n    \"                image=image.resize((width, height)),\\n\",\n    \"                strength=strength,\\n\",\n    \"                num_inference_steps=20, \\n\",\n    \"                guidance_scale=7.5, \\n\",\n    \"                height=height,\\n\",\n    \"                width=width,\\n\",\n    \"                ntk_factor=10,\\n\",\n    \"                proportional_attention=True,\\n\",\n    \"                generator=torch.Generator(\\\"cpu\\\").manual_seed(0)\\n\",\n    \"                ).images[0]\\n\",\n    \"image_hr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"pt25\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.12.7\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "pipeline_flux_img2img.py",
    "content": "# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast\n\nfrom diffusers.image_processor import PipelineImageInput, VaeImageProcessor\nfrom diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin\nfrom diffusers.models.autoencoders import AutoencoderKL\nfrom diffusers.models.transformers import FluxTransformer2DModel\nfrom diffusers.schedulers import FlowMatchEulerDiscreteScheduler\nfrom diffusers.utils import (\n    USE_PEFT_BACKEND,\n    is_torch_xla_available,\n    logging,\n    replace_example_docstring,\n    scale_lora_layers,\n    unscale_lora_layers,\n)\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput\n\n\nif is_torch_xla_available():\n    import torch_xla.core.xla_model as xm\n\n    XLA_AVAILABLE = True\nelse:\n    XLA_AVAILABLE = False\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\nEXAMPLE_DOC_STRING = \"\"\"\n    Examples:\n        ```py\n        >>> import torch\n\n        >>> from diffusers import FluxImg2ImgPipeline\n        >>> from diffusers.utils import load_image\n\n        >>> device = \"cuda\"\n        >>> pipe = FluxImg2ImgPipeline.from_pretrained(\"black-forest-labs/FLUX.1-schnell\", torch_dtype=torch.bfloat16)\n        >>> pipe = pipe.to(device)\n\n        >>> url = \"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg\"\n        >>> init_image = load_image(url).resize((1024, 1024))\n\n        >>> prompt = \"cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k\"\n\n        >>> images = pipe(\n        ...     prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0\n        ... ).images[0]\n        ```\n\"\"\"\n\n\n# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift\ndef calculate_shift(\n    image_seq_len,\n    base_seq_len: int = 256,\n    max_seq_len: int = 4096,\n    base_shift: float = 0.5,\n    max_shift: float = 1.16,\n):\n    m = (max_shift - base_shift) / (max_seq_len - base_seq_len)\n    b = base_shift - m * base_seq_len\n    mu = image_seq_len * m + b\n    return mu\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents\ndef retrieve_latents(\n    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = \"sample\"\n):\n    if hasattr(encoder_output, \"latent_dist\") and sample_mode == \"sample\":\n        return encoder_output.latent_dist.sample(generator)\n    elif hasattr(encoder_output, \"latent_dist\") and sample_mode == \"argmax\":\n        return encoder_output.latent_dist.mode()\n    elif hasattr(encoder_output, \"latents\"):\n        return encoder_output.latents\n    else:\n        raise AttributeError(\"Could not access latents of provided encoder_output\")\n\n\n# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps\ndef retrieve_timesteps(\n    scheduler,\n    num_inference_steps: Optional[int] = None,\n    device: Optional[Union[str, torch.device]] = None,\n    timesteps: Optional[List[int]] = None,\n    sigmas: Optional[List[float]] = None,\n    **kwargs,\n):\n    r\"\"\"\n    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles\n    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.\n\n    Args:\n        scheduler (`SchedulerMixin`):\n            The scheduler to get timesteps from.\n        num_inference_steps (`int`):\n            The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`\n            must be `None`.\n        device (`str` or `torch.device`, *optional*):\n            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.\n        timesteps (`List[int]`, *optional*):\n            Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,\n            `num_inference_steps` and `sigmas` must be `None`.\n        sigmas (`List[float]`, *optional*):\n            Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,\n            `num_inference_steps` and `timesteps` must be `None`.\n\n    Returns:\n        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the\n        second element is the number of inference steps.\n    \"\"\"\n    if timesteps is not None and sigmas is not None:\n        raise ValueError(\"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values\")\n    if timesteps is not None:\n        accepts_timesteps = \"timesteps\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accepts_timesteps:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" timestep schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    elif sigmas is not None:\n        accept_sigmas = \"sigmas\" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())\n        if not accept_sigmas:\n            raise ValueError(\n                f\"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom\"\n                f\" sigmas schedules. Please check whether you are using the correct scheduler.\"\n            )\n        scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n        num_inference_steps = len(timesteps)\n    else:\n        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)\n        timesteps = scheduler.timesteps\n    return timesteps, num_inference_steps\n\n\nclass FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):\n    r\"\"\"\n    The Flux pipeline for image inpainting.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Args:\n        transformer ([`FluxTransformer2DModel`]):\n            Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.\n        scheduler ([`FlowMatchEulerDiscreteScheduler`]):\n            A scheduler to be used in combination with `transformer` to denoise the encoded image latents.\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        text_encoder ([`CLIPTextModel`]):\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically\n            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        text_encoder_2 ([`T5EncoderModel`]):\n            [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically\n            the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.\n        tokenizer (`CLIPTokenizer`):\n            Tokenizer of class\n            [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).\n        tokenizer_2 (`T5TokenizerFast`):\n            Second Tokenizer of class\n            [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).\n    \"\"\"\n\n    model_cpu_offload_seq = \"text_encoder->text_encoder_2->transformer->vae\"\n    _optional_components = []\n    _callback_tensor_inputs = [\"latents\", \"prompt_embeds\"]\n\n    def __init__(\n        self,\n        scheduler: FlowMatchEulerDiscreteScheduler,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        text_encoder_2: T5EncoderModel,\n        tokenizer_2: T5TokenizerFast,\n        transformer: FluxTransformer2DModel,\n    ):\n        super().__init__()\n\n        self.register_modules(\n            vae=vae,\n            text_encoder=text_encoder,\n            text_encoder_2=text_encoder_2,\n            tokenizer=tokenizer,\n            tokenizer_2=tokenizer_2,\n            transformer=transformer,\n            scheduler=scheduler,\n        )\n        self.vae_scale_factor = (\n            2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, \"vae\") and self.vae is not None else 16\n        )\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.tokenizer_max_length = (\n            self.tokenizer.model_max_length if hasattr(self, \"tokenizer\") and self.tokenizer is not None else 77\n        )\n        self.default_sample_size = 64\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds\n    def _get_t5_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]] = None,\n        num_images_per_prompt: int = 1,\n        max_sequence_length: int = 512,\n        device: Optional[torch.device] = None,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        device = device or self._execution_device\n        dtype = dtype or self.text_encoder.dtype\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)\n\n        text_inputs = self.tokenizer_2(\n            prompt,\n            padding=\"max_length\",\n            max_length=max_sequence_length,\n            truncation=True,\n            return_length=False,\n            return_overflowing_tokens=False,\n            return_tensors=\"pt\",\n        )\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer_2(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because `max_sequence_length` is set to \"\n                f\" {max_sequence_length} tokens: {removed_text}\"\n            )\n\n        prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]\n\n        dtype = self.text_encoder_2.dtype\n        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)\n\n        _, seq_len, _ = prompt_embeds.shape\n\n        # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds\n    def _get_clip_prompt_embeds(\n        self,\n        prompt: Union[str, List[str]],\n        num_images_per_prompt: int = 1,\n        device: Optional[torch.device] = None,\n    ):\n        device = device or self._execution_device\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n        batch_size = len(prompt)\n\n        if isinstance(self, TextualInversionLoaderMixin):\n            prompt = self.maybe_convert_prompt(prompt, self.tokenizer)\n\n        text_inputs = self.tokenizer(\n            prompt,\n            padding=\"max_length\",\n            max_length=self.tokenizer_max_length,\n            truncation=True,\n            return_overflowing_tokens=False,\n            return_length=False,\n            return_tensors=\"pt\",\n        )\n\n        text_input_ids = text_inputs.input_ids\n        untruncated_ids = self.tokenizer(prompt, padding=\"longest\", return_tensors=\"pt\").input_ids\n        if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):\n            removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])\n            logger.warning(\n                \"The following part of your input was truncated because CLIP can only handle sequences up to\"\n                f\" {self.tokenizer_max_length} tokens: {removed_text}\"\n            )\n        prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)\n\n        # Use pooled output of CLIPTextModel\n        prompt_embeds = prompt_embeds.pooler_output\n        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)\n\n        # duplicate text embeddings for each generation per prompt, using mps friendly method\n        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)\n        prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)\n\n        return prompt_embeds\n\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt\n    def encode_prompt(\n        self,\n        prompt: Union[str, List[str]],\n        prompt_2: Union[str, List[str]],\n        device: Optional[torch.device] = None,\n        num_images_per_prompt: int = 1,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        max_sequence_length: int = 512,\n        lora_scale: Optional[float] = None,\n    ):\n        r\"\"\"\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                prompt to be encoded\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                used in all text-encoders\n            device: (`torch.device`):\n                torch device\n            num_images_per_prompt (`int`):\n                number of images that should be generated per prompt\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            lora_scale (`float`, *optional*):\n                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.\n        \"\"\"\n        device = device or self._execution_device\n\n        # set lora scale so that monkey patched LoRA\n        # function of text encoder can correctly access it\n        if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):\n            self._lora_scale = lora_scale\n\n            # dynamically adjust the LoRA scale\n            if self.text_encoder is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder, lora_scale)\n            if self.text_encoder_2 is not None and USE_PEFT_BACKEND:\n                scale_lora_layers(self.text_encoder_2, lora_scale)\n\n        prompt = [prompt] if isinstance(prompt, str) else prompt\n\n        if prompt_embeds is None:\n            prompt_2 = prompt_2 or prompt\n            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2\n\n            # We only use the pooled prompt output from the CLIPTextModel\n            pooled_prompt_embeds = self._get_clip_prompt_embeds(\n                prompt=prompt,\n                device=device,\n                num_images_per_prompt=num_images_per_prompt,\n            )\n            prompt_embeds = self._get_t5_prompt_embeds(\n                prompt=prompt_2,\n                num_images_per_prompt=num_images_per_prompt,\n                max_sequence_length=max_sequence_length,\n                device=device,\n            )\n\n        if self.text_encoder is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder, lora_scale)\n\n        if self.text_encoder_2 is not None:\n            if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:\n                # Retrieve the original scale by scaling back the LoRA layers\n                unscale_lora_layers(self.text_encoder_2, lora_scale)\n\n        dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype\n        text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)\n\n        return prompt_embeds, pooled_prompt_embeds, text_ids\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image\n    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):\n        if isinstance(generator, list):\n            image_latents = [\n                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])\n                for i in range(image.shape[0])\n            ]\n            image_latents = torch.cat(image_latents, dim=0)\n        else:\n            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)\n\n        image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor\n\n        return image_latents\n\n    # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps\n    def get_timesteps(self, num_inference_steps, strength, device):\n        # get the original timestep using init_timestep\n        init_timestep = min(num_inference_steps * strength, num_inference_steps)\n\n        t_start = int(max(num_inference_steps - init_timestep, 0))\n        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]\n        if hasattr(self.scheduler, \"set_begin_index\"):\n            self.scheduler.set_begin_index(t_start * self.scheduler.order)\n\n        return timesteps, num_inference_steps - t_start\n\n    def check_inputs(\n        self,\n        prompt,\n        prompt_2,\n        strength,\n        height,\n        width,\n        prompt_embeds=None,\n        pooled_prompt_embeds=None,\n        callback_on_step_end_tensor_inputs=None,\n        max_sequence_length=None,\n    ):\n        if strength < 0 or strength > 1:\n            raise ValueError(f\"The value of strength should in [0.0, 1.0] but is {strength}\")\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\")\n\n        if callback_on_step_end_tensor_inputs is not None and not all(\n            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs\n        ):\n            raise ValueError(\n                f\"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}\"\n            )\n\n        if prompt is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt_2 is not None and prompt_embeds is not None:\n            raise ValueError(\n                f\"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to\"\n                \" only forward one of the two.\"\n            )\n        elif prompt is None and prompt_embeds is None:\n            raise ValueError(\n                \"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.\"\n            )\n        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):\n            raise ValueError(f\"`prompt` has to be of type `str` or `list` but is {type(prompt)}\")\n        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):\n            raise ValueError(f\"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}\")\n\n        if prompt_embeds is not None and pooled_prompt_embeds is None:\n            raise ValueError(\n                \"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`.\"\n            )\n\n        if max_sequence_length is not None and max_sequence_length > 512:\n            raise ValueError(f\"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}\")\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids\n    def _prepare_latent_image_ids(batch_size, height, width, device, dtype):\n        latent_image_ids = torch.zeros(height // 2, width // 2, 3)\n        latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]\n        latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]\n\n        latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape\n\n        latent_image_ids = latent_image_ids.reshape(\n            latent_image_id_height * latent_image_id_width, latent_image_id_channels\n        )\n\n        return latent_image_ids.to(device=device, dtype=dtype)\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents\n    def _pack_latents(latents, batch_size, num_channels_latents, height, width):\n        latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)\n        latents = latents.permute(0, 2, 4, 1, 3, 5)\n        latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)\n\n        return latents\n\n    @staticmethod\n    # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents\n    def _unpack_latents(latents, height, width, vae_scale_factor):\n        batch_size, num_patches, channels = latents.shape\n\n        height = height // vae_scale_factor\n        width = width // vae_scale_factor\n\n        latents = latents.view(batch_size, height, width, channels // 4, 2, 2)\n        latents = latents.permute(0, 3, 1, 4, 2, 5)\n\n        latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)\n\n        return latents\n\n    def prepare_latents(\n        self,\n        image,\n        timestep,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        height = 2 * (int(height) // self.vae_scale_factor)\n        width = 2 * (int(width) // self.vae_scale_factor)\n\n        shape = (batch_size, num_channels_latents, height, width)\n        latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)\n\n        if latents is not None:\n            return latents.to(device=device, dtype=dtype), latent_image_ids\n\n        image = image.to(device=device, dtype=dtype)\n        image_latents = self._encode_vae_image(image=image, generator=generator)\n        if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:\n            # expand init_latents for batch_size\n            additional_image_per_prompt = batch_size // image_latents.shape[0]\n            image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)\n        elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:\n            raise ValueError(\n                f\"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.\"\n            )\n        else:\n            image_latents = torch.cat([image_latents], dim=0)\n\n        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)\n        latents = self.scheduler.scale_noise(image_latents, timestep, noise)\n        latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)\n        image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)\n        return latents, latent_image_ids, image_latents\n\n    @property\n    def guidance_scale(self):\n        return self._guidance_scale\n\n    @property\n    def joint_attention_kwargs(self):\n        return self._joint_attention_kwargs\n\n    @property\n    def num_timesteps(self):\n        return self._num_timesteps\n\n    @property\n    def interrupt(self):\n        return self._interrupt\n\n    @torch.no_grad()\n    @replace_example_docstring(EXAMPLE_DOC_STRING)\n    def __call__(\n        self,\n        prompt: Union[str, List[str]] = None,\n        prompt_2: Optional[Union[str, List[str]]] = None,\n        image: PipelineImageInput = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        strength: float = 0.6,\n        num_inference_steps: int = 28,\n        timesteps: List[int] = None,\n        guidance_scale: float = 7.0,\n        num_images_per_prompt: Optional[int] = 1,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        prompt_embeds: Optional[torch.FloatTensor] = None,\n        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,\n        callback_on_step_end_tensor_inputs: List[str] = [\"latents\"],\n        max_sequence_length: int = 512,\n        ntk_factor: float = 10.0,\n        proportional_attention: bool = True\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            prompt (`str` or `List[str]`, *optional*):\n                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.\n                instead.\n            prompt_2 (`str` or `List[str]`, *optional*):\n                The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is\n                will be used instead\n            image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):\n                `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both\n                numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list\n                or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a\n                list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image\n                latents as `image`, but if passing latents directly it is not encoded again.\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image. This is set to 1024 by default for the best results.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image. This is set to 1024 by default for the best results.\n            strength (`float`, *optional*, defaults to 1.0):\n                Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a\n                starting point and more noise is added the higher the `strength`. The number of denoising steps depends\n                on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising\n                process runs for the full number of iterations specified in `num_inference_steps`. A value of 1\n                essentially ignores `image`.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            timesteps (`List[int]`, *optional*):\n                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument\n                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is\n                passed will be used. Must be in descending order.\n            guidance_scale (`float`, *optional*, defaults to 7.0):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            generator (`torch.Generator` or `List[torch.Generator]`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will ge generated by sampling using the supplied random `generator`.\n            prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not\n                provided, text embeddings will be generated from `prompt` input argument.\n            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):\n                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.\n                If not provided, pooled text embeddings will be generated from `prompt` input argument.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            callback_on_step_end (`Callable`, *optional*):\n                A function that calls at the end of each denoising steps during the inference. The function is called\n                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,\n                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by\n                `callback_on_step_end_tensor_inputs`.\n            callback_on_step_end_tensor_inputs (`List`, *optional*):\n                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list\n                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the\n                `._callback_tensor_inputs` attribute of your pipeline class.\n            max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.\n\n        Examples:\n\n        Returns:\n            [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`\n            is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated\n            images.\n        \"\"\"\n\n        height = height or self.default_sample_size * self.vae_scale_factor\n        width = width or self.default_sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        self.check_inputs(\n            prompt,\n            prompt_2,\n            strength,\n            height,\n            width,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,\n            max_sequence_length=max_sequence_length,\n        )\n\n        self._guidance_scale = guidance_scale\n        self._joint_attention_kwargs = joint_attention_kwargs\n        self._interrupt = False\n\n        # 2. Preprocess image\n        init_image = self.image_processor.preprocess(image, height=height, width=width)\n        init_image = init_image.to(dtype=torch.float32)\n\n        # 3. Define call parameters\n        if prompt is not None and isinstance(prompt, str):\n            batch_size = 1\n        elif prompt is not None and isinstance(prompt, list):\n            batch_size = len(prompt)\n        else:\n            batch_size = prompt_embeds.shape[0]\n\n        device = self._execution_device\n\n        lora_scale = (\n            self.joint_attention_kwargs.get(\"scale\", None) if self.joint_attention_kwargs is not None else None\n        )\n        (\n            prompt_embeds,\n            pooled_prompt_embeds,\n            text_ids,\n        ) = self.encode_prompt(\n            prompt=prompt,\n            prompt_2=prompt_2,\n            prompt_embeds=prompt_embeds,\n            pooled_prompt_embeds=pooled_prompt_embeds,\n            device=device,\n            num_images_per_prompt=num_images_per_prompt,\n            max_sequence_length=max_sequence_length,\n            lora_scale=lora_scale,\n        )\n\n        # 4.Prepare timesteps\n        sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)\n        timesteps, num_inference_steps_ = retrieve_timesteps(\n            self.scheduler,\n            num_inference_steps,\n            device,\n            timesteps,\n            sigmas\n        )\n        timesteps, num_inference_steps_ = self.get_timesteps(num_inference_steps_, strength, device)\n\n        if num_inference_steps_ < 1:\n            raise ValueError(\n                f\"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline\"\n                f\"steps is {num_inference_steps_} which is < 1 and not appropriate for this pipeline.\"\n            )\n        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)\n\n        # 5. Prepare latent variables\n        num_channels_latents = self.transformer.config.in_channels // 4\n\n        latents, latent_image_ids, image_latents = self.prepare_latents(\n            init_image,\n            latent_timestep,\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            prompt_embeds.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        num_warmup_steps = max(len(timesteps) - num_inference_steps_ * self.scheduler.order, 0)\n        self._num_timesteps = len(timesteps)\n\n        # handle guidance\n        if self.transformer.config.guidance_embeds:\n            guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)\n            guidance = guidance.expand(latents.shape[0])\n        else:\n            guidance = None\n\n        # 6. Denoising loop\n        with self.progress_bar(total=num_inference_steps_) as progress_bar:\n            for i, t in enumerate(timesteps):\n                if self.interrupt:\n                    continue\n\n                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML\n                timestep = t.expand(latents.shape[0]).to(latents.dtype)\n                noise_pred = self.transformer(\n                    hidden_states=latents,\n                    timestep=timestep / 1000,\n                    guidance=guidance,\n                    pooled_projections=pooled_prompt_embeds,\n                    encoder_hidden_states=prompt_embeds,\n                    txt_ids=text_ids,\n                    img_ids=latent_image_ids,\n                    joint_attention_kwargs=self.joint_attention_kwargs,\n                    return_dict=False,\n                    ntk_factor=ntk_factor,\n                    proportional_attention=proportional_attention,\n                )[0]\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents_dtype = latents.dtype\n                latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]\n\n                if latents.dtype != latents_dtype:\n                    if torch.backends.mps.is_available():\n                        # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272\n                        latents = latents.to(latents_dtype)\n\n                if callback_on_step_end is not None:\n                    callback_kwargs = {}\n                    for k in callback_on_step_end_tensor_inputs:\n                        callback_kwargs[k] = locals()[k]\n                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)\n\n                    latents = callback_outputs.pop(\"latents\", latents)\n                    prompt_embeds = callback_outputs.pop(\"prompt_embeds\", prompt_embeds)\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):\n                    progress_bar.update()\n\n                if XLA_AVAILABLE:\n                    xm.mark_step()\n\n        if output_type == \"latent\":\n            image = latents\n\n        else:\n            latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)\n            latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor\n            image = self.vae.decode(latents, return_dict=False)[0]\n            image = self.image_processor.postprocess(image, output_type=output_type)\n\n        # Offload all models\n        self.maybe_free_model_hooks()\n\n        if not return_dict:\n            return (image,)\n\n        return FluxPipelineOutput(images=image)\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch>=2.5.0\ndiffusers>=0.31.0\ntransformers\nsafetensors\nopencv-python\naccelerate\ntqdm\ndeepspeed\nwandb\nprodigyopt\nsentencepiece\n"
  },
  {
    "path": "transformer_flux.py",
    "content": "# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\nfrom typing import Any, Dict, Optional, Tuple, Union, List\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom diffusers.configuration_utils import ConfigMixin, register_to_config\nfrom diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin\nfrom diffusers.models.attention import FeedForward\nfrom diffusers.models.attention_processor import (\n    Attention,\n    AttentionProcessor\n)\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle\nfrom diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers\nfrom diffusers.utils.torch_utils import maybe_allow_in_graph\nfrom diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed\nfrom diffusers.models.modeling_outputs import Transformer2DModelOutput\nfrom attention_processor import FluxAttnProcessor2_0\n\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\n@maybe_allow_in_graph\nclass FluxSingleTransformerBlock(nn.Module):\n    r\"\"\"\n    A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.\n\n    Reference: https://arxiv.org/abs/2403.03206\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the\n            processing of `context` conditions.\n    \"\"\"\n\n    def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):\n        super().__init__()\n        self.mlp_hidden_dim = int(dim * mlp_ratio)\n\n        self.norm = AdaLayerNormZeroSingle(dim)\n        self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)\n        self.act_mlp = nn.GELU(approximate=\"tanh\")\n        self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)\n\n        processor = FluxAttnProcessor2_0()\n        self.attn = Attention(\n            query_dim=dim,\n            cross_attention_dim=None,\n            dim_head=attention_head_dim,\n            heads=num_attention_heads,\n            out_dim=dim,\n            bias=True,\n            processor=processor,\n            qk_norm=\"rms_norm\",\n            eps=1e-6,\n            pre_only=True,\n        )\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        temb: torch.FloatTensor,\n        image_rotary_emb=None,\n        proportional_attention: bool = False,\n        joint_attention_kwargs=None\n    ):\n        residual = hidden_states\n        norm_hidden_states, gate = self.norm(hidden_states, emb=temb)\n        mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        attn_output = self.attn(\n            hidden_states=norm_hidden_states,\n            image_rotary_emb=image_rotary_emb,\n            proportional_attention=proportional_attention,\n            **joint_attention_kwargs,\n        )\n\n        hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)\n        gate = gate.unsqueeze(1)\n        hidden_states = gate * self.proj_out(hidden_states)\n        hidden_states = residual + hidden_states\n        if hidden_states.dtype == torch.float16:\n            hidden_states = hidden_states.clip(-65504, 65504)\n\n        return hidden_states\n\n\n@maybe_allow_in_graph\nclass FluxTransformerBlock(nn.Module):\n    r\"\"\"\n    A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.\n\n    Reference: https://arxiv.org/abs/2403.03206\n\n    Parameters:\n        dim (`int`): The number of channels in the input and output.\n        num_attention_heads (`int`): The number of heads to use for multi-head attention.\n        attention_head_dim (`int`): The number of channels in each head.\n        context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the\n            processing of `context` conditions.\n    \"\"\"\n\n    def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm=\"rms_norm\", eps=1e-6):\n        super().__init__()\n\n        self.norm1 = AdaLayerNormZero(dim)\n\n        self.norm1_context = AdaLayerNormZero(dim)\n\n        if hasattr(F, \"scaled_dot_product_attention\"):\n            processor = FluxAttnProcessor2_0()\n        else:\n            raise ValueError(\n                \"The current PyTorch version does not support the `scaled_dot_product_attention` function.\"\n            )\n        self.attn = Attention(\n            query_dim=dim,\n            cross_attention_dim=None,\n            added_kv_proj_dim=dim,\n            dim_head=attention_head_dim,\n            heads=num_attention_heads,\n            out_dim=dim,\n            context_pre_only=False,\n            bias=True,\n            processor=processor,\n            qk_norm=qk_norm,\n            eps=eps,\n        )\n\n        self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n        self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn=\"gelu-approximate\")\n\n        self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)\n        self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn=\"gelu-approximate\")\n\n        # let chunk size default to None\n        self._chunk_size = None\n        self._chunk_dim = 0\n\n    def forward(\n        self,\n        hidden_states: torch.FloatTensor,\n        encoder_hidden_states: torch.FloatTensor,\n        temb: torch.FloatTensor,\n        image_rotary_emb=None,\n        proportional_attention: bool = False,\n        joint_attention_kwargs=None,\n    ):\n        norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)\n\n        norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(\n            encoder_hidden_states, emb=temb\n        )\n        joint_attention_kwargs = joint_attention_kwargs or {}\n        # Attention.\n        attn_output, context_attn_output = self.attn(\n            hidden_states=norm_hidden_states,\n            encoder_hidden_states=norm_encoder_hidden_states,\n            image_rotary_emb=image_rotary_emb,\n            proportional_attention=proportional_attention,\n            **joint_attention_kwargs,\n        )\n\n        # Process attention outputs for the `hidden_states`.\n        attn_output = gate_msa.unsqueeze(1) * attn_output\n        hidden_states = hidden_states + attn_output\n\n        norm_hidden_states = self.norm2(hidden_states)\n        norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]\n\n        ff_output = self.ff(norm_hidden_states)\n        ff_output = gate_mlp.unsqueeze(1) * ff_output\n\n        hidden_states = hidden_states + ff_output\n\n        # Process attention outputs for the `encoder_hidden_states`.\n\n        context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output\n        encoder_hidden_states = encoder_hidden_states + context_attn_output\n\n        norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)\n        norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]\n\n        context_ff_output = self.ff_context(norm_encoder_hidden_states)\n        encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output\n        if encoder_hidden_states.dtype == torch.float16:\n            encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)\n\n        return encoder_hidden_states, hidden_states\n    \n\nclass FluxPosEmbed(nn.Module):\n    # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11\n    def __init__(self, theta: int, axes_dim: List[int]):\n        super().__init__()\n        self.theta = theta\n        self.axes_dim = axes_dim\n\n    def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor:\n        n_axes = ids.shape[-1]\n        cos_out = []\n        sin_out = []\n        pos = ids.float()\n        is_mps = ids.device.type == \"mps\"\n        freqs_dtype = torch.float32 if is_mps else torch.float64\n        for i in range(n_axes):\n            cos, sin = get_1d_rotary_pos_embed(\n                self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype,\n                ntk_factor=ntk_factor\n            )\n            cos_out.append(cos)\n            sin_out.append(sin)\n        freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)\n        freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)\n        return freqs_cos, freqs_sin\n\n\nclass FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):\n    \"\"\"\n    The Transformer model introduced in Flux.\n\n    Reference: https://blackforestlabs.ai/announcing-black-forest-labs/\n\n    Parameters:\n        patch_size (`int`): Patch size to turn the input data into small patches.\n        in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.\n        num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.\n        num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.\n        attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.\n        num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.\n        joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.\n        pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.\n        guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.\n    \"\"\"\n\n    _supports_gradient_checkpointing = True\n    _no_split_modules = [\"FluxTransformerBlock\", \"FluxSingleTransformerBlock\"]\n\n    @register_to_config\n    def __init__(\n        self,\n        patch_size: int = 1,\n        in_channels: int = 64,\n        num_layers: int = 19,\n        num_single_layers: int = 38,\n        attention_head_dim: int = 128,\n        num_attention_heads: int = 24,\n        joint_attention_dim: int = 4096,\n        pooled_projection_dim: int = 768,\n        guidance_embeds: bool = False,\n        axes_dims_rope: Tuple[int] = (16, 56, 56),\n    ):\n        super().__init__()\n        self.out_channels = in_channels\n        self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim\n\n        self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)\n\n        text_time_guidance_cls = (\n            CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings\n        )\n        self.time_text_embed = text_time_guidance_cls(\n            embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim\n        )\n\n        self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)\n        self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)\n\n        self.transformer_blocks = nn.ModuleList(\n            [\n                FluxTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=self.config.num_attention_heads,\n                    attention_head_dim=self.config.attention_head_dim,\n                )\n                for i in range(self.config.num_layers)\n            ]\n        )\n\n        self.single_transformer_blocks = nn.ModuleList(\n            [\n                FluxSingleTransformerBlock(\n                    dim=self.inner_dim,\n                    num_attention_heads=self.config.num_attention_heads,\n                    attention_head_dim=self.config.attention_head_dim,\n                )\n                for i in range(self.config.num_single_layers)\n            ]\n        )\n\n        self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)\n        self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)\n\n        self.gradient_checkpointing = False\n\n    @property\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors\n    def attn_processors(self) -> Dict[str, AttentionProcessor]:\n        r\"\"\"\n        Returns:\n            `dict` of attention processors: A dictionary containing all attention processors used in the model with\n            indexed by its weight name.\n        \"\"\"\n        # set recursively\n        processors = {}\n\n        def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):\n            if hasattr(module, \"get_processor\"):\n                processors[f\"{name}.processor\"] = module.get_processor()\n\n            for sub_name, child in module.named_children():\n                fn_recursive_add_processors(f\"{name}.{sub_name}\", child, processors)\n\n            return processors\n\n        for name, module in self.named_children():\n            fn_recursive_add_processors(name, module, processors)\n\n        return processors\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor\n    def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):\n        r\"\"\"\n        Sets the attention processor to use to compute attention.\n\n        Parameters:\n            processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):\n                The instantiated processor class or a dictionary of processor classes that will be set as the processor\n                for **all** `Attention` layers.\n\n                If `processor` is a dict, the key needs to define the path to the corresponding cross attention\n                processor. This is strongly recommended when setting trainable attention processors.\n\n        \"\"\"\n        count = len(self.attn_processors.keys())\n\n        if isinstance(processor, dict) and len(processor) != count:\n            raise ValueError(\n                f\"A dict of processors was passed, but the number of processors {len(processor)} does not match the\"\n                f\" number of attention layers: {count}. Please make sure to pass {count} processor classes.\"\n            )\n\n        def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):\n            if hasattr(module, \"set_processor\"):\n                if not isinstance(processor, dict):\n                    module.set_processor(processor)\n                else:\n                    module.set_processor(processor.pop(f\"{name}.processor\"))\n\n            for sub_name, child in module.named_children():\n                fn_recursive_attn_processor(f\"{name}.{sub_name}\", child, processor)\n\n        for name, module in self.named_children():\n            fn_recursive_attn_processor(name, module, processor)\n\n    # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections\n    def unfuse_qkv_projections(self):\n        \"\"\"Disables the fused QKV projection if enabled.\n\n        <Tip warning={true}>\n\n        This API is 🧪 experimental.\n\n        </Tip>\n\n        \"\"\"\n        if self.original_attn_processors is not None:\n            self.set_attn_processor(self.original_attn_processors)\n\n    def _set_gradient_checkpointing(self, module, value=False):\n        if hasattr(module, \"gradient_checkpointing\"):\n            module.gradient_checkpointing = value\n\n    def forward(\n        self,\n        hidden_states: torch.Tensor,\n        encoder_hidden_states: torch.Tensor = None,\n        pooled_projections: torch.Tensor = None,\n        timestep: torch.LongTensor = None,\n        img_ids: torch.Tensor = None,\n        txt_ids: torch.Tensor = None,\n        guidance: torch.Tensor = None,\n        joint_attention_kwargs: Optional[Dict[str, Any]] = None,\n        controlnet_block_samples=None,\n        controlnet_single_block_samples=None,\n        return_dict: bool = True,\n        ntk_factor: float = 1,\n        proportional_attention: bool = False,\n        controlnet_blocks_repeat: bool = False,\n    ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:\n        \"\"\"\n        The [`FluxTransformer2DModel`] forward method.\n\n        Args:\n            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):\n                Input `hidden_states`.\n            encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):\n                Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.\n            pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected\n                from the embeddings of input conditions.\n            timestep ( `torch.LongTensor`):\n                Used to indicate denoising step.\n            block_controlnet_hidden_states: (`list` of `torch.Tensor`):\n                A list of tensors that if specified are added to the residuals of transformer blocks.\n            joint_attention_kwargs (`dict`, *optional*):\n                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under\n                `self.processor` in\n                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain\n                tuple.\n\n        Returns:\n            If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a\n            `tuple` where the first element is the sample tensor.\n        \"\"\"\n\n        if txt_ids.ndim == 3:\n            logger.warning(\n                \"Passing `txt_ids` 3d torch.Tensor is deprecated.\"\n                \"Please remove the batch dimension and pass it as a 2d torch Tensor\"\n            )\n            txt_ids = txt_ids[0]\n        if img_ids.ndim == 3:\n            logger.warning(\n                \"Passing `img_ids` 3d torch.Tensor is deprecated.\"\n                \"Please remove the batch dimension and pass it as a 2d torch Tensor\"\n            )\n            img_ids = img_ids[0]\n\n        if joint_attention_kwargs is not None:\n            joint_attention_kwargs = joint_attention_kwargs.copy()\n            lora_scale = joint_attention_kwargs.pop(\"scale\", 1.0)\n        else:\n            lora_scale = 1.0\n\n        if USE_PEFT_BACKEND:\n            # weight the lora layers by setting `lora_scale` for each PEFT layer\n            scale_lora_layers(self, lora_scale)\n        else:\n            if joint_attention_kwargs is not None and joint_attention_kwargs.get(\"scale\", None) is not None:\n                logger.warning(\n                    \"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective.\"\n                )\n        hidden_states = self.x_embedder(hidden_states)\n\n        timestep = timestep.to(hidden_states.dtype) * 1000\n        if guidance is not None:\n            guidance = guidance.to(hidden_states.dtype) * 1000\n        else:\n            guidance = None\n        temb = (\n            self.time_text_embed(timestep, pooled_projections)\n            if guidance is None\n            else self.time_text_embed(timestep, guidance, pooled_projections)\n        )\n        encoder_hidden_states = self.context_embedder(encoder_hidden_states)\n\n        ids = torch.cat((txt_ids, img_ids), dim=0)\n        image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor)\n\n        for index_block, block in enumerate(self.transformer_blocks):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    hidden_states,\n                    encoder_hidden_states,\n                    temb,\n                    image_rotary_emb,\n                    proportional_attention,\n                    **ckpt_kwargs,\n                )\n\n            else:\n                encoder_hidden_states, hidden_states = block(\n                    hidden_states=hidden_states,\n                    encoder_hidden_states=encoder_hidden_states,\n                    temb=temb,\n                    image_rotary_emb=image_rotary_emb,\n                    proportional_attention=proportional_attention,\n                    joint_attention_kwargs=joint_attention_kwargs,\n                )\n\n            # controlnet residual\n            if controlnet_block_samples is not None:\n                interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)\n                interval_control = int(np.ceil(interval_control))\n                # For Xlabs ControlNet.\n                if controlnet_blocks_repeat:\n                    hidden_states = (\n                        hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]\n                    )\n                else:\n                    hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]\n\n        hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)\n\n        for index_block, block in enumerate(self.single_transformer_blocks):\n            if self.training and self.gradient_checkpointing:\n\n                def create_custom_forward(module, return_dict=None):\n                    def custom_forward(*inputs):\n                        if return_dict is not None:\n                            return module(*inputs, return_dict=return_dict)\n                        else:\n                            return module(*inputs)\n\n                    return custom_forward\n\n                ckpt_kwargs: Dict[str, Any] = {\"use_reentrant\": False} if is_torch_version(\">=\", \"1.11.0\") else {}\n                hidden_states = torch.utils.checkpoint.checkpoint(\n                    create_custom_forward(block),\n                    hidden_states,\n                    temb,\n                    image_rotary_emb,\n                    proportional_attention,\n                    **ckpt_kwargs,\n                )\n\n            else:\n                hidden_states = block(\n                    hidden_states=hidden_states,\n                    temb=temb,\n                    image_rotary_emb=image_rotary_emb,\n                    proportional_attention=proportional_attention,\n                    joint_attention_kwargs=joint_attention_kwargs,\n                )\n\n            # controlnet residual\n            if controlnet_single_block_samples is not None:\n                interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)\n                interval_control = int(np.ceil(interval_control))\n                hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (\n                    hidden_states[:, encoder_hidden_states.shape[1] :, ...]\n                    + controlnet_single_block_samples[index_block // interval_control]\n                )\n\n        hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]\n\n        hidden_states = self.norm_out(hidden_states, temb)\n        output = self.proj_out(hidden_states)\n\n        if USE_PEFT_BACKEND:\n            # remove `lora_scale` from each PEFT layer\n            unscale_lora_layers(self, lora_scale)\n\n        if not return_dict:\n            return (output,)\n\n        return Transformer2DModelOutput(sample=output)\n"
  }
]