Repository: Huage001/CLEAR Branch: main Commit: ddca13160b7c Files: 16 Total size: 177.7 KB Directory structure: gitextract_7vgr3kj8/ ├── LICENSE ├── README.md ├── attention_processor.py ├── cache_latent_codes.py ├── cache_latent_codes.sh ├── cache_prompt_embeds.py ├── cache_prompt_embeds.sh ├── dataset.py ├── deepspeed_config.yaml ├── distill.py ├── distill.sh ├── inference_t2i.ipynb ├── inference_t2i_highres.ipynb ├── pipeline_flux_img2img.py ├── requirements.txt └── transformer_flux.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================
# CLEAR arXiv
> **CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up** >
> [Songhua Liu](http://huage001.github.io/), > [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ&hl=en), > and > [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) >
> NeurIPS 2025 >
> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore >
![](./assets/teaser.png) ## 🔥News **[2025/9/18]** CLEAR is accepted at NeurIPS 2025. **[2024/12/20]** We release training and inference codes of CLEAR, a simple yet effective strategy to linearize the complexity of pre-trained diffusion transformers, such as FLUX and SD3. ## Introduction Diffusion Transformers (DiT) have become a leading architecture in image generation. However, the quadratic complexity of attention mechanisms, which are responsible for modeling token-wise relationships, results in significant latency when generating high-resolution images. To address this issue, we aim at a linear attention mechanism in this paper that reduces the complexity of pre-trained DiTs to linear. We begin our exploration with a comprehensive summary of existing efficient attention mechanisms and identify four key factors crucial for successful linearization of pre-trained DiTs: locality, formulation consistency, high-rank attention maps, and feature integrity. Based on these insights, we introduce a convolution-like local attention strategy termed CLEAR, which limits feature interactions to a local window around each query token, and thus achieves linear complexity. Our experiments indicate that, by fine-tuning the attention layer on merely 10K self-generated samples for 10K iterations, we can effectively transfer knowledge from a pre-trained DiT to a student model with linear complexity, yielding results comparable to the teacher model. Simultaneously, it reduces attention computations by 99.5% and accelerates generation by 6.3 times for generating 8K-resolution images. Furthermore, we investigate favorable properties in the distilled attention layers, such as zero-shot generalization across various models and plugins, and improved support for multi-GPU parallel inference. **TL;DR**: For pre-trained diffusion transformers, enforcing an image token to interact with only tokens within **a local window** can effectively reduce the complexity of the original models to a linear scale. ## Installation * CLEAR requires ``torch>=2.5.0``, ``diffusers>=0.31.0``, and other packages listed in ``requirements.txt``. You can set up a new experiment with: ```bash conda create -n CLEAR python=3.12 conda activate CLEAR pip install -r requirements.txt ``` * Clone this repo to your project directory: ``` bash git clone https://github.com/Huage001/CLEAR.git ``` ## Supported Models We release a series of variants for linearized [FLUX-1.dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with various local window sizes. We experimentally find that when the local window size is small, e.g., 8, the model can produce repetitive patterns in many cases. To alleviate the problem, in some variants, we also include down-sampled key-value tokens besides local tokens for attention interaction. The supported models and the download links are: | window_size | down_factor | link | | :---------: | :---------: | :----------------------------------------------------------: | | 32 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_32.safetensors) | | 16 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16.safetensors) | | 8 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8.safetensors) | | 16 | 4 | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16_down_4.safetensors) | | 8 | 4 | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors) | You are encouraged to download the model weights you need to ``ckpt`` beforehand. For example: ```bash mkdir ckpt wget https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors ``` ## Inference * If you want to compare the linearized FLUX with the original model, please try ``inference_t2i.ipynb``. * If you want to use CLEAR for high-resolution acceleration, please try ``inference_t2i_highres.ipynb``. We currently adopt the strategy of [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image). The basic idea is to generate a low-resolution result at first, based on which we gradually upscale the image. * Please configure ``down_factor`` and ``window_size`` in the notebooks to use different variants of CLEAR. If you do not want to include down-sampled key-value tokens, specify ``down_factor=1``. The models will be downloaded automatically to ``ckpt`` if not downloaded. * Currently, a GPU card with 48 GB VMem is recommended for high-resolution generation. ## Training * Configure ``/path/to/t2i_1024`` in multiple ``.sh`` files. * Download training images from [here](https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_1024_10K/data_000000.tar), which contains 10K 1024-resolution images generated by ``FLUX-1.dev`` itself, and unzip it to ``/path/to/t2i_1024``: ``` tar -xvf data_000000.tar -C /path/to/t2i_1024 ``` * [Optional but Recommended] Cache T5 and CLIP text embeddings and VAE features beforehand: ```bash bash cache_prompt_embeds.sh bash cache_latent_codes.sh ``` * Start Training: ```bash bash distill.sh ``` By default, it uses 4 80G-VMem GPUs with ``train_batch_size=2`` and ``gradient_accumulation_steps=4``. Please feel free to configure them in ``distill.sh`` and ``deepspeed_config.yaml`` according to your situations. ## Acknowledgement * [FLUX](https://blackforestlabs.ai/announcing-black-forest-labs/) for the source models. * [flexattention](https://pytorch.org/blog/flexattention/) for kernel implementation. * [diffusers](https://github.com/huggingface/diffusers) for the code base. * [DeepSpeed](https://github.com/microsoft/DeepSpeed) for the training framework. * [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image) for high-resolution image generation. * [@Weihao Yu](https://github.com/yuweihao) and [@Xinyin Ma](https://github.com/horseee) for valuable discussions. * NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001. ## Citation If you find this repo helpful, please consider citing: ```bib @article{liu2024clear, title = {CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up}, author = {Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao}, journal = {NeurIPS}, year = {2025}, } ``` ================================================ FILE: attention_processor.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn.attention.flex_attention import create_block_mask, flex_attention create_block_mask = torch.compile(create_block_mask) from diffusers.models.attention_processor import Attention from typing import Optional from functools import partial, lru_cache from diffusers.models.embeddings import apply_rotary_emb attn_outputs_teacher = [] attn_outputs = [] BLOCK_MASK = None HEIGHT = None WIDTH = None class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self, distill=False): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") self.distill = distill def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, proportional_attention=False ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) train_seq_len = 64 ** 2 + 512 if proportional_attention: attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim) else: attention_scale = math.sqrt(1 / head_dim) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) return hidden_states, encoder_hidden_states else: if self.distill: attn_outputs_teacher.append(hidden_states) return hidden_states @lru_cache def init_local_downsample_mask_flex(height, width, text_length, window_size, down_factor, device): def local_dwonsample_mask(b, h, q_idx, kv_idx): q_y = (q_idx - text_length) // width q_x = (q_idx - text_length) % width kv_y = (kv_idx - text_length) // width kv_x = (kv_idx - text_length) % width return torch.logical_or( torch.logical_and( q_idx < text_length, kv_idx < text_length + height * width ), torch.logical_and( q_idx >= text_length, torch.logical_or( torch.logical_or(kv_idx < text_length, kv_idx >= text_length + height * width), (q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2) ) ) global BLOCK_MASK, HEIGHT, WIDTH BLOCK_MASK = create_block_mask(local_dwonsample_mask, B=None, H=None, device=device, Q_LEN=text_length + height * width, KV_LEN=text_length + height * width + (height // down_factor) * (width // down_factor), _compile=True) HEIGHT = height WIDTH = width class LocalDownsampleFlexAttnProcessor(nn.Module): def __init__(self, down_factor=4, distill=False): super().__init__() assert BLOCK_MASK is not None self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK) self.flex_attn = torch.compile(self.flex_attn, dynamic=False) self.down_factor = down_factor self.spatial_weight = nn.Parameter(torch.ones(1, 1, 1, down_factor, 1, down_factor, 1) / (down_factor * down_factor)) self.distill = distill def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, proportional_attention=False ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) train_seq_len = 64 ** 2 + 512 if proportional_attention: attention_scale = math.sqrt(math.log(10 * key.size(2), train_seq_len) / head_dim) else: attention_scale = math.sqrt(1 / head_dim) key_downsample = (key[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor, WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3) value_downsample = (value[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor, WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3) hidden_states = self.flex_attn(query, torch.cat([key, key_downsample], dim=2), torch.cat([value, value_downsample], dim=2), scale=attention_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if self.distill: attn_outputs.append(hidden_states) return hidden_states, encoder_hidden_states else: if self.distill: attn_outputs.append(hidden_states) return hidden_states @lru_cache def init_local_mask_flex(height, width, text_length, window_size, device): def local_mask(b, h, q_idx, kv_idx): q_y = (q_idx - text_length) // width q_x = (q_idx - text_length) % width kv_y = (kv_idx - text_length) // width kv_x = (kv_idx - text_length) % width return torch.logical_or(torch.logical_or(q_idx < text_length, kv_idx < text_length), (q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2) global BLOCK_MASK, HEIGHT, WIDTH BLOCK_MASK = create_block_mask(local_mask, B=None, H=None, device=device, Q_LEN=text_length + height * width, KV_LEN=text_length + height * width, _compile=True) HEIGHT = height WIDTH = width class LocalFlexAttnProcessor: """Attention processor used typically in processing the SD3-like self-attention projections.""" def __init__(self, distill=False): super().__init__() self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK) self.flex_attn = torch.compile(self.flex_attn, dynamic=False) self.distill = distill def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: Optional[torch.FloatTensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, proportional_attention=False ) -> torch.FloatTensor: batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape # `sample` projections. query = attn.to_q(hidden_states) key = attn.to_k(hidden_states) value = attn.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` if encoder_hidden_states is not None: # `context` projections. encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( batch_size, -1, attn.heads, head_dim ).transpose(1, 2) if attn.norm_added_q is not None: encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) if attn.norm_added_k is not None: encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) # attention query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) train_seq_len = 64 ** 2 + 512 if proportional_attention: attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim) else: attention_scale = math.sqrt(1 / head_dim) hidden_states = self.flex_attn(query, key, value, scale=attention_scale) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) if encoder_hidden_states is not None: encoder_hidden_states, hidden_states = ( hidden_states[:, : encoder_hidden_states.shape[1]], hidden_states[:, encoder_hidden_states.shape[1] :], ) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) encoder_hidden_states = attn.to_add_out(encoder_hidden_states) if self.distill: attn_outputs.append(hidden_states) return hidden_states, encoder_hidden_states else: if self.distill: attn_outputs.append(hidden_states) return hidden_states ================================================ FILE: cache_latent_codes.py ================================================ import argparse import os import numpy as np from PIL import Image import math from safetensors.torch import save_file import torch import tqdm from diffusers import AutoencoderKL def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--variant", type=str, default=None, help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--data_root", type=str, default=None ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument( "--resolution", type=int, default=1024, help="resolution", ) parser.add_argument( "--output_dir", type=str, default="flux-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank") parser.add_argument( "--num_workers", type=int, default=1, help="Number of workers", ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() args.local_rank = int(os.environ.get("LOCAL_RANK", 0)) return args def main(args): if torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if args.mixed_precision == 'fp16': dtype = torch.float16 elif args.mixed_precision == 'bf16': dtype = torch.bfloat16 else: dtype = torch.float32 vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ).to(device, dtype) all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.jpg' in i or '.png' in i] os.makedirs(args.output_dir, exist_ok=True) work_load = math.ceil(len(all_info) / args.num_workers) for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)): images = [] paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_latent_code.safetensors') for item in all_info[idx:idx + args.batch_size]] for item in all_info[idx:idx + args.batch_size]: img = Image.open(os.path.join(args.data_root, item)).convert('RGB') img = img.resize((args.resolution, args.resolution)) img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1) images.append(img) with torch.no_grad(): images = torch.stack(images, dim=0) data = vae.encode(images.to(device, vae.dtype)).latent_dist means = data.mean.cpu().data stds = data.std.cpu().data for path, mean, std in zip(paths, means.unbind(), stds.unbind()): save_file( {'mean': mean, 'std': std}, path ) if __name__ == '__main__': main(parse_args()) ================================================ FILE: cache_latent_codes.sh ================================================ export NUM_WORKERS=4 export MODEL_NAME="black-forest-labs/FLUX.1-dev" torchrun --nproc_per_node=$NUM_WORKERS cache_latent_codes.py \ --data_root="/path/to/t2i_1024" \ --batch_size=16 \ --num_worker=$NUM_WORKERS \ --pretrained_model_name_or_path=$MODEL_NAME \ --mixed_precision='bf16' \ --output_dir="/path/to/t2i_1024" ================================================ FILE: cache_prompt_embeds.py ================================================ import argparse import os import json import math from safetensors.torch import save_file import torch import tqdm from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, T5EncoderModel def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--variant", type=str, default=None, help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--data_root", type=str, default=None ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument( "--max_sequence_length", type=int, default=512, help="Maximum sequence length to use with with the T5 text encoder", ) parser.add_argument( "--output_dir", type=str, default="flux-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank") parser.add_argument( "--num_workers", type=int, default=1, help="Number of workers", ) if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() args.local_rank = int(os.environ.get("LOCAL_RANK", 0)) return args def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) return prompt_embeds, pooled_prompt_embeds def main(args): if torch.cuda.is_available(): torch.cuda.set_device(args.local_rank) device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') if args.mixed_precision == 'fp16': dtype = torch.float16 elif args.mixed_precision == 'bf16': dtype = torch.bfloat16 else: dtype = torch.float32 tokenizer_one = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant, cache_dir=args.cache_dir ) tokenizer_two = T5TokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, variant=args.variant, cache_dir=args.cache_dir ) text_encoder_one = CLIPTextModel.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder", variant=args.variant, cache_dir=args.cache_dir ).to(device, dtype) text_encoder_two = T5EncoderModel.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, subfolder="text_encoder_2", variant=args.variant, cache_dir=args.cache_dir ).to(device, dtype) tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.json' in i] os.makedirs(args.output_dir, exist_ok=True) work_load = math.ceil(len(all_info) / args.num_workers) for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)): texts = [] for item in all_info[idx:idx + args.batch_size]: with open(os.path.join(args.data_root, item)) as f: texts.append(json.load(f)['prompt']) paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_prompt_embed.safetensors') for item in all_info[idx:idx + args.batch_size]] with torch.no_grad(): prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders, tokenizers, texts, args.max_sequence_length ) prompt_embeds = prompt_embeds.cpu().data pooled_prompt_embeds = pooled_prompt_embeds.cpu().data for path, prompt_embed, pooled_prompt_embed in zip(paths, prompt_embeds.unbind(), pooled_prompt_embeds.unbind()): save_file( {'caption_feature_t5': prompt_embed, 'caption_feature_clip': pooled_prompt_embed}, path ) if __name__ == '__main__': main(parse_args()) ================================================ FILE: cache_prompt_embeds.sh ================================================ export NUM_WORKERS=4 export MODEL_NAME="black-forest-labs/FLUX.1-dev" torchrun --nproc_per_node=$NUM_WORKERS cache_prompt_embeds.py \ --data_root="/path/to/t2i_1024" \ --batch_size=256 \ --num_worker=$NUM_WORKERS \ --pretrained_model_name_or_path=$MODEL_NAME \ --mixed_precision='bf16' \ --output_dir="/path/to/t2i_1024" ================================================ FILE: dataset.py ================================================ import os import io import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader import json import random import cv2 from safetensors import safe_open def image_resize(img, max_size=512): w, h = img.size if w >= h: new_w = max_size new_h = int((max_size / w) * h) else: new_h = max_size new_w = int((max_size / h) * w) new_w = (w // 32) * 32 new_h = (h // 32) * 32 return img.resize((new_w, new_h)) def c_crop(image): width, height = image.size new_size = min(width, height) left = (width - new_size) / 2 top = (height - new_size) / 2 right = (width + new_size) / 2 bottom = (height + new_size) / 2 return image.crop((left, top, right, bottom)) def crop_to_aspect_ratio(image, ratio="16:9"): width, height = image.size ratio_map = { "16:9": (16, 9), "4:3": (4, 3), "1:1": (1, 1) } target_w, target_h = ratio_map[ratio] target_ratio_value = target_w / target_h current_ratio = width / height if current_ratio > target_ratio_value: new_width = int(height * target_ratio_value) offset = (width - new_width) // 2 crop_box = (offset, 0, offset + new_width, height) else: new_height = int(width / target_ratio_value) offset = (height - new_height) // 2 crop_box = (0, offset, width, offset + new_height) cropped_img = image.crop(crop_box) return cropped_img class CustomImageDataset(Dataset): def __init__(self, img_dir, img_size=512, caption_type='json', random_ratio=False, use_cached_prompt_embeds=False, use_cached_latent_codes=False): self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i] self.images.sort() self.img_size = img_size self.caption_type = caption_type self.random_ratio = random_ratio self.use_cached_prompt_embeds = use_cached_prompt_embeds self.use_cached_latent_codes = use_cached_latent_codes def __len__(self): return len(self.images) def __getitem__(self, idx): try: batch = {} if not self.use_cached_latent_codes: img = Image.open(self.images[idx]).convert('RGB') if self.random_ratio: ratio = random.choice(["16:9", "default", "1:1", "4:3"]) if ratio != "default": img = crop_to_aspect_ratio(img, ratio) img = image_resize(img, self.img_size) img = torch.from_numpy((np.array(img) / 127.5) - 1) img = img.permute(2, 0, 1) batch['images'] = img else: with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_latent_code.safetensors', framework="pt") as f: batch['latent_codes_mean'] = f.get_tensor('mean') batch['latent_codes_std'] = f.get_tensor('std') if not self.use_cached_prompt_embeds: json_path = self.images[idx].split('.')[0] + '.' + self.caption_type if self.caption_type == "json": prompt = json.load(open(json_path))['prompt'] else: prompt = open(json_path).read() batch['prompts'] = prompt else: with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_prompt_embed.safetensors', framework="pt") as f: batch['prompt_embeds_t5'] = f.get_tensor('caption_feature_t5') batch['prompt_embeds_clip'] = f.get_tensor('caption_feature_clip') return batch except Exception as e: print(e) return self.__getitem__(random.randint(0, len(self.images) - 1)) def loader(train_batch_size, num_workers, **args): dataset = CustomImageDataset(**args) return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True) ================================================ FILE: deepspeed_config.yaml ================================================ compute_environment: LOCAL_MACHINE debug: false deepspeed_config: gradient_accumulation_steps: 4 gradient_clipping: 1.0 offload_optimizer_device: none offload_param_device: none zero3_init_flag: false zero_stage: 2 distributed_type: DEEPSPEED downcast_bf16: 'no' enable_cpu_affinity: false machine_rank: 0 main_training_function: main mixed_precision: bf16 num_machines: 1 num_processes: 4 rdzv_backend: static same_network: true tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false ================================================ FILE: distill.py ================================================ import argparse import copy import logging import math import os import shutil from contextlib import nullcontext from pathlib import Path import numpy as np import torch import torch.utils.checkpoint from torch.utils.data import RandomSampler import transformers from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast import diffusers from diffusers import ( AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel, ) from diffusers.optimization import get_scheduler from diffusers.training_utils import ( cast_training_params, compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory, ) from diffusers.utils import ( check_min_version, is_wandb_available, ) from diffusers.utils.torch_utils import is_compiled_module from safetensors.torch import save_file from safetensors import safe_open from dataset import loader from attention_processor import FluxAttnProcessor2_0, LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex from attention_processor import attn_outputs_teacher, attn_outputs if is_wandb_available(): import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.31.0.dev0") logger = get_logger(__name__) def load_text_encoders(class_one, class_two): text_encoder_one = class_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant ) text_encoder_two = class_two.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant ) return text_encoder_one, text_encoder_two def log_validation( pipeline, args, accelerator, pipeline_args, epoch, torch_dtype, is_final_validation=False, ): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() with autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") if tracker.name == "wandb": tracker.log( { phase_name: [ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) ] } ) del pipeline if torch.cuda.is_available(): torch.cuda.empty_cache() return images def import_model_class_from_model_name_or_path( pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision ) model_class = text_encoder_config.architectures[0] if model_class == "CLIPTextModel": from transformers import CLIPTextModel return CLIPTextModel elif model_class == "T5EncoderModel": from transformers import T5EncoderModel return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--revision", type=str, default=None, required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) parser.add_argument( "--variant", type=str, default=None, help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", ) parser.add_argument( "--data_root", type=str, default=None ) parser.add_argument( "--cache_dir", type=str, default=None, help="The directory where the downloaded models and datasets will be stored.", ) parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--max_sequence_length", type=int, default=512, help="Maximum sequence length to use with with the T5 text encoder", ) parser.add_argument( "--validation_prompt", type=str, default=None, help="A prompt that is used during validation to verify that the model is learning.", ) parser.add_argument( "--num_validation_images", type=int, default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) parser.add_argument( "--validation_epochs", type=int, default=1, help=( "Run validation every X epochs. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`." ), ) parser.add_argument( "--output_dir", type=str, default="flux-attn", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, default=1024, help=( "The resolution for input images" ) ) parser.add_argument( "--center_crop", default=False, action="store_true", help=( "Whether to center crop the input images to the resolution. If not set, the images will be randomly" " cropped. The images will be resized to the resolution first before cropping." ), ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." ) parser.add_argument("--num_train_epochs", type=int, default=1) parser.add_argument( "--max_train_steps", type=int, default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( "--window_size", type=int, default=16, help=( "Size of local window for attention sampling." ), ) parser.add_argument( "--down_factor", type=int, default=1, help=( "Factor of downsampling for key-value tokens." ), ) parser.add_argument( "--checkpointing_steps", type=int, default=10000, help=( "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" " training using `--resume_from_checkpoint`." ), ) parser.add_argument( "--checkpoints_total_limit", type=int, default=None, help=("Max number of checkpoints to store."), ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help=( "Whether training should be resumed from a previous checkpoint. Use a path saved by" ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) parser.add_argument( "--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( "--gradient_checkpointing", action="store_true", help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( "--guidance_scale", type=float, default=3.5, help="the FLUX.1 dev variant is a guidance distilled model", ) parser.add_argument( "--scale_lr", action="store_true", default=False, help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", ) parser.add_argument( "--lr_scheduler", type=str, default="constant", help=( 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' ' "constant", "constant_with_warmup"]' ), ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( "--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--weighting_scheme", type=str, default="none", choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"], help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'), ) parser.add_argument( "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." ) parser.add_argument( "--mode_scale", type=float, default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) parser.add_argument( "--optimizer", type=str, default="AdamW", help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", ) parser.add_argument( "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( "--prodigy_beta3", type=float, default=None, help="coefficients for computing the Prodigy stepsize using running averages. If set to None, " "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( "--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer and Prodigy optimizers.", ) parser.add_argument( "--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", ) parser.add_argument( "--prodigy_safeguard_warmup", type=bool, default=True, help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--allow_tf32", action="store_true", help=( "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" ), ) parser.add_argument( "--use_cached_latent", action="store_true", default=False, help="Cache the VAE latents", ) parser.add_argument( "--use_cached_prompt_embed", action="store_true", default=False, help="Use cached T5 and CLIP features", ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--upcast_before_saving", action="store_true", default=False, help=( "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). " "Defaults to precision dtype used for training to save memory" ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") if input_args is not None: args = parser.parse_args(input_args) else: args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank return args def tokenize_prompt(tokenizer, prompt, max_sequence_length): text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids return text_input_ids def _encode_prompt_with_t5( text_encoder, tokenizer, max_sequence_length=512, prompt=None, num_images_per_prompt=1, device=None, text_input_ids=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device))[0] dtype = text_encoder.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def _encode_prompt_with_clip( text_encoder, tokenizer, prompt: str, device=None, text_input_ids=None, num_images_per_prompt: int = 1, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if tokenizer is not None: text_inputs = tokenizer( prompt, padding="max_length", max_length=77, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids else: if text_input_ids is None: raise ValueError("text_input_ids must be provided when the tokenizer is not specified") prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds def encode_prompt( text_encoders, tokenizers, prompt: str, max_sequence_length, device=None, num_images_per_prompt: int = 1, text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( text_encoder=text_encoders[0], tokenizer=tokenizers[0], prompt=prompt, device=device if device is not None else text_encoders[0].device, num_images_per_prompt=num_images_per_prompt, text_input_ids=text_input_ids_list[0] if text_input_ids_list else None, ) prompt_embeds = _encode_prompt_with_t5( text_encoder=text_encoders[1], tokenizer=tokenizers[1], max_sequence_length=max_sequence_length, prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device if device is not None else text_encoders[1].device, text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids def main(args): if torch.backends.mps.is_available() and args.mixed_precision == "bf16": # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config ) # Disable AMP for MPS. if torch.backends.mps.is_available(): accelerator.native_amp = False if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) # Handle the repository creation if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load the tokenizers tokenizer_one = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, ) tokenizer_two = T5TokenizerFast.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, ) # import correct text encoder classes text_encoder_cls_one = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision ) text_encoder_cls_two = import_model_class_from_model_name_or_path( args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" ) # Load scheduler and models noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder="scheduler" ) text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) transformer = FluxTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) transformer_teacher = FluxTransformer2DModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant ) if args.down_factor == 1: init_local_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size, device=accelerator.device) attn_processors = {} for idx, name in enumerate(transformer.attn_processors): attn_processors[name] = LocalFlexAttnProcessor(distill='single' in name and idx % 4 == 0) else: init_local_downsample_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size, down_factor=args.down_factor, device=accelerator.device) attn_processors = {} for idx, name in enumerate(transformer.attn_processors): attn_processors[name] = LocalDownsampleFlexAttnProcessor(down_factor=args.down_factor, distill='single' in name and idx % 4 == 0) attn_processors_teacher = {} for idx, name in enumerate(transformer_teacher.attn_processors.keys()): attn_processors_teacher[name] = FluxAttnProcessor2_0(distill='single' in name and idx % 4 == 0) transformer.set_attn_processor(attn_processors) transformer_teacher.set_attn_processor(attn_processors_teacher) # We only train the Attn layers transformer.requires_grad_(False) transformer_teacher.requires_grad_(False) vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. raise ValueError( "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) vae.to(accelerator.device, dtype=weight_dtype) transformer.to(accelerator.device, dtype=weight_dtype) transformer_teacher.to(accelerator.device, dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() for _name, _param in transformer.named_parameters(): if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name: _param.requires_grad = True def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model return model # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True if args.scale_lr: args.learning_rate = ( args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": models = [transformer] # only upcast trainable parameters (Attn) into fp32 cast_training_params(models, dtype=torch.float32) # Optimization parameters transformer_attn_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) transformer_parameters_with_lr = {"params": transformer_attn_parameters, "lr": args.learning_rate} params_to_optimize = [transformer_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): logger.warning( f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." "Defaulting to adamW" ) args.optimizer = "adamw" if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warning( f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) if args.optimizer.lower() == "adamw": if args.use_8bit_adam: try: import bitsandbytes as bnb except ImportError: raise ImportError( "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." ) optimizer_class = bnb.optim.AdamW8bit else: optimizer_class = torch.optim.AdamW optimizer = optimizer_class( params_to_optimize, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, ) if args.optimizer.lower() == "prodigy": try: import prodigyopt except ImportError: raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") optimizer_class = prodigyopt.Prodigy if args.learning_rate <= 0.1: logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple, use_bias_correction=args.prodigy_use_bias_correction, safeguard_warmup=args.prodigy_safeguard_warmup, ) # Dataset and DataLoaders creation: train_dataloader = loader(train_batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, img_dir=args.data_root, img_size=args.resolution, use_cached_prompt_embeds=args.use_cached_prompt_embed, use_cached_latent_codes=args.use_cached_latent) tokenizers = [tokenizer_one, tokenizer_two] text_encoders = [text_encoder_one, text_encoder_two] def compute_text_embeddings(prompt, text_encoders, tokenizers): with torch.no_grad(): prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt( text_encoders, tokenizers, prompt, args.max_sequence_length ) prompt_embeds = prompt_embeds.to(accelerator.device) pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) text_ids = text_ids.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds, text_ids # Clear the memory here if args.use_cached_prompt_embed: del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, text_encoders, tokenizers free_memory() vae_config_shift_factor = vae.config.shift_factor vae_config_scaling_factor = vae.config.scaling_factor vae_config_block_out_channels = vae.config.block_out_channels if args.use_cached_latent: del vae free_memory() # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if args.max_train_steps is None: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, num_training_steps=args.max_train_steps * accelerator.num_processes, num_cycles=args.lr_num_cycles, power=args.lr_power, ) # Prepare everything with our `accelerator`. guidance_embeds = transformer.config.guidance_embeds transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( transformer, optimizer, train_dataloader, lr_scheduler ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) if overrode_max_train_steps: args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch # Afterwards we recalculate our number of training epochs args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_name = "flux-dev-attn" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = args.resume_from_checkpoint else: # Get the mos recent checkpoint dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(path) global_step = int(path.split("-")[-1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 progress_bar = tqdm( range(0, args.max_train_steps), initial=initial_global_step, desc="Steps", # Only show the progress bar once on each machine. disable=not accelerator.is_local_main_process, ) def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype) schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device) timesteps = timesteps.to(accelerator.device) step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < n_dim: sigma = sigma.unsqueeze(-1) return sigma for epoch in range(first_epoch, args.num_train_epochs): transformer.train() for _, batch in enumerate(train_dataloader): with accelerator.accumulate(transformer): if not args.use_cached_prompt_embed: prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( batch['prompts'], text_encoders, tokenizers ) else: prompt_embeds = batch['prompt_embeds_t5'].to(dtype=weight_dtype) pooled_prompt_embeds = batch['prompt_embeds_clip'].to(dtype=weight_dtype) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=weight_dtype) with torch.no_grad(): # Convert images to latent space if args.use_cached_latent: mean = batch['latent_codes_mean'].to(dtype=weight_dtype) std = batch['latent_codes_std'].to(dtype=weight_dtype) sample = torch.randn_like(mean) model_input = mean + std * sample else: model_input = vae.encode(batch[0].to(dtype=vae.dtype)).latent_dist.sample() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor vae_scale_factor = 2 ** (len(vae_config_block_out_channels)) latent_image_ids = FluxPipeline._prepare_latent_image_ids( model_input.shape[0], model_input.shape[2], model_input.shape[3], accelerator.device, weight_dtype, ) # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly u = compute_density_for_timestep_sampling( weighting_scheme=args.weighting_scheme, batch_size=bsz, logit_mean=args.logit_mean, logit_std=args.logit_std, mode_scale=args.mode_scale, ) indices = (u * noise_scheduler.config.num_train_timesteps).long() timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device) # Add noise according to flow matching. # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise packed_noisy_model_input = FluxPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], num_channels_latents=model_input.shape[1], height=model_input.shape[2], width=model_input.shape[3], ) # handle guidance if guidance_embeds: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) else: guidance = None teacher_pred = transformer_teacher( hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, return_dict=False, )[0] teacher_pred = FluxPipeline._unpack_latents( teacher_pred, height=int(model_input.shape[2] * vae_scale_factor / 2), width=int(model_input.shape[3] * vae_scale_factor / 2), vae_scale_factor=vae_scale_factor, ) # Predict the noise residual model_pred = transformer( hidden_states=packed_noisy_model_input, # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) timestep=timesteps / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, return_dict=False, )[0] model_pred = FluxPipeline._unpack_latents( model_pred, height=int(model_input.shape[2] * vae_scale_factor / 2), width=int(model_input.shape[3] * vae_scale_factor / 2), vae_scale_factor=vae_scale_factor, ) # these weighting schemes use a uniform timestep sampling # and instead post-weight the loss weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) # flow matching loss target = noise - model_input # Compute regular loss. loss_fm = (weighting.float() * (model_pred.float() - target.float()) ** 2).mean() loss_distill = (weighting.float() * (model_pred.float() - teacher_pred.float()) ** 2).mean() loss_attn = sum([(weighting.float().squeeze(-1) * (attn_output.float() - attn_output_teacher.float()) ** 2).mean() for attn_output, attn_output_teacher in zip(attn_outputs, attn_outputs_teacher)]) / len(attn_outputs) loss = loss_fm + loss_distill * 0.5 + loss_attn * 0.5 accelerator.backward(loss) if accelerator.sync_gradients: params_to_clip = ( transformer.parameters() ) accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() lr_scheduler.step() optimizer.zero_grad() attn_outputs.clear() attn_outputs_teacher.clear() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: progress_bar.update(1) global_step += 1 if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` if args.checkpoints_total_limit is not None: checkpoints = os.listdir(args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints if len(checkpoints) >= args.checkpoints_total_limit: num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"loss_fm": loss_fm.detach().item(), "loss_distill": loss_distill.detach().item(), "loss_attn": loss_attn.detach().item(), "loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break free_memory() if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # create pipeline if args.use_cached_prompt_embed: text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two) if args.use_cached_latent: vae = AutoencoderKL.from_pretrained( args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant, ) pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, text_encoder=accelerator.unwrap_model(text_encoder_one), text_encoder_2=accelerator.unwrap_model(text_encoder_two), transformer=accelerator.unwrap_model(transformer), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) pipeline_args = {"prompt": args.validation_prompt} images = log_validation( pipeline=pipeline, args=args, accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, torch_dtype=weight_dtype, ) if args.use_cached_prompt_embed: del text_encoder_one, text_encoder_two if args.use_cached_latent: del vae free_memory() # Save the attn weights accelerator.wait_for_everyone() if accelerator.is_main_process: transformer = unwrap_model(transformer) if args.upcast_before_saving: transformer.to(torch.float32) else: transformer = transformer.to(weight_dtype) state_dict = {} for _name, _param in transformer.named_parameters(): if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name: state_dict[_name] = _param save_file( state_dict, os.path.join(args.output_dir, 'attn_weights.safetensors') ) # Final inference # Load previous pipeline free_memory() pipeline = FluxPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) # load attention processors state_dict = {} with safe_open(os.path.join(args.output_dir, 'attn_weights.safetensor'), framework="pt") as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(state_dict) missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or '.attn.to_k.' in p or '.attn.to_v.' in p or '.attn.to_out.' in p or 'spatial_weight' in p), missing_keys)) if len(missing_keys) != 0 or len(unexpected_keys) != 0: logger.warning( f"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}" f" and missing keys: {missing_keys}." ) # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: pipeline_args = {"prompt": args.validation_prompt} images = log_validation( pipeline=pipeline, args=args, accelerator=accelerator, pipeline_args=pipeline_args, epoch=epoch, is_final_validation=True, torch_dtype=weight_dtype, ) for idx, image in enumerate(images): image.save(os.path.join(args.output_dir, 'validation_idx_%02d.jpg' % idx)) accelerator.end_training() if __name__ == "__main__": args = parse_args() main(args) ================================================ FILE: distill.sh ================================================ export MODEL_NAME="black-forest-labs/FLUX.1-dev" export DATAPATH="/path/to/t2i_1024" export OUTPUT_DIR="ckpt/training_exp" export PRECISION="bf16" accelerate launch --config_file deepspeed_config.yaml distill.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --data_root=$DATAPATH \ --output_dir=$OUTPUT_DIR \ --mixed_precision=$PRECISION \ --dataloader_num_workers=8 \ --resolution=1024 \ --train_batch_size=2 \ --gradient_accumulation_steps=4 \ --optimizer="prodigy" \ --learning_rate=1. \ --report_to="wandb" \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --max_train_steps=50000 \ --validation_epochs=1 \ --seed="0" \ --checkpointing_steps=5000 \ --use_cached_prompt_embed \ --use_cached_latent \ --gradient_checkpointing \ --down_factor=1 \ --window_size=16 ================================================ FILE: inference_t2i.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "import requests\n", "from safetensors.torch import load_file\n", "from diffusers import FluxPipeline\n", "from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n", "device = torch.device('cuda')\n", "dtype = torch.bfloat16\n", "pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "height = 1024\n", "width = 1024\n", "down_factor, window_size = 4, 8\n", "# Supported Configurations:\n", "# down_factor, window_size = 1, 8\n", "# down_factor, window_size = 1, 16\n", "# down_factor, window_size = 1, 32\n", "# down_factor, window_size = 4, 16\n", "# down_factor, window_size = 4, 8\n", "if down_factor == 1:\n", " init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\n", " attn_processors = {}\n", " for k in pipe.transformer.attn_processors.keys():\n", " attn_processors[k] = LocalFlexAttnProcessor()\n", "else:\n", " init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\n", " attn_processors = {}\n", " for k in pipe.transformer.attn_processors.keys():\n", " attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\n", "pipe.transformer.set_attn_processor(attn_processors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not os.path.exists('ckpt'):\n", " os.mkdir('ckpt')\n", "if down_factor == 1:\n", " if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\n", " print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\n", " response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\")\n", " response.raise_for_status()\n", " with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\n", " f.write(response.content)\n", " state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\n", "else:\n", " if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\n", " print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n", " response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\")\n", " response.raise_for_status()\n", " with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\n", " f.write(response.content)\n", " state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n", "\n", "missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\n", "\n", "missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n", " '.attn.to_k.' in p or \n", " '.attn.to_v.' in p or \n", " '.attn.to_out.' in p or \n", " 'spatial_weight' in p), missing_keys))\n", "\n", "if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n", " print(\n", " f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n", " f\" and missing keys: {missing_keys}.\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"A Mickey is eating a pie\"\n", "height = 1024\n", "width = 1024\n", "image = pipe(\n", " prompt,\n", " height=height,\n", " width=width,\n", " guidance_scale=3.5,\n", " num_inference_steps=20,\n", " max_sequence_length=512,\n", " generator=torch.Generator(\"cpu\").manual_seed(0)\n", ").images[0]\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pt25", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: inference_t2i_highres.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import os\n", "import gc\n", "import requests\n", "from safetensors.torch import load_file\n", "from diffusers import FluxPipeline\n", "from pipeline_flux_img2img import FluxImg2ImgPipeline\n", "from transformer_flux import FluxTransformer2DModel\n", "from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n", "device = torch.device('cuda')\n", "dtype = torch.bfloat16\n", "pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "prompt = \"enchanted forest, glowing plants, towering ancient trees, a mystical girl, magical aura, \" \\\n", " \"fantasy style, vibrant colors, ethereal lighting, bokeh effect, ultra-detailed, painterly, ultra HD, 8K, \" \\\n", " \"soft glowing lights, mist and fog, otherworldly ambiance, glowing mushrooms, sparkling particles\"\n", "height = 512\n", "width = 1024\n", "image = pipe(\n", " prompt,\n", " height=height,\n", " width=width,\n", " guidance_scale=3.5,\n", " num_inference_steps=20,\n", " max_sequence_length=512,\n", " generator=torch.Generator(\"cpu\").manual_seed(0)\n", ").images[0]\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "del pipe\n", "torch.cuda.empty_cache()\n", "gc.collect()\n", "transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n", "pipe = FluxImg2ImgPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=torch.bfloat16)\n", "pipe.transformer = transformer\n", "pipe.scheduler.config.use_dynamic_shifting = False\n", "pipe.scheduler.config.time_shift = 10\n", "pipe.vae.enable_tiling()\n", "pipe = pipe.to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "height = 2048\n", "width = 4096\n", "down_factor, window_size = 4, 8\n", "# Supported Configurations:\n", "# down_factor, window_size = 1, 8\n", "# down_factor, window_size = 1, 16\n", "# down_factor, window_size = 1, 32\n", "# down_factor, window_size = 4, 16\n", "# down_factor, window_size = 4, 8\n", "if down_factor == 1:\n", " init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\n", " attn_processors = {}\n", " for k in pipe.transformer.attn_processors.keys():\n", " attn_processors[k] = LocalFlexAttnProcessor()\n", "else:\n", " init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\n", " attn_processors = {}\n", " for k in pipe.transformer.attn_processors.keys():\n", " attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\n", "pipe.transformer.set_attn_processor(attn_processors)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if not os.path.exists('ckpt'):\n", " os.mkdir('ckpt')\n", "if down_factor == 1:\n", " if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\n", " print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\n", " response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\")\n", " response.raise_for_status()\n", " with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\n", " f.write(response.content)\n", " state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\n", "else:\n", " if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\n", " print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n", " response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\")\n", " response.raise_for_status()\n", " with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\n", " f.write(response.content)\n", " state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n", "\n", "missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\n", "\n", "missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n", " '.attn.to_k.' in p or \n", " '.attn.to_v.' in p or \n", " '.attn.to_out.' in p or \n", " 'spatial_weight' in p), missing_keys))\n", "\n", "if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n", " print(\n", " f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n", " f\" and missing keys: {missing_keys}.\"\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "strength = 0.7\n", "image_hr = pipe(prompt=prompt,\n", " image=image.resize((width, height)),\n", " strength=strength,\n", " num_inference_steps=20, \n", " guidance_scale=7.5, \n", " height=height,\n", " width=width,\n", " ntk_factor=10,\n", " proportional_attention=True,\n", " generator=torch.Generator(\"cpu\").manual_seed(0)\n", " ).images[0]\n", "image_hr" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "pt25", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: pipeline_flux_img2img.py ================================================ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models.autoencoders import AutoencoderKL from diffusers.models.transformers import FluxTransformer2DModel from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from diffusers.utils import ( USE_PEFT_BACKEND, is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers, ) from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ Examples: ```py >>> import torch >>> from diffusers import FluxImg2ImgPipeline >>> from diffusers.utils import load_image >>> device = "cuda" >>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) >>> pipe = pipe.to(device) >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" >>> init_image = load_image(url).resize((1024, 1024)) >>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k" >>> images = pipe( ... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0 ... ).images[0] ``` """ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.16, ): m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len mu = image_seq_len * m + b return mu # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": return encoder_output.latent_dist.mode() elif hasattr(encoder_output, "latents"): return encoder_output.latents else: raise AttributeError("Could not access latents of provided encoder_output") # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, **kwargs, ): r""" Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. sigmas (`List[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if timesteps is not None: accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) elif sigmas is not None: accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accept_sigmas: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" sigmas schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin): r""" The Flux pipeline for image inpainting. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Args: transformer ([`FluxTransformer2DModel`]): Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. text_encoder ([`CLIPTextModel`]): [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. text_encoder_2 ([`T5EncoderModel`]): [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. tokenizer (`CLIPTokenizer`): Tokenizer of class [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer). tokenizer_2 (`T5TokenizerFast`): Second Tokenizer of class [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] def __init__( self, scheduler: FlowMatchEulerDiscreteScheduler, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text_encoder_2: T5EncoderModel, tokenizer_2: T5TokenizerFast, transformer: FluxTransformer2DModel, ): super().__init__() self.register_modules( vae=vae, text_encoder=text_encoder, text_encoder_2=text_encoder_2, tokenizer=tokenizer, tokenizer_2=tokenizer_2, transformer=transformer, scheduler=scheduler, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16 ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) self.default_sample_size = 64 # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( self, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2) text_inputs = self.tokenizer_2( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] dtype = self.text_encoder_2.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, ): device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, self.tokenizer) text_inputs = self.tokenizer( prompt, padding="max_length", max_length=self.tokenizer_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1]) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {self.tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt def encode_prompt( self, prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): r""" Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or self._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): self._lora_scale = lora_scale # dynamically adjust the LoRA scale if self.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = self._get_clip_prompt_embeds( prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = self._get_t5_prompt_embeds( prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) if self.text_encoder is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder, lora_scale) if self.text_encoder_2 is not None: if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder_2, lora_scale) dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor return image_latents # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(num_inference_steps * strength, num_inference_steps) t_start = int(max(num_inference_steps - init_timestep, 0)) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] if hasattr(self.scheduler, "set_begin_index"): self.scheduler.set_begin_index(t_start * self.scheduler.order) return timesteps, num_inference_steps - t_start def check_inputs( self, prompt, prompt_2, strength, height, width, prompt_embeds=None, pooled_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt_2 is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" " only forward one of the two." ) elif prompt is None and prompt_embeds is None: raise ValueError( "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." ) elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." ) if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids def _prepare_latent_image_ids(batch_size, height, width, device, dtype): latent_image_ids = torch.zeros(height // 2, width // 2, 3) latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape latent_image_ids = latent_image_ids.reshape( latent_image_id_height * latent_image_id_width, latent_image_id_channels ) return latent_image_ids.to(device=device, dtype=dtype) @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) return latents @staticmethod # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape height = height // vae_scale_factor width = width // vae_scale_factor latents = latents.view(batch_size, height, width, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2) return latents def prepare_latents( self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) height = 2 * (int(height) // self.vae_scale_factor) width = 2 * (int(width) // self.vae_scale_factor) shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype) if latents is not None: return latents.to(device=device, dtype=dtype), latent_image_ids image = image.to(device=device, dtype=dtype) image_latents = self._encode_vae_image(image=image, generator=generator) if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size additional_image_per_prompt = batch_size // image_latents.shape[0] image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) else: image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids, image_latents @property def guidance_scale(self): return self._guidance_scale @property def joint_attention_kwargs(self): return self._joint_attention_kwargs @property def num_timesteps(self): return self._num_timesteps @property def interrupt(self): return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.6, num_inference_steps: int = 28, timesteps: List[int] = None, guidance_scale: float = 7.0, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ntk_factor: float = 10.0, proportional_attention: bool = True ): r""" Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is will be used instead image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. This is set to 1024 by default for the best results. strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 essentially ignores `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. timesteps (`List[int]`, *optional*): Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed will be used. Must be in descending order. guidance_scale (`float`, *optional*, defaults to 7.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by `callback_on_step_end_tensor_inputs`. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the `._callback_tensor_inputs` attribute of your pipeline class. max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. Examples: Returns: [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, strength, height, width, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) self._guidance_scale = guidance_scale self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Preprocess image init_image = self.image_processor.preprocess(image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device lora_scale = ( self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None ) ( prompt_embeds, pooled_prompt_embeds, text_ids, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, device=device, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, lora_scale=lora_scale, ) # 4.Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) timesteps, num_inference_steps_ = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps, sigmas ) timesteps, num_inference_steps_ = self.get_timesteps(num_inference_steps_, strength, device) if num_inference_steps_ < 1: raise ValueError( f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" f"steps is {num_inference_steps_} which is < 1 and not appropriate for this pipeline." ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 latents, latent_image_ids, image_latents = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) num_warmup_steps = max(len(timesteps) - num_inference_steps_ * self.scheduler.order, 0) self._num_timesteps = len(timesteps) # handle guidance if self.transformer.config.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = guidance.expand(latents.shape[0]) else: guidance = None # 6. Denoising loop with self.progress_bar(total=num_inference_steps_) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, pooled_projections=pooled_prompt_embeds, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, ntk_factor=ntk_factor, proportional_attention=proportional_attention, )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] if latents.dtype != latents_dtype: if torch.backends.mps.is_available(): # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() if output_type == "latent": image = latents else: latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return FluxPipelineOutput(images=image) ================================================ FILE: requirements.txt ================================================ torch>=2.5.0 diffusers>=0.31.0 transformers safetensors opencv-python accelerate tqdm deepspeed wandb prodigyopt sentencepiece ================================================ FILE: transformer_flux.py ================================================ # Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional, Tuple, Union, List import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import ( Attention, AttentionProcessor ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed from diffusers.models.modeling_outputs import Transformer2DModelOutput from attention_processor import FluxAttnProcessor2_0 logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True, ) def forward( self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, proportional_attention: bool = False, joint_attention_kwargs=None ): residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, proportional_attention=proportional_attention, **joint_attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = residual + hidden_states if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) return hidden_states @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): r""" A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. Reference: https://arxiv.org/abs/2403.03206 Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the processing of `context` conditions. """ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=processor, qk_norm=qk_norm, eps=eps, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") # let chunk size default to None self._chunk_size = None self._chunk_dim = 0 def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, proportional_attention: bool = False, joint_attention_kwargs=None, ): norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, proportional_attention=proportional_attention, **joint_attention_kwargs, ) # Process attention outputs for the `hidden_states`. attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = hidden_states + ff_output # Process attention outputs for the `encoder_hidden_states`. context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) return encoder_hidden_states, hidden_states class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype, ntk_factor=ntk_factor ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): """ The Transformer model introduced in Flux. Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ Parameters: patch_size (`int`): Patch size to turn the input data into small patches. in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. """ _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() self.out_channels = in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): """Disables the fused QKV projection if enabled. This API is 🧪 experimental. """ if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, ntk_factor: float = 1, proportional_attention: bool = False, controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) txt_ids = txt_ids[0] if img_ids.ndim == 3: logger.warning( "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor) for index_block, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, proportional_attention, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, proportional_attention=proportional_attention, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. if controlnet_blocks_repeat: hidden_states = ( hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, proportional_attention, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, proportional_attention=proportional_attention, joint_attention_kwargs=joint_attention_kwargs, ) # controlnet residual if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output)