Repository: Huage001/CLEAR
Branch: main
Commit: ddca13160b7c
Files: 16
Total size: 177.7 KB
Directory structure:
gitextract_7vgr3kj8/
├── LICENSE
├── README.md
├── attention_processor.py
├── cache_latent_codes.py
├── cache_latent_codes.sh
├── cache_prompt_embeds.py
├── cache_prompt_embeds.sh
├── dataset.py
├── deepspeed_config.yaml
├── distill.py
├── distill.sh
├── inference_t2i.ipynb
├── inference_t2i_highres.ipynb
├── pipeline_flux_img2img.py
├── requirements.txt
└── transformer_flux.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# CLEAR
> **CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up**
>
> [Songhua Liu](http://huage001.github.io/),
> [Zhenxiong Tan](https://scholar.google.com/citations?user=HP9Be6UAAAAJ&hl=en),
> and
> [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
>
> NeurIPS 2025
>
> [xML Lab](https://sites.google.com/view/xml-nus), National University of Singapore
>

## 🔥News
**[2025/9/18]** CLEAR is accepted at NeurIPS 2025.
**[2024/12/20]** We release training and inference codes of CLEAR, a simple yet effective strategy to linearize the complexity of pre-trained diffusion transformers, such as FLUX and SD3.
## Introduction
Diffusion Transformers (DiT) have become a leading architecture in image generation. However, the quadratic complexity of attention mechanisms, which are responsible for modeling token-wise relationships, results in significant latency when generating high-resolution images. To address this issue, we aim at a linear attention mechanism in this paper that reduces the complexity of pre-trained DiTs to linear. We begin our exploration with a comprehensive summary of existing efficient attention mechanisms and identify four key factors crucial for successful linearization of pre-trained DiTs: locality, formulation consistency, high-rank attention maps, and feature integrity. Based on these insights, we introduce a convolution-like local attention strategy termed CLEAR, which limits feature interactions to a local window around each query token, and thus achieves linear complexity.
Our experiments indicate that, by fine-tuning the attention layer on merely 10K self-generated samples for 10K iterations, we can effectively transfer knowledge from a pre-trained DiT to a student model with linear complexity, yielding results comparable to the teacher model. Simultaneously, it reduces attention computations by 99.5% and accelerates generation by 6.3 times for generating 8K-resolution images. Furthermore, we investigate favorable properties in the distilled attention layers, such as zero-shot generalization across various models and plugins, and improved support for multi-GPU parallel inference.
**TL;DR**: For pre-trained diffusion transformers, enforcing an image token to interact with only tokens within **a local window** can effectively reduce the complexity of the original models to a linear scale.
## Installation
* CLEAR requires ``torch>=2.5.0``, ``diffusers>=0.31.0``, and other packages listed in ``requirements.txt``. You can set up a new experiment with:
```bash
conda create -n CLEAR python=3.12
conda activate CLEAR
pip install -r requirements.txt
```
* Clone this repo to your project directory:
``` bash
git clone https://github.com/Huage001/CLEAR.git
```
## Supported Models
We release a series of variants for linearized [FLUX-1.dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with various local window sizes.
We experimentally find that when the local window size is small, e.g., 8, the model can produce repetitive patterns in many cases. To alleviate the problem, in some variants, we also include down-sampled key-value tokens besides local tokens for attention interaction.
The supported models and the download links are:
| window_size | down_factor | link |
| :---------: | :---------: | :----------------------------------------------------------: |
| 32 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_32.safetensors) |
| 16 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16.safetensors) |
| 8 | NA | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8.safetensors) |
| 16 | 4 | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_16_down_4.safetensors) |
| 8 | 4 | [here](https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors) |
You are encouraged to download the model weights you need to ``ckpt`` beforehand. For example:
```bash
mkdir ckpt
wget https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_8_down_4.safetensors
```
## Inference
* If you want to compare the linearized FLUX with the original model, please try ``inference_t2i.ipynb``.
* If you want to use CLEAR for high-resolution acceleration, please try ``inference_t2i_highres.ipynb``. We currently adopt the strategy of [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image). The basic idea is to generate a low-resolution result at first, based on which we gradually upscale the image.
* Please configure ``down_factor`` and ``window_size`` in the notebooks to use different variants of CLEAR. If you do not want to include down-sampled key-value tokens, specify ``down_factor=1``. The models will be downloaded automatically to ``ckpt`` if not downloaded.
* Currently, a GPU card with 48 GB VMem is recommended for high-resolution generation.
## Training
* Configure ``/path/to/t2i_1024`` in multiple ``.sh`` files.
* Download training images from [here](https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_1024_10K/data_000000.tar), which contains 10K 1024-resolution images generated by ``FLUX-1.dev`` itself, and unzip it to ``/path/to/t2i_1024``:
```
tar -xvf data_000000.tar -C /path/to/t2i_1024
```
* [Optional but Recommended] Cache T5 and CLIP text embeddings and VAE features beforehand:
```bash
bash cache_prompt_embeds.sh
bash cache_latent_codes.sh
```
* Start Training:
```bash
bash distill.sh
```
By default, it uses 4 80G-VMem GPUs with ``train_batch_size=2`` and ``gradient_accumulation_steps=4``. Please feel free to configure them in ``distill.sh`` and ``deepspeed_config.yaml`` according to your situations.
## Acknowledgement
* [FLUX](https://blackforestlabs.ai/announcing-black-forest-labs/) for the source models.
* [flexattention](https://pytorch.org/blog/flexattention/) for kernel implementation.
* [diffusers](https://github.com/huggingface/diffusers) for the code base.
* [DeepSpeed](https://github.com/microsoft/DeepSpeed) for the training framework.
* [SDEdit](https://huggingface.co/docs/diffusers/v0.30.2/en/api/pipelines/stable_diffusion/img2img#image-to-image) for high-resolution image generation.
* [@Weihao Yu](https://github.com/yuweihao) and [@Xinyin Ma](https://github.com/horseee) for valuable discussions.
* NUS IT’s Research Computing group using grant numbers NUSREC-HPC-00001.
## Citation
If you find this repo helpful, please consider citing:
```bib
@article{liu2024clear,
title = {CLEAR: Conv-Like Linearization Revs Pre-Trained Diffusion Transformers Up},
author = {Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},
journal = {NeurIPS},
year = {2025},
}
```
================================================
FILE: attention_processor.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
create_block_mask = torch.compile(create_block_mask)
from diffusers.models.attention_processor import Attention
from typing import Optional
from functools import partial, lru_cache
from diffusers.models.embeddings import apply_rotary_emb
attn_outputs_teacher = []
attn_outputs = []
BLOCK_MASK = None
HEIGHT = None
WIDTH = None
class FluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, distill=False):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.distill = distill
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
proportional_attention=False
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
train_seq_len = 64 ** 2 + 512
if proportional_attention:
attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim)
else:
attention_scale = math.sqrt(1 / head_dim)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, scale=attention_scale)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
else:
if self.distill:
attn_outputs_teacher.append(hidden_states)
return hidden_states
@lru_cache
def init_local_downsample_mask_flex(height, width, text_length, window_size, down_factor, device):
def local_dwonsample_mask(b, h, q_idx, kv_idx):
q_y = (q_idx - text_length) // width
q_x = (q_idx - text_length) % width
kv_y = (kv_idx - text_length) // width
kv_x = (kv_idx - text_length) % width
return torch.logical_or(
torch.logical_and(
q_idx < text_length,
kv_idx < text_length + height * width
),
torch.logical_and(
q_idx >= text_length,
torch.logical_or(
torch.logical_or(kv_idx < text_length, kv_idx >= text_length + height * width),
(q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2)
)
)
global BLOCK_MASK, HEIGHT, WIDTH
BLOCK_MASK = create_block_mask(local_dwonsample_mask, B=None, H=None, device=device,
Q_LEN=text_length + height * width,
KV_LEN=text_length + height * width + (height // down_factor) * (width // down_factor), _compile=True)
HEIGHT = height
WIDTH = width
class LocalDownsampleFlexAttnProcessor(nn.Module):
def __init__(self, down_factor=4, distill=False):
super().__init__()
assert BLOCK_MASK is not None
self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK)
self.flex_attn = torch.compile(self.flex_attn, dynamic=False)
self.down_factor = down_factor
self.spatial_weight = nn.Parameter(torch.ones(1, 1, 1, down_factor, 1, down_factor, 1) / (down_factor * down_factor))
self.distill = distill
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
proportional_attention=False
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
train_seq_len = 64 ** 2 + 512
if proportional_attention:
attention_scale = math.sqrt(math.log(10 * key.size(2), train_seq_len) / head_dim)
else:
attention_scale = math.sqrt(1 / head_dim)
key_downsample = (key[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor,
WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3)
value_downsample = (value[:, :, 512:].unflatten(2, (HEIGHT // self.down_factor, self.down_factor,
WIDTH // self.down_factor, self.down_factor)) * self.spatial_weight).sum(dim=(3, 5)).flatten(2, 3)
hidden_states = self.flex_attn(query, torch.cat([key, key_downsample], dim=2), torch.cat([value, value_downsample], dim=2), scale=attention_scale)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if self.distill:
attn_outputs.append(hidden_states)
return hidden_states, encoder_hidden_states
else:
if self.distill:
attn_outputs.append(hidden_states)
return hidden_states
@lru_cache
def init_local_mask_flex(height, width, text_length, window_size, device):
def local_mask(b, h, q_idx, kv_idx):
q_y = (q_idx - text_length) // width
q_x = (q_idx - text_length) % width
kv_y = (kv_idx - text_length) // width
kv_x = (kv_idx - text_length) % width
return torch.logical_or(torch.logical_or(q_idx < text_length, kv_idx < text_length),
(q_y - kv_y) ** 2 + (q_x - kv_x) ** 2 < window_size ** 2)
global BLOCK_MASK, HEIGHT, WIDTH
BLOCK_MASK = create_block_mask(local_mask, B=None, H=None, device=device,
Q_LEN=text_length + height * width,
KV_LEN=text_length + height * width, _compile=True)
HEIGHT = height
WIDTH = width
class LocalFlexAttnProcessor:
"""Attention processor used typically in processing the SD3-like self-attention projections."""
def __init__(self, distill=False):
super().__init__()
self.flex_attn = partial(flex_attention, block_mask=BLOCK_MASK)
self.flex_attn = torch.compile(self.flex_attn, dynamic=False)
self.distill = distill
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
proportional_attention=False
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
train_seq_len = 64 ** 2 + 512
if proportional_attention:
attention_scale = math.sqrt(math.log(key.size(2), train_seq_len) / head_dim)
else:
attention_scale = math.sqrt(1 / head_dim)
hidden_states = self.flex_attn(query, key, value, scale=attention_scale)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
if self.distill:
attn_outputs.append(hidden_states)
return hidden_states, encoder_hidden_states
else:
if self.distill:
attn_outputs.append(hidden_states)
return hidden_states
================================================
FILE: cache_latent_codes.py
================================================
import argparse
import os
import numpy as np
from PIL import Image
import math
from safetensors.torch import save_file
import torch
import tqdm
from diffusers import AutoencoderKL
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--data_root",
type=str,
default=None
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--resolution",
type=int,
default=1024,
help="resolution",
)
parser.add_argument(
"--output_dir",
type=str,
default="flux-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank")
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="Number of workers",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
return args
def main(args):
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if args.mixed_precision == 'fp16':
dtype = torch.float16
elif args.mixed_precision == 'bf16':
dtype = torch.bfloat16
else:
dtype = torch.float32
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
).to(device, dtype)
all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.jpg' in i or '.png' in i]
os.makedirs(args.output_dir, exist_ok=True)
work_load = math.ceil(len(all_info) / args.num_workers)
for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)):
images = []
paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_latent_code.safetensors') for item in all_info[idx:idx + args.batch_size]]
for item in all_info[idx:idx + args.batch_size]:
img = Image.open(os.path.join(args.data_root, item)).convert('RGB')
img = img.resize((args.resolution, args.resolution))
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1)
images.append(img)
with torch.no_grad():
images = torch.stack(images, dim=0)
data = vae.encode(images.to(device, vae.dtype)).latent_dist
means = data.mean.cpu().data
stds = data.std.cpu().data
for path, mean, std in zip(paths, means.unbind(), stds.unbind()):
save_file(
{'mean': mean, 'std': std},
path
)
if __name__ == '__main__':
main(parse_args())
================================================
FILE: cache_latent_codes.sh
================================================
export NUM_WORKERS=4
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
torchrun --nproc_per_node=$NUM_WORKERS cache_latent_codes.py \
--data_root="/path/to/t2i_1024" \
--batch_size=16 \
--num_worker=$NUM_WORKERS \
--pretrained_model_name_or_path=$MODEL_NAME \
--mixed_precision='bf16' \
--output_dir="/path/to/t2i_1024"
================================================
FILE: cache_prompt_embeds.py
================================================
import argparse
import os
import json
import math
from safetensors.torch import save_file
import torch
import tqdm
from transformers import CLIPTokenizer, T5TokenizerFast, CLIPTextModel, T5EncoderModel
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--data_root",
type=str,
default=None
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument(
"--max_sequence_length",
type=int,
default=512,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="flux-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument("--local_rank", type=int, default=0, help="For distributed training: local_rank")
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="Number of workers",
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
args.local_rank = int(os.environ.get("LOCAL_RANK", 0))
return args
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length=512,
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
prompt_embeds = _encode_prompt_with_t5(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
return prompt_embeds, pooled_prompt_embeds
def main(args):
if torch.cuda.is_available():
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
if args.mixed_precision == 'fp16':
dtype = torch.float16
elif args.mixed_precision == 'bf16':
dtype = torch.bfloat16
else:
dtype = torch.float32
tokenizer_one = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir
)
tokenizer_two = T5TokenizerFast.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
variant=args.variant,
cache_dir=args.cache_dir
)
text_encoder_one = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
subfolder="text_encoder",
variant=args.variant,
cache_dir=args.cache_dir
).to(device, dtype)
text_encoder_two = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
subfolder="text_encoder_2",
variant=args.variant,
cache_dir=args.cache_dir
).to(device, dtype)
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
all_info = [os.path.join(args.data_root, i) for i in sorted(os.listdir(args.data_root)) if '.json' in i]
os.makedirs(args.output_dir, exist_ok=True)
work_load = math.ceil(len(all_info) / args.num_workers)
for idx in tqdm.tqdm(range(work_load * args.local_rank, min(work_load * (args.local_rank + 1), len(all_info)), args.batch_size)):
texts = []
for item in all_info[idx:idx + args.batch_size]:
with open(os.path.join(args.data_root, item)) as f:
texts.append(json.load(f)['prompt'])
paths = [os.path.join(args.data_root, item[:item.rfind('.')] + '_prompt_embed.safetensors') for item in all_info[idx:idx + args.batch_size]]
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders, tokenizers, texts, args.max_sequence_length
)
prompt_embeds = prompt_embeds.cpu().data
pooled_prompt_embeds = pooled_prompt_embeds.cpu().data
for path, prompt_embed, pooled_prompt_embed in zip(paths, prompt_embeds.unbind(), pooled_prompt_embeds.unbind()):
save_file(
{'caption_feature_t5': prompt_embed, 'caption_feature_clip': pooled_prompt_embed},
path
)
if __name__ == '__main__':
main(parse_args())
================================================
FILE: cache_prompt_embeds.sh
================================================
export NUM_WORKERS=4
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
torchrun --nproc_per_node=$NUM_WORKERS cache_prompt_embeds.py \
--data_root="/path/to/t2i_1024" \
--batch_size=256 \
--num_worker=$NUM_WORKERS \
--pretrained_model_name_or_path=$MODEL_NAME \
--mixed_precision='bf16' \
--output_dir="/path/to/t2i_1024"
================================================
FILE: dataset.py
================================================
import os
import io
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import json
import random
import cv2
from safetensors import safe_open
def image_resize(img, max_size=512):
w, h = img.size
if w >= h:
new_w = max_size
new_h = int((max_size / w) * h)
else:
new_h = max_size
new_w = int((max_size / h) * w)
new_w = (w // 32) * 32
new_h = (h // 32) * 32
return img.resize((new_w, new_h))
def c_crop(image):
width, height = image.size
new_size = min(width, height)
left = (width - new_size) / 2
top = (height - new_size) / 2
right = (width + new_size) / 2
bottom = (height + new_size) / 2
return image.crop((left, top, right, bottom))
def crop_to_aspect_ratio(image, ratio="16:9"):
width, height = image.size
ratio_map = {
"16:9": (16, 9),
"4:3": (4, 3),
"1:1": (1, 1)
}
target_w, target_h = ratio_map[ratio]
target_ratio_value = target_w / target_h
current_ratio = width / height
if current_ratio > target_ratio_value:
new_width = int(height * target_ratio_value)
offset = (width - new_width) // 2
crop_box = (offset, 0, offset + new_width, height)
else:
new_height = int(width / target_ratio_value)
offset = (height - new_height) // 2
crop_box = (0, offset, width, offset + new_height)
cropped_img = image.crop(crop_box)
return cropped_img
class CustomImageDataset(Dataset):
def __init__(self, img_dir, img_size=512, caption_type='json',
random_ratio=False, use_cached_prompt_embeds=False, use_cached_latent_codes=False):
self.images = [os.path.join(img_dir, i) for i in os.listdir(img_dir) if '.jpg' in i or '.png' in i]
self.images.sort()
self.img_size = img_size
self.caption_type = caption_type
self.random_ratio = random_ratio
self.use_cached_prompt_embeds = use_cached_prompt_embeds
self.use_cached_latent_codes = use_cached_latent_codes
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
try:
batch = {}
if not self.use_cached_latent_codes:
img = Image.open(self.images[idx]).convert('RGB')
if self.random_ratio:
ratio = random.choice(["16:9", "default", "1:1", "4:3"])
if ratio != "default":
img = crop_to_aspect_ratio(img, ratio)
img = image_resize(img, self.img_size)
img = torch.from_numpy((np.array(img) / 127.5) - 1)
img = img.permute(2, 0, 1)
batch['images'] = img
else:
with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_latent_code.safetensors', framework="pt") as f:
batch['latent_codes_mean'] = f.get_tensor('mean')
batch['latent_codes_std'] = f.get_tensor('std')
if not self.use_cached_prompt_embeds:
json_path = self.images[idx].split('.')[0] + '.' + self.caption_type
if self.caption_type == "json":
prompt = json.load(open(json_path))['prompt']
else:
prompt = open(json_path).read()
batch['prompts'] = prompt
else:
with safe_open(self.images[idx][:self.images[idx].rfind('.')] + '_prompt_embed.safetensors', framework="pt") as f:
batch['prompt_embeds_t5'] = f.get_tensor('caption_feature_t5')
batch['prompt_embeds_clip'] = f.get_tensor('caption_feature_clip')
return batch
except Exception as e:
print(e)
return self.__getitem__(random.randint(0, len(self.images) - 1))
def loader(train_batch_size, num_workers, **args):
dataset = CustomImageDataset(**args)
return DataLoader(dataset, batch_size=train_batch_size, num_workers=num_workers, shuffle=True)
================================================
FILE: deepspeed_config.yaml
================================================
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 4
gradient_clipping: 1.0
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
================================================
FILE: distill.py
================================================
import argparse
import copy
import logging
import math
import os
import shutil
from contextlib import nullcontext
from pathlib import Path
import numpy as np
import torch
import torch.utils.checkpoint
from torch.utils.data import RandomSampler
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
import diffusers
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxPipeline,
FluxTransformer2DModel,
)
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
)
from diffusers.utils import (
check_min_version,
is_wandb_available,
)
from diffusers.utils.torch_utils import is_compiled_module
from safetensors.torch import save_file
from safetensors import safe_open
from dataset import loader
from attention_processor import FluxAttnProcessor2_0, LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex
from attention_processor import attn_outputs_teacher, attn_outputs
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0.dev0")
logger = get_logger(__name__)
def load_text_encoders(class_one, class_two):
text_encoder_one = class_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
)
text_encoder_two = class_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
)
return text_encoder_one, text_encoder_two
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
torch_dtype,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()
with autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, subfolder=subfolder, revision=revision
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--data_root",
type=str,
default=None
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
parser.add_argument(
"--max_sequence_length",
type=int,
default=512,
help="Maximum sequence length to use with with the T5 text encoder",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=1,
help=(
"Run validation every X epochs. Validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="flux-attn",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=1024,
help=(
"The resolution for input images"
)
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--window_size",
type=int,
default=16,
help=(
"Size of local window for attention sampling."
),
)
parser.add_argument(
"--down_factor",
type=int,
default=1,
help=(
"Factor of downsampling for key-value tokens."
),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=10000,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--guidance_scale",
type=float,
default=3.5,
help="the FLUX.1 dev variant is a guidance distilled model",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--weighting_scheme",
type=str,
default="none",
choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
)
parser.add_argument(
"--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
)
parser.add_argument(
"--mode_scale",
type=float,
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--optimizer",
type=str,
default="AdamW",
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
)
parser.add_argument(
"--use_8bit_adam",
action="store_true",
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
)
parser.add_argument(
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--prodigy_beta3",
type=float,
default=None,
help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
"uses the value of square root of beta2. Ignored if optimizer is adamW",
)
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-08,
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
)
parser.add_argument(
"--prodigy_use_bias_correction",
type=bool,
default=True,
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
)
parser.add_argument(
"--prodigy_safeguard_warmup",
type=bool,
default=True,
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
"Ignored if optimizer is adamW",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--use_cached_latent",
action="store_true",
default=False,
help="Cache the VAE latents",
)
parser.add_argument(
"--use_cached_prompt_embed",
action="store_true",
default=False,
help="Use cached T5 and CLIP features",
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--upcast_before_saving",
action="store_true",
default=False,
help=(
"Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
"Defaults to precision dtype used for training to save memory"
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def tokenize_prompt(tokenizer, prompt, max_sequence_length):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
def _encode_prompt_with_t5(
text_encoder,
tokenizer,
max_sequence_length=512,
prompt=None,
num_images_per_prompt=1,
device=None,
text_input_ids=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device))[0]
dtype = text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
def _encode_prompt_with_clip(
text_encoder,
tokenizer,
prompt: str,
device=None,
text_input_ids=None,
num_images_per_prompt: int = 1,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if tokenizer is not None:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
else:
if text_input_ids is None:
raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
def encode_prompt(
text_encoders,
tokenizers,
prompt: str,
max_sequence_length,
device=None,
num_images_per_prompt: int = 1,
text_input_ids_list=None,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
dtype = text_encoders[0].dtype
pooled_prompt_embeds = _encode_prompt_with_clip(
text_encoder=text_encoders[0],
tokenizer=tokenizers[0],
prompt=prompt,
device=device if device is not None else text_encoders[0].device,
num_images_per_prompt=num_images_per_prompt,
text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
)
prompt_embeds = _encode_prompt_with_t5(
text_encoder=text_encoders[1],
tokenizer=tokenizers[1],
max_sequence_length=max_sequence_length,
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=device if device is not None else text_encoders[1].device,
text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
def main(args):
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizers
tokenizer_one = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
)
tokenizer_two = T5TokenizerFast.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer_2",
revision=args.revision,
)
# import correct text encoder classes
text_encoder_cls_one = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision
)
text_encoder_cls_two = import_model_class_from_model_name_or_path(
args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
)
# Load scheduler and models
noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler"
)
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
transformer = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
transformer_teacher = FluxTransformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
)
if args.down_factor == 1:
init_local_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size, device=accelerator.device)
attn_processors = {}
for idx, name in enumerate(transformer.attn_processors):
attn_processors[name] = LocalFlexAttnProcessor(distill='single' in name and idx % 4 == 0)
else:
init_local_downsample_mask_flex(args.resolution // 16, args.resolution // 16, text_length=args.max_sequence_length, window_size=args.window_size,
down_factor=args.down_factor, device=accelerator.device)
attn_processors = {}
for idx, name in enumerate(transformer.attn_processors):
attn_processors[name] = LocalDownsampleFlexAttnProcessor(down_factor=args.down_factor, distill='single' in name and idx % 4 == 0)
attn_processors_teacher = {}
for idx, name in enumerate(transformer_teacher.attn_processors.keys()):
attn_processors_teacher[name] = FluxAttnProcessor2_0(distill='single' in name and idx % 4 == 0)
transformer.set_attn_processor(attn_processors)
transformer_teacher.set_attn_processor(attn_processors_teacher)
# We only train the Attn layers
transformer.requires_grad_(False)
transformer_teacher.requires_grad_(False)
vae.requires_grad_(False)
text_encoder_one.requires_grad_(False)
text_encoder_two.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
# due to pytorch#99272, MPS does not yet support bfloat16.
raise ValueError(
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
transformer_teacher.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
for _name, _param in transformer.named_parameters():
if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name:
_param.requires_grad = True
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32 and torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Make sure the trainable params are in float32.
if args.mixed_precision == "fp16":
models = [transformer]
# only upcast trainable parameters (Attn) into fp32
cast_training_params(models, dtype=torch.float32)
# Optimization parameters
transformer_attn_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
transformer_parameters_with_lr = {"params": transformer_attn_parameters, "lr": args.learning_rate}
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
if args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
# Dataset and DataLoaders creation:
train_dataloader = loader(train_batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers,
img_dir=args.data_root, img_size=args.resolution,
use_cached_prompt_embeds=args.use_cached_prompt_embed,
use_cached_latent_codes=args.use_cached_latent)
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
text_encoders, tokenizers, prompt, args.max_sequence_length
)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
text_ids = text_ids.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds, text_ids
# Clear the memory here
if args.use_cached_prompt_embed:
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, text_encoders, tokenizers
free_memory()
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
if args.use_cached_latent:
del vae
free_memory()
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
guidance_embeds = transformer.config.guidance_embeds
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "flux-dev-attn"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = args.resume_from_checkpoint
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(path)
global_step = int(path.split("-")[-1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
sigmas = noise_scheduler.sigmas.to(device=accelerator.device, dtype=dtype)
schedule_timesteps = noise_scheduler.timesteps.to(accelerator.device)
timesteps = timesteps.to(accelerator.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < n_dim:
sigma = sigma.unsqueeze(-1)
return sigma
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
for _, batch in enumerate(train_dataloader):
with accelerator.accumulate(transformer):
if not args.use_cached_prompt_embed:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
batch['prompts'], text_encoders, tokenizers
)
else:
prompt_embeds = batch['prompt_embeds_t5'].to(dtype=weight_dtype)
pooled_prompt_embeds = batch['prompt_embeds_clip'].to(dtype=weight_dtype)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=weight_dtype)
with torch.no_grad():
# Convert images to latent space
if args.use_cached_latent:
mean = batch['latent_codes_mean'].to(dtype=weight_dtype)
std = batch['latent_codes_std'].to(dtype=weight_dtype)
sample = torch.randn_like(mean)
model_input = mean + std * sample
else:
model_input = vae.encode(batch[0].to(dtype=vae.dtype)).latent_dist.sample()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
vae_scale_factor = 2 ** (len(vae_config_block_out_channels))
latent_image_ids = FluxPipeline._prepare_latent_image_ids(
model_input.shape[0],
model_input.shape[2],
model_input.shape[3],
accelerator.device,
weight_dtype,
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
u = compute_density_for_timestep_sampling(
weighting_scheme=args.weighting_scheme,
batch_size=bsz,
logit_mean=args.logit_mean,
logit_std=args.logit_std,
mode_scale=args.mode_scale,
)
indices = (u * noise_scheduler.config.num_train_timesteps).long()
timesteps = noise_scheduler.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
packed_noisy_model_input = FluxPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
num_channels_latents=model_input.shape[1],
height=model_input.shape[2],
width=model_input.shape[3],
)
# handle guidance
if guidance_embeds:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
guidance = None
teacher_pred = transformer_teacher(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
teacher_pred = FluxPipeline._unpack_latents(
teacher_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
# Predict the noise residual
model_pred = transformer(
hidden_states=packed_noisy_model_input,
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
timestep=timesteps / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
return_dict=False,
)[0]
model_pred = FluxPipeline._unpack_latents(
model_pred,
height=int(model_input.shape[2] * vae_scale_factor / 2),
width=int(model_input.shape[3] * vae_scale_factor / 2),
vae_scale_factor=vae_scale_factor,
)
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = noise - model_input
# Compute regular loss.
loss_fm = (weighting.float() * (model_pred.float() - target.float()) ** 2).mean()
loss_distill = (weighting.float() * (model_pred.float() - teacher_pred.float()) ** 2).mean()
loss_attn = sum([(weighting.float().squeeze(-1) * (attn_output.float() - attn_output_teacher.float()) ** 2).mean()
for attn_output, attn_output_teacher in zip(attn_outputs, attn_outputs_teacher)]) / len(attn_outputs)
loss = loss_fm + loss_distill * 0.5 + loss_attn * 0.5
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = (
transformer.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
attn_outputs.clear()
attn_outputs_teacher.clear()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss_fm": loss_fm.detach().item(),
"loss_distill": loss_distill.detach().item(),
"loss_attn": loss_attn.detach().item(),
"loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
free_memory()
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
# create pipeline
if args.use_cached_prompt_embed:
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
if args.use_cached_latent:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
)
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one),
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
transformer=accelerator.unwrap_model(transformer),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
torch_dtype=weight_dtype,
)
if args.use_cached_prompt_embed:
del text_encoder_one, text_encoder_two
if args.use_cached_latent:
del vae
free_memory()
# Save the attn weights
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = unwrap_model(transformer)
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
state_dict = {}
for _name, _param in transformer.named_parameters():
if '.attn.to_q.' in _name or '.attn.to_k.' in _name or '.attn.to_v.' in _name or '.attn.to_out.' in _name or 'spatial_weight' in _name:
state_dict[_name] = _param
save_file(
state_dict,
os.path.join(args.output_dir, 'attn_weights.safetensors')
)
# Final inference
# Load previous pipeline
free_memory()
pipeline = FluxPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
# load attention processors
state_dict = {}
with safe_open(os.path.join(args.output_dir, 'attn_weights.safetensor'), framework="pt") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
missing_keys, unexpected_keys = pipeline.transformer.load_state_dict(state_dict)
missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or
'.attn.to_k.' in p or
'.attn.to_v.' in p or
'.attn.to_out.' in p or
'spatial_weight' in p), missing_keys))
if len(missing_keys) != 0 or len(unexpected_keys) != 0:
logger.warning(
f"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}"
f" and missing keys: {missing_keys}."
)
# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
images = log_validation(
pipeline=pipeline,
args=args,
accelerator=accelerator,
pipeline_args=pipeline_args,
epoch=epoch,
is_final_validation=True,
torch_dtype=weight_dtype,
)
for idx, image in enumerate(images):
image.save(os.path.join(args.output_dir, 'validation_idx_%02d.jpg' % idx))
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)
================================================
FILE: distill.sh
================================================
export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export DATAPATH="/path/to/t2i_1024"
export OUTPUT_DIR="ckpt/training_exp"
export PRECISION="bf16"
accelerate launch --config_file deepspeed_config.yaml distill.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--data_root=$DATAPATH \
--output_dir=$OUTPUT_DIR \
--mixed_precision=$PRECISION \
--dataloader_num_workers=8 \
--resolution=1024 \
--train_batch_size=2 \
--gradient_accumulation_steps=4 \
--optimizer="prodigy" \
--learning_rate=1. \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=50000 \
--validation_epochs=1 \
--seed="0" \
--checkpointing_steps=5000 \
--use_cached_prompt_embed \
--use_cached_latent \
--gradient_checkpointing \
--down_factor=1 \
--window_size=16
================================================
FILE: inference_t2i.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"import requests\n",
"from safetensors.torch import load_file\n",
"from diffusers import FluxPipeline\n",
"from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
"device = torch.device('cuda')\n",
"dtype = torch.bfloat16\n",
"pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"height = 1024\n",
"width = 1024\n",
"down_factor, window_size = 4, 8\n",
"# Supported Configurations:\n",
"# down_factor, window_size = 1, 8\n",
"# down_factor, window_size = 1, 16\n",
"# down_factor, window_size = 1, 32\n",
"# down_factor, window_size = 4, 16\n",
"# down_factor, window_size = 4, 8\n",
"if down_factor == 1:\n",
" init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\n",
" attn_processors = {}\n",
" for k in pipe.transformer.attn_processors.keys():\n",
" attn_processors[k] = LocalFlexAttnProcessor()\n",
"else:\n",
" init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\n",
" attn_processors = {}\n",
" for k in pipe.transformer.attn_processors.keys():\n",
" attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\n",
"pipe.transformer.set_attn_processor(attn_processors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists('ckpt'):\n",
" os.mkdir('ckpt')\n",
"if down_factor == 1:\n",
" if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\n",
" print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\n",
" response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\")\n",
" response.raise_for_status()\n",
" with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\n",
" f.write(response.content)\n",
" state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\n",
"else:\n",
" if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\n",
" print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
" response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\")\n",
" response.raise_for_status()\n",
" with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\n",
" f.write(response.content)\n",
" state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
"\n",
"missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\n",
"\n",
"missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n",
" '.attn.to_k.' in p or \n",
" '.attn.to_v.' in p or \n",
" '.attn.to_out.' in p or \n",
" 'spatial_weight' in p), missing_keys))\n",
"\n",
"if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n",
" print(\n",
" f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n",
" f\" and missing keys: {missing_keys}.\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"A Mickey is eating a pie\"\n",
"height = 1024\n",
"width = 1024\n",
"image = pipe(\n",
" prompt,\n",
" height=height,\n",
" width=width,\n",
" guidance_scale=3.5,\n",
" num_inference_steps=20,\n",
" max_sequence_length=512,\n",
" generator=torch.Generator(\"cpu\").manual_seed(0)\n",
").images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pt25",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: inference_t2i_highres.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import os\n",
"import gc\n",
"import requests\n",
"from safetensors.torch import load_file\n",
"from diffusers import FluxPipeline\n",
"from pipeline_flux_img2img import FluxImg2ImgPipeline\n",
"from transformer_flux import FluxTransformer2DModel\n",
"from attention_processor import LocalFlexAttnProcessor, LocalDownsampleFlexAttnProcessor, init_local_mask_flex, init_local_downsample_mask_flex"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bfl_repo=\"black-forest-labs/FLUX.1-dev\"\n",
"device = torch.device('cuda')\n",
"dtype = torch.bfloat16\n",
"pipe = FluxPipeline.from_pretrained(bfl_repo, torch_dtype=dtype).to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"enchanted forest, glowing plants, towering ancient trees, a mystical girl, magical aura, \" \\\n",
" \"fantasy style, vibrant colors, ethereal lighting, bokeh effect, ultra-detailed, painterly, ultra HD, 8K, \" \\\n",
" \"soft glowing lights, mist and fog, otherworldly ambiance, glowing mushrooms, sparkling particles\"\n",
"height = 512\n",
"width = 1024\n",
"image = pipe(\n",
" prompt,\n",
" height=height,\n",
" width=width,\n",
" guidance_scale=3.5,\n",
" num_inference_steps=20,\n",
" max_sequence_length=512,\n",
" generator=torch.Generator(\"cpu\").manual_seed(0)\n",
").images[0]\n",
"image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"del pipe\n",
"torch.cuda.empty_cache()\n",
"gc.collect()\n",
"transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder=\"transformer\", torch_dtype=torch.bfloat16)\n",
"pipe = FluxImg2ImgPipeline.from_pretrained(bfl_repo, transformer=None, torch_dtype=torch.bfloat16)\n",
"pipe.transformer = transformer\n",
"pipe.scheduler.config.use_dynamic_shifting = False\n",
"pipe.scheduler.config.time_shift = 10\n",
"pipe.vae.enable_tiling()\n",
"pipe = pipe.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"height = 2048\n",
"width = 4096\n",
"down_factor, window_size = 4, 8\n",
"# Supported Configurations:\n",
"# down_factor, window_size = 1, 8\n",
"# down_factor, window_size = 1, 16\n",
"# down_factor, window_size = 1, 32\n",
"# down_factor, window_size = 4, 16\n",
"# down_factor, window_size = 4, 8\n",
"if down_factor == 1:\n",
" init_local_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, device=device)\n",
" attn_processors = {}\n",
" for k in pipe.transformer.attn_processors.keys():\n",
" attn_processors[k] = LocalFlexAttnProcessor()\n",
"else:\n",
" init_local_downsample_mask_flex(height // 16, width // 16, text_length=512, window_size=window_size, down_factor=down_factor, device=device)\n",
" attn_processors = {}\n",
" for k in pipe.transformer.attn_processors.keys():\n",
" attn_processors[k] = LocalDownsampleFlexAttnProcessor(down_factor=down_factor).to(device, dtype)\n",
"pipe.transformer.set_attn_processor(attn_processors)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if not os.path.exists('ckpt'):\n",
" os.mkdir('ckpt')\n",
"if down_factor == 1:\n",
" if not os.path.exists(f'ckpt/clear_local_{window_size}.safetensors'):\n",
" print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}.safetensors')\n",
" response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}.safetensors\")\n",
" response.raise_for_status()\n",
" with open(f'ckpt/clear_local_{window_size}.safetensors', 'wb') as f:\n",
" f.write(response.content)\n",
" state_dict = load_file(f'ckpt/clear_local_{window_size}.safetensors')\n",
"else:\n",
" if not os.path.exists(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors'):\n",
" print(f'Checkpoint not found. Downloading checkpoint to ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
" response = requests.get(f\"https://huggingface.co/Huage001/CLEAR/resolve/main/clear_local_{window_size}_down_{down_factor}.safetensors\")\n",
" response.raise_for_status()\n",
" with open(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors', 'wb') as f:\n",
" f.write(response.content)\n",
" state_dict = load_file(f'ckpt/clear_local_{window_size}_down_{down_factor}.safetensors')\n",
"\n",
"missing_keys, unexpected_keys = pipe.transformer.load_state_dict(state_dict, strict=False)\n",
"\n",
"missing_keys = list(filter(lambda p: ('.attn.to_q.' in p or \n",
" '.attn.to_k.' in p or \n",
" '.attn.to_v.' in p or \n",
" '.attn.to_out.' in p or \n",
" 'spatial_weight' in p), missing_keys))\n",
"\n",
"if len(missing_keys) != 0 or len(unexpected_keys) != 0:\n",
" print(\n",
" f\"Loading attn weights from state_dict led to unexpected keys: {unexpected_keys}\"\n",
" f\" and missing keys: {missing_keys}.\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"strength = 0.7\n",
"image_hr = pipe(prompt=prompt,\n",
" image=image.resize((width, height)),\n",
" strength=strength,\n",
" num_inference_steps=20, \n",
" guidance_scale=7.5, \n",
" height=height,\n",
" width=width,\n",
" ntk_factor=10,\n",
" proportional_attention=True,\n",
" generator=torch.Generator(\"cpu\").manual_seed(0)\n",
" ).images[0]\n",
"image_hr"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "pt25",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: pipeline_flux_img2img.py
================================================
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import FluxTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_xla_available,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import randn_tensor
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import FluxImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> device = "cuda"
>>> pipe = FluxImg2ImgPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
>>> pipe = pipe.to(device)
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
>>> init_image = load_image(url).resize((1024, 1024))
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
>>> images = pipe(
... prompt=prompt, image=init_image, num_inference_steps=4, strength=0.95, guidance_scale=0.0
... ).images[0]
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin):
r"""
The Flux pipeline for image inpainting.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Args:
transformer ([`FluxTransformer2DModel`]):
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_sample_size = 64
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
text_inputs = self.tokenizer_2(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
dtype = self.text_encoder_2.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
def _get_clip_prompt_embeds(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
r"""
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or self._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if self.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(self.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = self._get_clip_prompt_embeds(
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = self._get_t5_prompt_embeds(
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
if self.text_encoder is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2, lora_scale)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
def check_inputs(
self,
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=None,
pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
return latents
@staticmethod
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
def _unpack_latents(latents, height, width, vae_scale_factor):
batch_size, num_patches, channels = latents.shape
height = height // vae_scale_factor
width = width // vae_scale_factor
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
latents = latents.permute(0, 3, 1, 4, 2, 5)
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
return latents
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
height = 2 * (int(height) // self.vae_scale_factor)
width = 2 * (int(width) // self.vae_scale_factor)
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype)
image_latents = self._encode_vae_image(image=image, generator=generator)
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
# expand init_latents for batch_size
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
else:
image_latents = torch.cat([image_latents], dim=0)
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids, image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.6,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 7.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
ntk_factor: float = 10.0,
proportional_attention: bool = True
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 7.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
strength,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Preprocess image
init_image = self.image_processor.preprocess(image, height=height, width=width)
init_image = init_image.to(dtype=torch.float32)
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)
(
prompt_embeds,
pooled_prompt_embeds,
text_ids,
) = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4.Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
timesteps, num_inference_steps_ = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas
)
timesteps, num_inference_steps_ = self.get_timesteps(num_inference_steps_, strength, device)
if num_inference_steps_ < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
f"steps is {num_inference_steps_} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids, image_latents = self.prepare_latents(
init_image,
latent_timestep,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps_ * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
guidance = None
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps_) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
ntk_factor=ntk_factor,
proportional_attention=proportional_attention,
)[0]
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
image = latents
else:
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return FluxPipelineOutput(images=image)
================================================
FILE: requirements.txt
================================================
torch>=2.5.0
diffusers>=0.31.0
transformers
safetensors
opencv-python
accelerate
tqdm
deepspeed
wandb
prodigyopt
sentencepiece
================================================
FILE: transformer_flux.py
================================================
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from attention_processor import FluxAttnProcessor2_0
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
proportional_attention: bool = False,
joint_attention_kwargs=None
):
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
proportional_attention=proportional_attention,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = residual + hidden_states
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
r"""
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
Reference: https://arxiv.org/abs/2403.03206
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
processing of `context` conditions.
"""
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
proportional_attention: bool = False,
joint_attention_kwargs=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
proportional_attention=proportional_attention,
**joint_attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output
# Process attention outputs for the `encoder_hidden_states`.
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return encoder_hidden_states, hidden_states
class FluxPosEmbed(nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor, ntk_factor=1) -> torch.Tensor:
n_axes = ids.shape[-1]
cos_out = []
sin_out = []
pos = ids.float()
is_mps = ids.device.type == "mps"
freqs_dtype = torch.float32 if is_mps else torch.float64
for i in range(n_axes):
cos, sin = get_1d_rotary_pos_embed(
self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype,
ntk_factor=ntk_factor
)
cos_out.append(cos)
sin_out.append(sin)
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
return freqs_cos, freqs_sin
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
"""
The Transformer model introduced in Flux.
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
Parameters:
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
):
super().__init__()
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
ntk_factor: float = 1,
proportional_attention: bool = False,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids, ntk_factor=ntk_factor)
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
proportional_attention,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
proportional_attention=proportional_attention,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
# For Xlabs ControlNet.
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
proportional_attention,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
proportional_attention=proportional_attention,
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)