Repository: black-forest-labs/flux Branch: main Commit: 802fb4713906 Files: 51 Total size: 319.8 KB Directory structure: gitextract_5wlz5y8p/ ├── .github/ │ └── workflows/ │ └── ci.yaml ├── .gitignore ├── LICENSE ├── README.md ├── demo_gr.py ├── demo_st.py ├── demo_st_fill.py ├── docs/ │ ├── fill.md │ ├── image-editing.md │ ├── image-variation.md │ ├── structural-conditioning.md │ └── text-to-image.md ├── model_cards/ │ ├── FLUX.1-Krea-dev.md │ ├── FLUX.1-dev.md │ ├── FLUX.1-kontext-dev.md │ └── FLUX.1-schnell.md ├── model_licenses/ │ ├── LICENSE-FLUX1-dev │ └── LICENSE-FLUX1-schnell ├── pyproject.toml ├── setup.py └── src/ └── flux/ ├── __init__.py ├── __main__.py ├── cli.py ├── cli_control.py ├── cli_fill.py ├── cli_kontext.py ├── cli_redux.py ├── content_filters.py ├── math.py ├── model.py ├── modules/ │ ├── autoencoder.py │ ├── conditioner.py │ ├── image_embedders.py │ ├── layers.py │ └── lora.py ├── sampling.py ├── trt/ │ ├── __init__.py │ ├── engine/ │ │ ├── __init__.py │ │ ├── base_engine.py │ │ ├── clip_engine.py │ │ ├── t5_engine.py │ │ ├── transformer_engine.py │ │ └── vae_engine.py │ ├── trt_config/ │ │ ├── __init__.py │ │ ├── base_trt_config.py │ │ ├── clip_trt_config.py │ │ ├── t5_trt_config.py │ │ ├── transformer_trt_config.py │ │ └── vae_trt_config.py │ └── trt_manager.py └── util.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/ci.yaml ================================================ name: CI on: push jobs: lint: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - uses: actions/setup-python@v2 with: python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip pip install ruff==0.6.8 - name: Run Ruff run: ruff check --output-format=github . - name: Check imports run: ruff check --select I --output-format=github . - name: Check formatting run: ruff format --check . ================================================ FILE: .gitignore ================================================ # Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python # Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python ### Linux ### *~ # temporary files which can be created if a process still has a handle open of a deleted file .fuse_hidden* # KDE directory preferences .directory # Linux trash folder which might appear on any partition or disk .Trash-* # .nfs files are created when an open file is removed but is still being accessed .nfs* ### macOS ### # General .DS_Store .AppleDouble .LSOverride # Icon must end with two \r Icon # Thumbnails ._* # Files that might appear in the root of a volume .DocumentRevisions-V100 .fseventsd .Spotlight-V100 .TemporaryItems .Trashes .VolumeIcon.icns .com.apple.timemachine.donotpresent # Directories potentially created on remote AFP share .AppleDB .AppleDesktop Network Trash Folder Temporary Items .apdisk ### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: # .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ ### VisualStudioCode ### .vscode/* !.vscode/settings.json !.vscode/tasks.json !.vscode/launch.json !.vscode/extensions.json *.code-workspace # Local History for Visual Studio Code .history/ ### VisualStudioCode Patch ### # Ignore all local history of files .history .ionide ### Windows ### # Windows thumbnail cache files Thumbs.db Thumbs.db:encryptable ehthumbs.db ehthumbs_vista.db # Dump file *.stackdump # Folder config file [Dd]esktop.ini # Recycle Bin used on file shares $RECYCLE.BIN/ # Windows Installer files *.cab *.msi *.msix *.msm *.msp # Windows shortcuts *.lnk # End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python output/ ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # FLUX by Black Forest Labs: https://bfl.ai. Documentation for our API can be found here: [docs.bfl.ai](https://docs.bfl.ai/). ![grid](assets/grid.jpg) This repo contains minimal inference code to run image generation & editing with our Flux open-weight models. ## Local installation ```bash cd $HOME && git clone https://github.com/black-forest-labs/flux cd $HOME/flux python3.10 -m venv .venv source .venv/bin/activate pip install -e ".[all]" ``` ### Local installation with TensorRT support If you would like to install the repository with [TensorRT](https://github.com/NVIDIA/TensorRT) support, you currently need to install a PyTorch image from NVIDIA instead. First install [enroot](https://github.com/NVIDIA/enroot), next follow the steps below: ```bash cd $HOME && git clone https://github.com/black-forest-labs/flux enroot import 'docker://$oauthtoken@nvcr.io#nvidia/pytorch:25.01-py3' enroot create -n pti2501 nvidia+pytorch+25.01-py3.sqsh enroot start --rw -m ${PWD}/flux:/workspace/flux -r pti2501 cd flux pip install -e ".[tensorrt]" --extra-index-url https://pypi.nvidia.com ``` ### Open-weight models We are offering an extensive suite of open-weight models. For more information about the individual models, please refer to the link under **Usage**. | Name | Usage | HuggingFace repo | License | | --------------------------- | ---------------------------------------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | | `FLUX.1 [schnell]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) | | `FLUX.1 [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Fill [dev]` | [In/Out-painting](docs/fill.md) | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Canny [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Depth [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Canny [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Depth [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Redux [dev]` | [Image variation](docs/image-variation.md) | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Kontext [dev]` | [Image editing](docs/image-editing.md) | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | | `FLUX.1 Krea [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above. ## API usage Our API offers access to all models including our Pro tier non-open weight models. Check out our API documentation [docs.bfl.ai](https://docs.bfl.ai/) to learn more. ## Licensing models for commercial use You can license our models for commercial use here: https://bfl.ai/pricing/licensing As the fee is based on a monthly usage, we provide code to automatically track your usage via the BFL API. To enable usage tracking please select *track_usage* in the cli or click the corresponding checkmark in our provided demos. ### Example: Using FLUX.1 Kontext with usage tracking We provide a reference implementation for running FLUX.1 with usage tracking enabled for commercial licensing. This can be customized as needed as long as the usage reporting is accurate. For the reporting logic to work you will need to set your API key as an environment variable before running: ```bash export BFL_API_KEY="your_api_key_here" ``` You can call `FLUX.1 Kontext [dev]` like this with tracking activated: ```bash python -m flux kontext --track_usage --loop ``` For a single generation: ```bash python -m flux kontext --track_usage --prompt "replace the logo with the text 'Black Forest Labs'" ``` The above reporting logic works similarly for FLUX.1 [dev] and FLUX.1 Tools [dev]. **Note that this is only required when using one or more of our open weights models commercially. More information on the commercial licensing can be found at the [BFL Helpdesk](https://help.bfl.ai/collections/6939000511-licensing).** ## Citation If you find the provided code or models useful for your research, consider citing them as: ```bib @misc{labs2025flux1kontextflowmatching, title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space}, author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith}, year={2025}, eprint={2506.15742}, archivePrefix={arXiv}, primaryClass={cs.GR}, url={https://arxiv.org/abs/2506.15742}, } @misc{flux2024, author={Black Forest Labs}, title={FLUX}, year={2024}, howpublished={\url{https://github.com/black-forest-labs/flux}}, } ``` ================================================ FILE: demo_gr.py ================================================ import os import time import uuid import gradio as gr import numpy as np import torch from einops import rearrange from PIL import ExifTags, Image from transformers import pipeline from flux.cli import SamplingOptions from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack from flux.util import ( configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5, track_usage_via_api, ) NSFW_THRESHOLD = 0.85 def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) ae = load_ae(name, device="cpu" if offload else device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier class FluxGenerator: def __init__(self, model_name: str, device: str, offload: bool, track_usage: bool): self.device = torch.device(device) self.offload = offload self.model_name = model_name self.is_schnell = model_name == "flux-schnell" self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models( model_name, device=self.device, offload=self.offload, is_schnell=self.is_schnell, ) self.track_usage = track_usage @torch.inference_mode() def generate_image( self, width, height, num_steps, guidance, seed, prompt, init_image=None, image2image_strength=0.0, add_sampling_metadata=True, ): seed = int(seed) if seed == -1: seed = None opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if opts.seed is None: opts.seed = torch.Generator(device="cpu").seed() print(f"Generating '{opts.prompt}' with seed {opts.seed}") t0 = time.perf_counter() if init_image is not None: if isinstance(init_image, np.ndarray): init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0 init_image = init_image.unsqueeze(0) init_image = init_image.to(self.device) init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) if self.offload: self.ae.encoder.to(self.device) init_image = self.ae.encode(init_image.to()) if self.offload: self.ae = self.ae.cpu() torch.cuda.empty_cache() # prepare input x = get_noise( 1, opts.height, opts.width, device=self.device, dtype=torch.bfloat16, seed=opts.seed, ) timesteps = get_schedule( opts.num_steps, x.shape[-1] * x.shape[-2] // 4, shift=(not self.is_schnell), ) if init_image is not None: t_idx = int((1 - image2image_strength) * num_steps) t = timesteps[t_idx] timesteps = timesteps[t_idx:] x = t * x + (1.0 - t) * init_image.to(x.dtype) if self.offload: self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt) # offload TEs to CPU, load model to gpu if self.offload: self.t5, self.clip = self.t5.cpu(), self.clip.cpu() torch.cuda.empty_cache() self.model = self.model.to(self.device) # denoise initial noise x = denoise(self.model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if self.offload: self.model.cpu() torch.cuda.empty_cache() self.ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): x = self.ae.decode(x) if self.offload: self.ae.decoder.cpu() torch.cuda.empty_cache() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s.") # bring into PIL format x = x.clamp(-1, 1) x = embed_watermark(x.float()) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0] if nsfw_score < NSFW_THRESHOLD: filename = f"output/gradio/{uuid.uuid4()}.jpg" os.makedirs(os.path.dirname(filename), exist_ok=True) exif_data = Image.Exif() if init_image is None: exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" else: exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" exif_data[ExifTags.Base.Make] = "Black Forest Labs" exif_data[ExifTags.Base.Model] = self.model_name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0) if self.track_usage: track_usage_via_api(self.model_name, 1) return img, str(opts.seed), filename, None else: return None, str(opts.seed), None, "Your generated image may contain NSFW content." def create_demo( model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, track_usage: bool = False, ): generator = FluxGenerator(model_name, device, offload, track_usage) is_schnell = model_name == "flux-schnell" with gr.Blocks() as demo: gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt", value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture', ) do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell) init_image = gr.Image(label="Input Image", visible=False) image2image_strength = gr.Slider( 0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False ) with gr.Accordion("Advanced Options", open=False): width = gr.Slider(128, 8192, 1360, step=16, label="Width") height = gr.Slider(128, 8192, 768, step=16, label="Height") num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps") guidance = gr.Slider( 1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell ) seed = gr.Textbox(-1, label="Seed (-1 for random)") add_sampling_metadata = gr.Checkbox( label="Add sampling parameters to metadata?", value=True ) generate_btn = gr.Button("Generate") with gr.Column(): output_image = gr.Image(label="Generated Image") seed_output = gr.Number(label="Used Seed") warning_text = gr.Textbox(label="Warning", visible=False) download_btn = gr.File(label="Download full-resolution") def update_img2img(do_img2img): return { init_image: gr.update(visible=do_img2img), image2image_strength: gr.update(visible=do_img2img), } do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength]) generate_btn.click( fn=generator.generate_image, inputs=[ width, height, num_steps, guidance, seed, prompt, init_image, image2image_strength, add_sampling_metadata, ], outputs=[output_image, seed_output, download_btn, warning_text], ) return demo if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Flux") parser.add_argument( "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name" ) parser.add_argument( "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use" ) parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") parser.add_argument("--share", action="store_true", help="Create a public link to your demo") parser.add_argument("--track_usage", action="store_true", help="Track usage for licensing purposes") args = parser.parse_args() demo = create_demo(args.name, args.device, args.offload, args.track_usage) demo.launch(share=args.share) ================================================ FILE: demo_st.py ================================================ import os import re import time from glob import iglob from io import BytesIO import streamlit as st import torch from einops import rearrange from fire import Fire from PIL import ExifTags, Image from st_keyup import st_keyup from torchvision import transforms from transformers import pipeline from flux.cli import SamplingOptions from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack from flux.util import ( configs, embed_watermark, load_ae, load_clip, load_flow_model, load_t5, track_usage_via_api, ) NSFW_THRESHOLD = 0.85 @st.cache_resource() def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool): t5 = load_t5(device, max_length=256 if is_schnell else 512) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) ae = load_ae(name, device="cpu" if offload else device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier def get_image() -> torch.Tensor | None: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"]) if image is None: return None image = Image.open(image).convert("RGB") transform = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: 2.0 * x - 1.0), ] ) img: torch.Tensor = transform(image) return img[None, ...] @torch.inference_mode() def main( device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, output_dir: str = "output", track_usage: bool = False, ): torch_device = torch.device(device) names = list(configs.keys()) name = st.selectbox("Which model to load?", names) if name is None or not st.checkbox("Load model", False): return is_schnell = name == "flux-schnell" model, ae, t5, clip, nsfw_classifier = get_models( name, device=torch_device, offload=offload, is_schnell=is_schnell, ) do_img2img = ( st.checkbox( "Image to Image", False, disabled=is_schnell, help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev", ) and not is_schnell ) if do_img2img: init_image = get_image() if init_image is None: st.warning("Please add an image to do image to image") image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8) if init_image is not None: h, w = init_image.shape[-2:] st.write(f"Got image of size {w}x{h} ({h * w / 1e6:.2f}MP)") resize_img = st.checkbox("Resize image", False) or init_image is None else: init_image = None resize_img = True image2image_strength = 0.0 # allow for packing and conversion to latent space width = int( 16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16) ) height = int( 16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16) ) num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50))) guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell)) seed_str = st.text_input("Seed", disabled=is_schnell) if seed_str.isdecimal(): seed = int(seed_str) else: st.info("No seed set, set to positive integer to enable") seed = None save_samples = st.checkbox("Save samples?", not is_schnell) add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True) default_prompt = ( "a photo of a forest with mist swirling around the tree trunks. The word " '"FLUX" is painted over it in big, red brush strokes with visible texture' ) prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text") output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 rng = torch.Generator(device="cpu") if "seed" not in st.session_state: st.session_state.seed = rng.seed() def increment_counter(): st.session_state.seed += 1 def decrement_counter(): if st.session_state.seed > 0: st.session_state.seed -= 1 opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if name == "flux-schnell": cols = st.columns([5, 1, 1, 5]) with cols[1]: st.button("↩", on_click=increment_counter) with cols[2]: st.button("↪", on_click=decrement_counter) if is_schnell or st.button("Sample"): if is_schnell: opts.seed = st.session_state.seed elif opts.seed is None: opts.seed = rng.seed() print(f"Generating '{opts.prompt}' with seed {opts.seed}") t0 = time.perf_counter() if init_image is not None: if resize_img: init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width)) else: h, w = init_image.shape[-2:] init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)] opts.height = init_image.shape[-2] opts.width = init_image.shape[-1] if offload: ae.encoder.to(torch_device) init_image = ae.encode(init_image.to(torch_device)) if offload: ae = ae.cpu() torch.cuda.empty_cache() # prepare input x = get_noise( 1, opts.height, opts.width, device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) # divide pixel space by 16**2 to account for latent space conversion timesteps = get_schedule( opts.num_steps, (x.shape[-1] * x.shape[-2]) // 4, shift=(not is_schnell), ) if init_image is not None: t_idx = int((1 - image2image_strength) * num_steps) t = timesteps[t_idx] timesteps = timesteps[t_idx:] x = t * x + (1.0 - t) * init_image.to(x.dtype) if offload: t5, clip = t5.to(torch_device), clip.to(torch_device) inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt) # offload TEs to CPU, load model to gpu if offload: t5, clip = t5.cpu(), clip.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if offload: ae.decoder.cpu() torch.cuda.empty_cache() t1 = time.perf_counter() fn = output_name.format(idx=idx) print(f"Done in {t1 - t0:.1f}s.") # bring into PIL format and save x = x.clamp(-1, 1) x = embed_watermark(x.float()) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] if nsfw_score < NSFW_THRESHOLD: buffer = BytesIO() exif_data = Image.Exif() if init_image is None: exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" else: exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" exif_data[ExifTags.Base.Make] = "Black Forest Labs" exif_data[ExifTags.Base.Model] = name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) img_bytes = buffer.getvalue() if save_samples: print(f"Saving {fn}") with open(fn, "wb") as file: file.write(img_bytes) idx += 1 if track_usage: track_usage_via_api(name, 1) st.session_state["samples"] = { "prompt": opts.prompt, "img": img, "seed": opts.seed, "bytes": img_bytes, } opts.seed = None else: st.warning("Your generated image may contain NSFW content.") st.session_state["samples"] = None samples = st.session_state.get("samples", None) if samples is not None: st.image(samples["img"], caption=samples["prompt"]) st.download_button( "Download full-resolution", samples["bytes"], file_name="generated.jpg", mime="image/jpg", ) st.write(f"Seed: {samples['seed']}") def app(): Fire(main) if __name__ == "__main__": app() ================================================ FILE: demo_st_fill.py ================================================ import os import re import tempfile import time from glob import iglob from io import BytesIO import numpy as np import streamlit as st import torch from einops import rearrange from fire import Fire from PIL import ExifTags, Image from st_keyup import st_keyup from streamlit_drawable_canvas import st_canvas from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack from flux.util import ( embed_watermark, load_ae, load_clip, load_flow_model, load_t5, track_usage_via_api, ) NSFW_THRESHOLD = 0.85 def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, zoom_up=0, zoom_down=0, overlap=0): """Adds a black border around the image with individual side control and mask overlap""" orig_width, orig_height = image.size # Calculate padding for each side (in pixels) left_pad = int(orig_width * zoom_left) right_pad = int(orig_width * zoom_right) top_pad = int(orig_height * zoom_up) bottom_pad = int(orig_height * zoom_down) # Calculate overlap in pixels overlap_left = int(orig_width * overlap) overlap_right = int(orig_width * overlap) overlap_top = int(orig_height * overlap) overlap_bottom = int(orig_height * overlap) # If using the all-sides zoom, add it to each side if zoom_all > 1.0: extra_each_side = (zoom_all - 1.0) / 2 left_pad += int(orig_width * extra_each_side) right_pad += int(orig_width * extra_each_side) top_pad += int(orig_height * extra_each_side) bottom_pad += int(orig_height * extra_each_side) # Calculate new dimensions (ensure they're multiples of 32) new_width = 32 * round((orig_width + left_pad + right_pad) / 32) new_height = 32 * round((orig_height + top_pad + bottom_pad) / 32) # Create new image with black border bordered_image = Image.new("RGB", (new_width, new_height), (0, 0, 0)) # Paste original image in position paste_x = left_pad paste_y = top_pad bordered_image.paste(image, (paste_x, paste_y)) # Create mask (white where the border is, black where the original image was) mask = Image.new("L", (new_width, new_height), 255) # White background # Paste black rectangle with overlap adjustment mask.paste( 0, ( paste_x + overlap_left, # Left edge moves right paste_y + overlap_top, # Top edge moves down paste_x + orig_width - overlap_right, # Right edge moves left paste_y + orig_height - overlap_bottom, # Bottom edge moves up ), ) return bordered_image, mask @st.cache_resource() def get_models(name: str, device: torch.device, offload: bool): t5 = load_t5(device, max_length=128) clip = load_clip(device) model = load_flow_model(name, device="cpu" if offload else device) ae = load_ae(name, device="cpu" if offload else device) nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) return model, ae, t5, clip, nsfw_classifier def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -> Image.Image: width, height = img.size mp = (width * height) / 1_000_000 # Current megapixels if min_mp <= mp <= max_mp: # Even if MP is in range, ensure dimensions are multiples of 32 new_width = int(32 * round(width / 32)) new_height = int(32 * round(height / 32)) if new_width != width or new_height != height: return img.resize((new_width, new_height), Image.Resampling.LANCZOS) return img # Calculate scaling factor if mp < min_mp: scale = (min_mp / mp) ** 0.5 else: # mp > max_mp scale = (max_mp / mp) ** 0.5 new_width = int(32 * round(width * scale / 32)) new_height = int(32 * round(height * scale / 32)) return img.resize((new_width, new_height), Image.Resampling.LANCZOS) def clear_canvas_state(): """Clear all canvas-related state""" keys_to_clear = ["canvas", "last_image_dims"] for key in keys_to_clear: if key in st.session_state: del st.session_state[key] def set_new_image(img: Image.Image): """Safely set a new image and clear relevant state""" st.session_state["current_image"] = img clear_canvas_state() st.rerun() def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image: """Downscale image by a given factor while maintaining 32-pixel multiple dimensions""" if scale_factor >= 1.0: return img width, height = img.size new_width = int(32 * round(width * scale_factor / 32)) new_height = int(32 * round(height * scale_factor / 32)) # Ensure minimum dimensions new_width = max(64, new_width) # minimum 64 pixels new_height = max(64, new_height) # minimum 64 pixels return img.resize((new_width, new_height), Image.Resampling.LANCZOS) @torch.inference_mode() def main( device: str = "cuda" if torch.cuda.is_available() else "cpu", offload: bool = False, output_dir: str = "output", track_usage: bool = False, ): torch_device = torch.device(device) st.title("Flux Fill: Inpainting & Outpainting") # Model selection and loading name = "flux-dev-fill" if not st.checkbox("Load model", False): return try: model, ae, t5, clip, nsfw_classifier = get_models( name, device=torch_device, offload=offload, ) except Exception as e: st.error(f"Error loading models: {e}") return # Mode selection mode = st.radio("Select Mode", ["Inpainting", "Outpainting"]) # Image handling - either from previous generation or new upload if "input_image" in st.session_state: image = st.session_state["input_image"] del st.session_state["input_image"] set_new_image(image) st.write("Continuing from previous result") else: uploaded_image = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"]) if uploaded_image is None: st.warning("Please upload an image") return if ( "current_image_name" not in st.session_state or st.session_state["current_image_name"] != uploaded_image.name ): try: image = Image.open(uploaded_image).convert("RGB") st.session_state["current_image_name"] = uploaded_image.name set_new_image(image) except Exception as e: st.error(f"Error loading image: {e}") return else: image = st.session_state.get("current_image") if image is None: st.error("Error: Image state is invalid. Please reupload the image.") clear_canvas_state() return # Add downscale control with st.expander("Image Size Control"): current_mp = (image.size[0] * image.size[1]) / 1_000_000 st.write(f"Current image size: {image.size[0]}x{image.size[1]} ({current_mp:.1f}MP)") scale_factor = st.slider( "Downscale Factor", min_value=0.1, max_value=1.0, value=1.0, step=0.1, help="1.0 = original size, 0.5 = half size, etc.", ) if scale_factor < 1.0 and st.button("Apply Downscaling"): image = downscale_image(image, scale_factor) set_new_image(image) st.rerun() # Resize image with validation try: original_mp = (image.size[0] * image.size[1]) / 1_000_000 image = resize(image) width, height = image.size current_mp = (width * height) / 1_000_000 if width % 32 != 0 or height % 32 != 0: st.error("Error: Image dimensions must be multiples of 32") return st.write(f"Image dimensions: {width}x{height} pixels") if original_mp != current_mp: st.write( f"Image has been resized from {original_mp:.1f}MP to {current_mp:.1f}MP to stay within bounds (0.5MP - 2MP)" ) except Exception as e: st.error(f"Error processing image: {e}") return if mode == "Outpainting": # Outpainting controls zoom_all = st.slider("Zoom Out Amount (All Sides)", min_value=1.0, max_value=3.0, value=1.0, step=0.1) with st.expander("Advanced Zoom Controls"): st.info("These controls add additional zoom to specific sides") col1, col2 = st.columns(2) with col1: zoom_left = st.slider("Left", min_value=0.0, max_value=1.0, value=0.0, step=0.1) zoom_right = st.slider("Right", min_value=0.0, max_value=1.0, value=0.0, step=0.1) with col2: zoom_up = st.slider("Up", min_value=0.0, max_value=1.0, value=0.0, step=0.1) zoom_down = st.slider("Down", min_value=0.0, max_value=1.0, value=0.0, step=0.1) overlap = st.slider("Overlap", min_value=0.01, max_value=0.25, value=0.01, step=0.01) # Generate bordered image and mask image_for_generation, mask = add_border_and_mask( image, zoom_all=zoom_all, zoom_left=zoom_left, zoom_right=zoom_right, zoom_up=zoom_up, zoom_down=zoom_down, overlap=overlap, ) width, height = image_for_generation.size # Show preview col1, col2 = st.columns(2) with col1: st.image(image_for_generation, caption="Image with Border") with col2: st.image(mask, caption="Mask (white areas will be generated)") else: # Inpainting mode # Canvas setup with dimension tracking canvas_key = f"canvas_{width}_{height}" if "last_image_dims" not in st.session_state: st.session_state.last_image_dims = (width, height) elif st.session_state.last_image_dims != (width, height): clear_canvas_state() st.session_state.last_image_dims = (width, height) st.rerun() try: canvas_result = st_canvas( fill_color="rgba(255, 255, 255, 0.0)", stroke_width=st.slider("Brush size", 1, 500, 50), stroke_color="#fff", background_image=image, height=height, width=width, drawing_mode="freedraw", key=canvas_key, display_toolbar=True, ) except Exception as e: st.error(f"Error creating canvas: {e}") clear_canvas_state() st.rerun() return # Sampling parameters num_steps = int(st.number_input("Number of steps", min_value=1, value=50)) guidance = float(st.number_input("Guidance", min_value=1.0, value=30.0)) seed_str = st.text_input("Seed") if seed_str.isdecimal(): seed = int(seed_str) else: st.info("No seed set, using random seed") seed = None save_samples = st.checkbox("Save samples?", True) add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True) # Prompt input prompt = st_keyup("Enter a prompt", value="", debounce=300, key="interactive_text") # Setup output path output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] idx = len(fns) if st.button("Generate"): valid_input = False if mode == "Inpainting" and canvas_result.image_data is not None: valid_input = True # Create mask from canvas try: mask = Image.fromarray(canvas_result.image_data) mask = mask.getchannel("A") # Get alpha channel mask_array = np.array(mask) mask_array = (mask_array > 0).astype(np.uint8) * 255 mask = Image.fromarray(mask_array) image_for_generation = image except Exception as e: st.error(f"Error creating mask: {e}") return elif mode == "Outpainting": valid_input = True # image_for_generation and mask are already set above if not valid_input: st.error("Please draw a mask or configure outpainting settings") return # Create temporary files with ( tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img, tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask, ): try: image_for_generation.save(tmp_img.name) mask.save(tmp_mask.name) except Exception as e: st.error(f"Error saving temporary files: {e}") return try: # Generate inpainting/outpainting rng = torch.Generator(device="cpu") if seed is None: seed = rng.seed() print(f"Generating with seed {seed}:\n{prompt}") t0 = time.perf_counter() x = get_noise( 1, height, width, device=torch_device, dtype=torch.bfloat16, seed=seed, ) if offload: t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) inp = prepare_fill( t5, clip, x, prompt=prompt, ae=ae, img_cond_path=tmp_img.name, mask_path=tmp_mask.name, ) timesteps = get_schedule(num_steps, inp["img"].shape[1], shift=True) if offload: t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) x = denoise(model, **inp, timesteps=timesteps, guidance=guidance) if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) x = unpack(x.float(), height, width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s") # Process and display result x = x.clamp(-1, 1) x = embed_watermark(x.float()) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] if nsfw_score < NSFW_THRESHOLD: buffer = BytesIO() exif_data = Image.Exif() exif_data[ExifTags.Base.Software] = "AI generated;inpainting;flux" exif_data[ExifTags.Base.Make] = "Black Forest Labs" exif_data[ExifTags.Base.Model] = name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0) img_bytes = buffer.getvalue() if save_samples: fn = output_name.format(idx=idx) print(f"Saving {fn}") with open(fn, "wb") as file: file.write(img_bytes) if track_usage: track_usage_via_api(name, 1) st.session_state["samples"] = { "prompt": prompt, "img": img, "seed": seed, "bytes": img_bytes, } else: st.warning("Your generated image may contain NSFW content.") st.session_state["samples"] = None except Exception as e: st.error(f"Error during generation: {e}") return finally: # Clean up temporary files try: os.unlink(tmp_img.name) os.unlink(tmp_mask.name) except Exception as e: print(f"Error cleaning up temporary files: {e}") # Display results samples = st.session_state.get("samples", None) if samples is not None: st.image(samples["img"], caption=samples["prompt"]) col1, col2 = st.columns(2) with col1: st.download_button( "Download full-resolution", samples["bytes"], file_name="generated.jpg", mime="image/jpg", ) with col2: if st.button("Continue from this image"): # Store the generated image new_image = samples["img"] # Clear ALL canvas state clear_canvas_state() if "samples" in st.session_state: del st.session_state["samples"] # Set as current image st.session_state["current_image"] = new_image st.rerun() st.write(f"Seed: {samples['seed']}") def app(): Fire(main) if __name__ == "__main__": st.set_page_config(layout="wide") app() ================================================ FILE: docs/fill.md ================================================ ## Open-weight models FLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless edits that integrate naturally with existing images. | Name | HuggingFace repo | License | sha256sum | | ------------------- | -------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | | `FLUX.1 Fill [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 03e289f530df51d014f48e675a9ffa2141bc003259bf5f25d75b957e920a41ca | ## Examples ![inpainting](../assets/docs/inpainting.png) ![outpainting](../assets/docs/outpainting.png) ## Open-weights usage The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables: ```bash export FLUX_MODEL= export FLUX_AE= ``` For interactive sampling run ```bash python -m flux fill --loop ``` Or to generate a single sample run ```bash python -m flux fill \ --img_cond_path \ --img_mask_path ``` The input_mask should be an image of the same size as the conditioning image that only contains black and white pixels; see [an example mask](../assets/cup_mask.png) for [this image](../assets/cup.png). We also provide an interactive streamlit demo. The demo can be run via ```bash streamlit run demo_st_fill.py ``` ================================================ FILE: docs/image-editing.md ================================================ ## Open-weight models We currently offer two open-weight text-to-image models. | Name | HuggingFace repo | License | sha256sum | | ------------------------- | ----------------------------------------------------------------| --------------------------------------------------------------------- | ---------------------------------------------------------------- | | `FLUX.1 Kontext [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 843a26dc765d3105dba081c30bce7b14c65b0988f9e8d14e9fbc8856a6deebd5 | ## Examples ![FLUX.1 [dev] Grid](../assets/docs/kontext.png) ## Open-weights usage The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables: ```bash export FLUX_MODEL= export FLUX_AE= ``` For interactive sampling run ```bash python -m flux kontext --loop ``` Or to generate a single sample run ```bash python -m flux kontext \ --img_cond_path \ --prompt \ --num_steps 30 --aspect_ratio "16:9" --guidance 2.5 --seed 1 ``` Note that the flags `num_steps`, `aspect_ratio`, `guidance` and `seed` are optional. For more available flags see [the code](../src/flux/cli_kontext.py). ### TRT engine infernece We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md). ```bash python -m flux kontext --loop --trt --trt_transformer_precision ``` where `` is either `bf16`, `fp8`, or `fp4_sdvd32`. ================================================ FILE: docs/image-variation.md ================================================ ## Models FLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which can be used to generate image variations. | Name | HuggingFace repo | License | sha256sum | | --------------------------- | ----------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | | `FLUX.1 Redux [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | a1b3bdcb4bdc58ce04874b9ca776d61fc3e914bb6beab41efb63e4e2694dca45 | ## Examples ![redux](../assets/docs/redux.png) ## Open-weights usage The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables: ```bash export FLUX_MODEL= export FLUX_REDUX= export FLUX_AE= ``` For interactive sampling run: ```bash python -m flux redux --name --loop ``` where `name` specifies the base model, which should be one of `flux-dev` or `flux-schnell`. ================================================ FILE: docs/structural-conditioning.md ================================================ ## Models Structural conditioning uses canny edge or depth detection to maintain precise control during image transformations. By preserving the original image's structure through edge or depth maps, users can make text-guided edits while keeping the core composition intact. This is particularly effective for retexturing images. We release four variations: two based on edge maps (full model and LoRA for FLUX.1 [dev]) and two based on depth maps (full model and LoRA for FLUX.1 [dev]). | Name | HuggingFace repo | License | sha256sum | | ------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- | | `FLUX.1 Canny [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 996876670169591cb412b937fbd46ea14cbed6933aef17c48a2dcd9685c98cdb | | `FLUX.1 Depth [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 41360d1662f44ca45bc1b665fe6387e91802f53911001630d970a4f8be8dac21 | | `FLUX.1 Canny [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 8eaa21b9c43d5e7242844deb64b8cf22ae9010f813f955ca8c05f240b8a98f7e | | `FLUX.1 Depth [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 1938b38ea0fdd98080fa3e48beb2bedfbc7ad102d8b65e6614de704a46d8b907 | ## Examples ![canny](../assets/docs/canny.png) ![depth](../assets/docs/depth.png) ## Open-weights usage The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables: ```bash export FLUX_MODEL= export FLUX_AE= # optional (see below) export FLUX_LORA= ``` Note that the LoRA models (`flux-dev-canny-lora` and `flux-dev-depth-lora`) require the base FLUX.1 [dev] model to be downloaded first. The system will automatically download both the base model and the LoRA adapter when using these variants. For interactive sampling run ```bash python -m flux control --name --loop ``` where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`. ### TRT engine inference We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md). ```bash python flux control --name= --loop --img_cond_path="assets/robot.webp" --trt --static_shape=False --trt_transformer_precision ``` where `` is either `bf16`, `fp8`, or `fp4`. ## Diffusers usage Flux Control (including the LoRAs) is also compatible with the `diffusers` Python library. Check out the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more. ================================================ FILE: docs/text-to-image.md ================================================ ## Open-weight models We currently offer two open-weight text-to-image models. | Name | HuggingFace repo | License | sha256sum | | --------------------------|----------------------------------------------------------| -------------------------------------------------------------------------|----------------------------------------------------------------- | | `FLUX.1 [schnell]` | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](../model_licenses/LICENSE-FLUX1-schnell) | 9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a | | `FLUX.1 [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 | | `FLUX.1 Krea [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 | ## Open-weights usage The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables: ```bash export FLUX_MODEL= export FLUX_AE= ``` For interactive sampling run ```bash python -m flux t2i --name --loop ``` where `name` is one of `flux-dev` or `flux-schnell`. Or to generate a single sample run ```bash python -m flux t2i --name \ --height --width \ --prompt "" ``` ### TRT engine infernece We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md). ```bash python -m flux t2i --name= --loop --trt --trt_transformer_precision ``` where `` is either `bf16`, `fp8`, or `fp4`. For ONNX exports, `height` and `width` have to be within 768 and 1344. ### Streamlit and Gradio We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via ```bash streamlit run demo_st.py ``` We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo: ```bash python demo_gr.py --name flux-schnell --device cuda ``` Options: - `--name`: Choose the model to use (options: "flux-schnell", "flux-dev") - `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu") - `--offload`: Offload model to CPU when not in use - `--share`: Create a public link to your demo To run the demo with the dev model and create a public link: ```bash python demo_gr.py --name flux-dev --share ``` ## Diffusers integration `FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it: ```shell pip install git+https://github.com/huggingface/diffusers.git ``` Then you can use `FluxPipeline` to run the model ```python import torch from diffusers import FluxPipeline model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev` pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power prompt = "A cat holding a sign that says hello world" seed = 42 image = pipe( prompt, output_type="pil", num_inference_steps=4, #use a larger number if you are using [dev] generator=torch.Generator("cpu").manual_seed(seed) ).images[0] image.save("flux-schnell.png") ``` To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation ================================================ FILE: model_cards/FLUX.1-Krea-dev.md ================================================ ![FLUX.1 Krea [dev] Grid](../assets/flux-1-krea-dev-grid.png) `FLUX.1 Krea [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-krea-dev). # Key Features 1. Cutting-edge output quality, with a focus on aesthetic photography. 2. Competitive prompt following, matching the performance of closed source alternatives. 3. Trained using guidance distillation, making `FLUX.1 Krea [dev]` more efficient. 4. Open weights to drive new scientific research, and empower artists to develop innovative workflows. 5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev). # Usage `FLUX.1 Krea [dev]` can be used as a drop-in replacement in every system that supports the original `FLUX.1 [dev]`. A reference implementation of `FLUX.1 [dev]` is in our dedicated [github repository](https://github.com/black-forest-labs/flux). Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point. `FLUX.1 Krea [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers). --- # Limitations - This model is not intended or able to provide factual information. - As a statistical model this checkpoint might amplify existing societal biases. - The model may fail to generate output that matches the prompts. - Prompt following is heavily influenced by the prompting-style. --- # Out-of-Scope Use The model and its derivatives may not be used - In any way that violates any applicable national, federal, state, local or international law or regulation. - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content. - To generate or disseminate verifiably false information and/or content with the purpose of harming others. - To generate or disseminate personal identifiable information that can be used to harm an individual. - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals. - To create non-consensual nudity or illegal pornographic content. - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation. - Generating or facilitating large-scale disinformation campaigns. - Please reference our [content filters](https://github.com/black-forest-labs/flux/blob/main/src/flux/content_filters.py) to avoid such generations. --- # Risks Black Forest Labs (BFL) and Krea are committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Krea [dev], BFL and Krea collaboratively evaluated and mitigated a number of risks in the FLUX.1 Krea [dev] model and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks: 1. **Pre-training mitigation.** BFL filtered pre-training data for multiple categories of “not safe for work” (NSFW) and unlawful content to help prevent a user generating unlawful content in response to text prompts or uploaded images. 2. **Post-training mitigation.** BFL has partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, BFL and Krea undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt. 3. **Pre-release evaluation.** Throughout this process, BFL conducted internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations focused on eliciting CSAM and NCII through adversarial testing of the text-to-image model with text-only prompts. We also conducted internal evaluations of the proposed release checkpoints, comparing the model with other leading openly-available generative image models from other companies. The final FLUX.1 Krea [dev] open-weight model checkpoint demonstrated very high resilience against violative inputs, demonstrating higher resilience than other similar open-weight models across these risk categories. Based on these findings, we approved the release of the FLUX.1 Krea [dev] model as openly-available weights under a non-commercial license to support third-party research and development. 4. **Inference filters.** The BFL Github repository for the open FLUX.1 Krea [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Krea [dev] model at random to verify that filters or manual review processes are in place. 5. **Policies.** Our FLUX.1 [dev] Non-Commercial License prohibits the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX.1 Krea [dev] model. 6. **Monitoring.** BFL is monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies. Additionally, BFL provides a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. BFL maintains a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and BFL welcomes ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations. --- # License This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). ================================================ FILE: model_cards/FLUX.1-dev.md ================================================ ![FLUX.1 [dev] Grid](../assets/dev_grid.jpg) `FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/). # Key Features 1. Cutting-edge output quality, second only to our state-of-the-art model `FLUX.1 [pro]`. 2. Competitive prompt following, matching the performance of closed source alternatives. 3. Trained using guidance distillation, making `FLUX.1 [dev]` more efficient. 4. Open weights to drive new scientific research, and empower artists to develop innovative workflows. 5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](./licence.md). # Usage We provide a reference implementation of `FLUX.1 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux). Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point. ## API Endpoints The FLUX.1 models are also available via API from the following sources 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`) 2. [replicate.com](https://replicate.com/collections/flux) 3. [fal.ai](https://fal.ai/models/fal-ai/flux/dev) ## ComfyUI `FLUX.1 [dev]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow. --- # Limitations - This model is not intended or able to provide factual information. - As a statistical model this checkpoint might amplify existing societal biases. - The model may fail to generate output that matches the prompts. - Prompt following is heavily influenced by the prompting-style. # Out-of-Scope Use The model and its derivatives may not be used - In any way that violates any applicable national, federal, state, local or international law or regulation. - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content. - To generate or disseminate verifiably false information and/or content with the purpose of harming others. - To generate or disseminate personal identifiable information that can be used to harm an individual. - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals. - To create non-consensual nudity or illegal pornographic content. - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation. - Generating or facilitating large-scale disinformation campaigns. # License This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md). ================================================ FILE: model_cards/FLUX.1-kontext-dev.md ================================================ ![FLUX.1 [dev] Grid](../assets/docs/kontext.png) `FLUX.1 Kontext [dev]` is a 12 billion parameter rectified flow transformer capable of editing images based on text instructions. For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-kontext-dev) and our [technical report](https://arxiv.org/abs/2506.15742). You can find information about the `[pro]` version in [here](https://bfl.ai/models/flux-kontext). # Key Features 1. Change existing images based on an edit instruction. 2. Have character, style and object reference without any finetuning. 3. Robust consistency allows users to refine an image through multiple successive edits with minimal visual drift. 4. Trained using guidance distillation, making `FLUX.1 Kontext [dev]` more efficient. 5. Open weights to drive new scientific research, and empower artists to develop innovative workflows. 6. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev). # Usage We provide a reference implementation of `FLUX.1 Kontext [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux). Developers and creatives looking to build on top of `FLUX.1 Kontext [dev]` are encouraged to use this as a starting point. `FLUX.1 Kontext [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers). ## API Endpoints The FLUX.1 Kontext models are also available via API from the following sources - bfl.ai: https://docs.bfl.ai/ - DataCrunch: https://datacrunch.io/flux-kontext - fal: https://fal.ai/flux-kontext - Replicate: https://replicate.com/blog/flux-kontext - https://replicate.com/black-forest-labs/flux-kontext-dev - https://replicate.com/black-forest-labs/flux-kontext-pro - https://replicate.com/black-forest-labs/flux-kontext-max - Runware: https://runware.ai/blog/introducing-flux1-kontext-instruction-based-image-editing-with-ai?utm_source=bfl - TogetherAI: https://www.together.ai/models/flux-1-kontext-dev --- # Risks Risks Black Forest Labs is committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Kontext, we evaluated and mitigated a number of risks in our models and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks: 1. **Pre-training mitigation**. We filtered pre-training data for multiple categories of “not safe for work” (NSFW) content to help prevent a user generating unlawful content in response to text prompts or uploaded images. 2. **Post-training mitigation.** We have partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, we undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt, or transforming an uploaded image into synthetic CSAM or NCII. 3. **Pre-release evaluation.** Throughout this process, we conducted multiple internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations—which included 21 checkpoints of FLUX.1 Kontext [pro] and [dev]—focused on eliciting CSAM and NCII through adversarial testing with text-only prompts, as well as uploaded images with text prompts. Next, we conducted a final third-party evaluation of the proposed release checkpoints, focused on text-to-image and image-to-image CSAM and NCII generation. The final FLUX.1 Kontext [pro] (as offered through the FLUX API only) and FLUX.1 Kontext [dev] (released as an open-weight model) checkpoints demonstrated very high resilience against violative inputs, and FLUX.1 Kontext [dev] demonstrated higher resilience than other similar open-weight models across these risk categories. Based on these findings, we approved the release of the FLUX.1 Kontext [pro] model via API, and the release of the FLUX.1 Kontext [dev] model as openly-available weights under a non-commercial license to support third-party research and development. 4. **Inference filters.** We are applying multiple filters to intercept text prompts, uploaded images, and output images on the FLUX API for FLUX.1 Kontext [pro]. Filters for CSAM and NCII are provided by Hive, a third-party provider, and cannot be adjusted or removed by developers. We provide filters for other categories of potentially harmful content, including gore, which can be adjusted by developers based on their specific risk profile. Additionally, the repository for the open FLUX.1 Kontext [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Kontext [dev] model at random to verify that filters or manual review processes are in place. 5. **Content provenance.** The FLUX API applies cryptographically-signed metadata to output content to indicate that images were produced with our model. Our API implements the Coalition for Content Provenance and Authenticity (C2PA) standard for metadata. 6. **Policies.** Access to our API and use of our models are governed by our Developer Terms of Service, Usage Policy, and FLUX.1 [dev] Non-Commercial License, which prohibit the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX Kontext models. 7. **Monitoring.** We are monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies via the FLUX API. Additionally, we provide a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. We maintain a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and we welcome ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations. # License This model falls under the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev). # Citation ```bib @misc{labs2025flux1kontextflowmatching, title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space}, Add commentMore actions author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith}, year={2025}, eprint={2506.15742}, archivePrefix={arXiv}, primaryClass={cs.GR}, url={https://arxiv.org/abs/2506.15742}, } ``` ================================================ FILE: model_cards/FLUX.1-schnell.md ================================================ ![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg) `FLUX.1 [schnell]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/). # Key Features 1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives. 2. Trained using latent adversarial diffusion distillation, `FLUX.1 [schnell]` can generate high-quality images in only 1 to 4 steps. 3. Released under the `apache-2.0` licence, the model can be used for personal, scientific, and commercial purposes. # Usage We provide a reference implementation of `FLUX.1 [schnell]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux). Developers and creatives looking to build on top of `FLUX.1 [schnell]` are encouraged to use this as a starting point. ## API Endpoints The FLUX.1 models are also available via API from the following sources 1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`) 2. [replicate.com](https://replicate.com/collections/flux) 3. [fal.ai](https://fal.ai/models/fal-ai/flux/schnell) ## ComfyUI `FLUX.1 [schnell]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow. --- # Limitations - This model is not intended or able to provide factual information. - As a statistical model this checkpoint might amplify existing societal biases. - The model may fail to generate output that matches the prompts. - Prompt following is heavily influenced by the prompting-style. # Out-of-Scope Use The model and its derivatives may not be used - In any way that violates any applicable national, federal, state, local or international law or regulation. - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content. - To generate or disseminate verifiably false information and/or content with the purpose of harming others. - To generate or disseminate personal identifiable information that can be used to harm an individual. - To harass, abuse, threaten, stalk, or bully individuals or groups of individuals. - To create non-consensual nudity or illegal pornographic content. - For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation. - Generating or facilitating large-scale disinformation campaigns. ================================================ FILE: model_licenses/LICENSE-FLUX1-dev ================================================ FLUX.1 [dev] Non-Commercial License v1.1.1 Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models and models denoted as FLUX.1 [dev], including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, and FLUX.1 Kontext [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). Note that we may also make available certain elements of what is included in the definition of “FLUX.1 [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements. By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity. 1. Definitions. - a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License. - b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be. - c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX.1 [dev] Model, Derivatives, or FLUX Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune or distill other models for commercial use, in each case is not a Non-Commercial Purpose. - d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters. - e. “you” or “your” means the individual or entity entering into this License with Company. 2. License Grant. - a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein regarding the FLUX.1 [dev] Model also apply to any Derivative you create or that are created on your behalf. - b. Non-Commercial Use Only. You may only access, use, Distribute, or create Derivatives of the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If you want to use a FLUX.1 [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license. - c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License. - d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model or the FLUX.1 Kontext [dev] Model. - e. You may access, use, Distribute, or create Output of the FLUX.1 [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX.1 [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX.1 [dev] Model (“FLUX Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law. 3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions: - a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License; - b. you must prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”): “The FLUX.1 [dev] Model is licensed by Black Forest Labs Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs Inc. IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.” - c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing. 4. Restrictions. You will not, and will not permit, assist or cause any third party to - a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third party’s legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images; or (viii) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing; - b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model; - c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model; - d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License; - e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model; - f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods. 5. DISCLAIMERS. THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS. 6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, FLUX CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE. 7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any FLUX Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c) your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties. 8. Termination; Survival. - a. This License will automatically terminate upon any breach by you of the terms of this License. - b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you. - c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model, any Derivative, or FLUX Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated. - d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model, any Derivatives, and any FLUX Content Filters. The following sections survive termination of this License 2(c), 2(d), 4-11. 9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk. 10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators. 11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter. ================================================ FILE: model_licenses/LICENSE-FLUX1-schnell ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: You must give any other recipients of the Work or Derivative Works a copy of this License; and You must cause any modified files to carry prominent notices stating that You changed the files; and You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS ================================================ FILE: pyproject.toml ================================================ [project] name = "flux" authors = [ { name = "Black Forest Labs", email = "support@blackforestlabs.ai" }, ] description = "Inference codebase for FLUX" readme = "README.md" requires-python = ">=3.10" license = { file = "LICENSE.md" } dynamic = ["version"] dependencies = [ "accelerate", "einops", "fire >= 0.6.0", "huggingface-hub", "safetensors", "sentencepiece", "transformers >= 4.45.2", "tokenizers", "protobuf", "requests", "invisible-watermark", "ruff == 0.6.8", "accelerate", ] [project.optional-dependencies] torch = [ "torch == 2.6.0", "torchvision", ] streamlit = [ "streamlit", "streamlit-drawable-canvas", "streamlit-keyup", ] gradio = [ "gradio", ] tensorrt = [ "tensorrt-cu12 == 10.12.0.36", "colored", "opencv-python-headless==4.8.0.74", "onnx >=1.18.0", "onnxruntime ~= 1.22.0", "onnxruntime-gpu ~= 1.22.0", "onnx-graphsurgeon", "polygraphy >= 0.49.22", ] all = [ "flux[gradio]", "flux[streamlit]", "flux[torch]", ] [project.scripts] flux = "flux.cli:app" [build-system] build-backend = "setuptools.build_meta" requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"] [tool.ruff] line-length = 110 target-version = "py310" extend-exclude = ["/usr/lib/*"] [tool.ruff.lint] ignore = [ "E501", # line too long - will be fixed in format ] [tool.ruff.format] quote-style = "double" indent-style = "space" line-ending = "auto" skip-magic-trailing-comma = false docstring-code-format = true exclude = [ "src/flux/_version.py", # generated by setuptools_scm ] [tool.ruff.lint.isort] combine-as-imports = true force-wrap-aliases = true known-local-folder = ["src"] known-first-party = ["flux"] [tool.pyright] include = ["src"] exclude = [ "**/__pycache__", # cache directories "./typings", # generated type stubs ] stubPath = "./typings" [tool.tomlsort] in_place = true no_sort_tables = true spaces_before_inline_comment = 1 spaces_indent_inline_array = 2 trailing_comma_inline_array = true sort_first = [ "project", "build-system", "tool.setuptools", ] # needs to be last for CI reasons [tool.setuptools_scm] write_to = "src/flux/_version.py" parentdir_prefix_version = "flux-" fallback_version = "0.0.0" version_scheme = "post-release" ================================================ FILE: setup.py ================================================ import setuptools setuptools.setup() ================================================ FILE: src/flux/__init__.py ================================================ try: from ._version import ( version as __version__, # type: ignore version_tuple, ) except ImportError: __version__ = "unknown (no version information available)" version_tuple = (0, 0, "unknown", "noinfo") from pathlib import Path PACKAGE = __package__.replace("_", "-") PACKAGE_ROOT = Path(__file__).parent ================================================ FILE: src/flux/__main__.py ================================================ from fire import Fire from .cli import main as cli_main from .cli_control import main as control_main from .cli_fill import main as fill_main from .cli_kontext import main as kontext_main from .cli_redux import main as redux_main if __name__ == "__main__": Fire( { "t2i": cli_main, "control": control_main, "fill": fill_main, "kontext": kontext_main, "redux": redux_main, } ) ================================================ FILE: src/flux/cli.py ================================================ import os import re import time from dataclasses import dataclass from glob import iglob import torch from fire import Fire from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack from flux.util import ( check_onnx_access_for_trt, configs, load_ae, load_clip, load_flow_model, load_t5, save_image, ) NSFW_THRESHOLD = 0.85 @dataclass class SamplingOptions: prompt: str width: int height: int num_steps: int guidance: float seed: int | None def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the prompt or write a command starting with a slash:\n" "- '/w ' will set the width of the generated image\n" "- '/h ' will set the height of the generated image\n" "- '/s ' sets the next seed\n" "- '/g ' sets the guidance (flux-dev only)\n" "- '/n ' sets the number of steps\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/w"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, width = prompt.split() options.width = 16 * (int(width) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/h"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, height = prompt.split() options.height = 16 * (int(height) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/g"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, guidance = prompt.split() options.guidance = float(guidance) print(f"Setting guidance to {options.guidance}") elif prompt.startswith("/s"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, seed = prompt.split() options.seed = int(seed) print(f"Setting seed to {options.seed}") elif prompt.startswith("/n"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, steps = prompt.split() options.num_steps = int(steps) print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) if prompt != "": options.prompt = prompt return options @torch.inference_mode() def main( name: str = "flux-dev-krea", width: int = 1360, height: int = 768, seed: int | None = None, prompt: str = ( "a photo of a forest with mist swirling around the tree trunks. The word " '"FLUX" is painted over it in big, red brush strokes with visible texture' ), device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int | None = None, loop: bool = False, guidance: float = 2.5, offload: bool = False, output_dir: str = "output", add_sampling_metadata: bool = True, trt: bool = False, trt_transformer_precision: str = "bf16", track_usage: bool = False, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: name: Name of the model to load height: height of the sample in pixels (should be a multiple of 16) width: width of the sample in pixels (should be a multiple of 16) seed: Set a seed for sampling output_name: where to save the output image, `{idx}` will be replaced by the index of the sample prompt: Prompt used for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata trt: use TensorRT backend for optimized inference trt_transformer_precision: specify transformer precision for inference track_usage: track usage of the model for licensing purposes """ prompt = prompt.split("|") if len(prompt) == 1: prompt = prompt[0] additional_prompts = None else: additional_prompts = prompt[1:] prompt = prompt[0] assert not ( (additional_prompts is not None) and loop ), "Do not provide additional prompts and set loop to True" nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if name not in configs: available = ", ".join(configs.keys()) raise ValueError(f"Got unknown model name: {name}, chose from {available}") torch_device = torch.device(device) if num_steps is None: num_steps = 4 if name == "flux-schnell" else 50 # allow for packing and conversion to latent space height = 16 * (height // 16) width = 16 * (width // 16) output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 if not trt: t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) else: # lazy import to make install optional from flux.trt.trt_manager import ModuleName, TRTManager # Check if we need ONNX model access (which requires authentication for FLUX models) onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) trt_ctx_manager = TRTManager( trt_transformer_precision=trt_transformer_precision, trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"), ) engines = trt_ctx_manager.load_engines( model_name=name, module_names={ ModuleName.CLIP, ModuleName.TRANSFORMER, ModuleName.T5, ModuleName.VAE, }, engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), trt_image_height=height, trt_image_width=width, trt_batch_size=1, trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), trt_static_batch=False, trt_static_shape=False, ) ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) clip = engines[ModuleName.CLIP].to(torch_device) t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) rng = torch.Generator(device="cpu") opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, ) if loop: opts = parse_prompt(opts) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") t0 = time.perf_counter() # prepare input x = get_noise( 1, opts.height, opts.width, device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) opts.seed = None if offload: ae = ae.cpu() torch.cuda.empty_cache() t5, clip = t5.to(torch_device), clip.to(torch_device) inp = prepare(t5, clip, x, prompt=opts.prompt) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs to CPU, load model to gpu if offload: t5, clip = t5.cpu(), clip.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() fn = output_name.format(idx=idx) print(f"Done in {t1 - t0:.1f}s. Saving {fn}") idx = save_image( nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage ) if loop: print("-" * 80) opts = parse_prompt(opts) elif additional_prompts: next_prompt = additional_prompts.pop(0) opts.prompt = next_prompt else: opts = None if trt: trt_ctx_manager.stop_runtime() if __name__ == "__main__": Fire(main) ================================================ FILE: src/flux/cli_control.py ================================================ import os import re import time from dataclasses import dataclass from glob import iglob import torch from fire import Fire from transformers import pipeline from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image @dataclass class SamplingOptions: prompt: str width: int height: int num_steps: int guidance: float seed: int | None img_cond_path: str lora_scale: float | None def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the prompt or write a command starting with a slash:\n" "- '/w ' will set the width of the generated image\n" "- '/h ' will set the height of the generated image\n" "- '/s ' sets the next seed\n" "- '/g ' sets the guidance (flux-dev only)\n" "- '/n ' sets the number of steps\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/w"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, width = prompt.split() options.width = 16 * (int(width) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/h"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, height = prompt.split() options.height = 16 * (int(height) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/g"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, guidance = prompt.split() options.guidance = float(guidance) print(f"Setting guidance to {options.guidance}") elif prompt.startswith("/s"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, seed = prompt.split() options.seed = int(seed) print(f"Setting seed to {options.seed}") elif prompt.startswith("/n"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, steps = prompt.split() options.num_steps = int(steps) print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) if prompt != "": options.prompt = prompt return options def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: if options is None: return None user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the conditioning image or write a command starting with a slash:\n" "- '/q' to quit" ) while True: img_cond_path = input(user_question) if img_cond_path.startswith("/"): if img_cond_path.startswith("/q"): print("Quitting") return None else: if not img_cond_path.startswith("/h"): print(f"Got invalid command '{img_cond_path}'\n{usage}") print(usage) continue if img_cond_path == "": break if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( (".jpg", ".jpeg", ".png", ".webp") ): print(f"File '{img_cond_path}' does not exist or is not a valid image file") continue options.img_cond_path = img_cond_path break return options def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]: changed = False if options is None: return None, changed user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the lora scale or write a command starting with a slash:\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/q"): print("Quitting") return None, changed else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) if prompt != "": options.lora_scale = float(prompt) changed = True return options, changed @torch.inference_mode() def main( name: str, width: int = 1024, height: int = 1024, seed: int | None = None, prompt: str = "a robot made out of gold", device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int = 50, loop: bool = False, guidance: float | None = None, offload: bool = False, output_dir: str = "output", add_sampling_metadata: bool = True, img_cond_path: str = "assets/robot.webp", lora_scale: float | None = 0.85, trt: bool = False, trt_transformer_precision: str = "bf16", track_usage: bool = False, **kwargs: dict | None, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: height: height of the sample in pixels (should be a multiple of 16) width: width of the sample in pixels (should be a multiple of 16) seed: Set a seed for sampling output_name: where to save the output image, `{idx}` will be replaced by the index of the sample prompt: Prompt used for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata img_cond_path: path to conditioning image (jpeg/png/webp) trt: use TensorRT backend for optimized inference trt_transformer_precision: specify transformer precision for inference track_usage: track usage of the model for licensing purposes """ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if "lora" in name: assert not trt, "TRT does not support LORA" assert name in [ "flux-dev-canny", "flux-dev-depth", "flux-dev-canny-lora", "flux-dev-depth-lora", ], f"Got unknown model name: {name}" if guidance is None: if name in ["flux-dev-canny", "flux-dev-canny-lora"]: guidance = 30.0 elif name in ["flux-dev-depth", "flux-dev-depth-lora"]: guidance = 10.0 else: raise NotImplementedError() if name not in configs: available = ", ".join(configs.keys()) raise ValueError(f"Got unknown model name: {name}, chose from {available}") torch_device = torch.device(device) output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 if name in ["flux-dev-depth", "flux-dev-depth-lora"]: img_embedder = DepthImageEncoder(torch_device) elif name in ["flux-dev-canny", "flux-dev-canny-lora"]: img_embedder = CannyImageEncoder(torch_device) else: raise NotImplementedError() if not trt: # init all components t5 = load_t5(torch_device, max_length=512) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) else: # lazy import to make install optional from flux.trt.trt_manager import ModuleName, TRTManager trt_ctx_manager = TRTManager( trt_transformer_precision=trt_transformer_precision, trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), ) engines = trt_ctx_manager.load_engines( model_name=name, module_names={ ModuleName.CLIP, ModuleName.TRANSFORMER, ModuleName.T5, ModuleName.VAE, ModuleName.VAE_ENCODER, }, engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""), trt_image_height=height, trt_image_width=width, trt_batch_size=1, trt_static_batch=kwargs.get("static_batch", True), trt_static_shape=kwargs.get("static_shape", True), ) ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device) model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) clip = engines[ModuleName.CLIP].to(torch_device) t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) # set lora scale if "lora" in name and lora_scale is not None: for _, module in model.named_modules(): if hasattr(module, "set_scale"): module.set_scale(lora_scale) rng = torch.Generator(device="cpu") opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, img_cond_path=img_cond_path, lora_scale=lora_scale, ) if loop: opts = parse_prompt(opts) opts = parse_img_cond_path(opts) if "lora" in name: opts, changed = parse_lora_scale(opts) if changed: # update the lora scale: for _, module in model.named_modules(): if hasattr(module, "set_scale"): module.set_scale(opts.lora_scale) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") t0 = time.perf_counter() # prepare input x = get_noise( 1, opts.height, opts.width, device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) opts.seed = None if offload: t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) inp = prepare_control( t5, clip, x, prompt=opts.prompt, ae=ae, encoder=img_embedder, img_cond_path=opts.img_cond_path, ) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs and AE to CPU, load model to gpu if offload: t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s") idx = save_image( nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage ) if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) if "lora" in name: opts, changed = parse_lora_scale(opts) if changed: # update the lora scale: for _, module in model.named_modules(): if hasattr(module, "set_scale"): module.set_scale(opts.lora_scale) else: opts = None if trt: trt_ctx_manager.stop_runtime() if __name__ == "__main__": Fire(main) ================================================ FILE: src/flux/cli_fill.py ================================================ import os import re import time from dataclasses import dataclass from glob import iglob import torch from fire import Fire from PIL import Image from transformers import pipeline from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image @dataclass class SamplingOptions: prompt: str width: int height: int num_steps: int guidance: float seed: int | None img_cond_path: str img_mask_path: str def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the prompt or write a command starting with a slash:\n" "- '/s ' sets the next seed\n" "- '/g ' sets the guidance (flux-dev only)\n" "- '/n ' sets the number of steps\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/g"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, guidance = prompt.split() options.guidance = float(guidance) print(f"Setting guidance to {options.guidance}") elif prompt.startswith("/s"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, seed = prompt.split() options.seed = int(seed) print(f"Setting seed to {options.seed}") elif prompt.startswith("/n"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, steps = prompt.split() options.num_steps = int(steps) print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) if prompt != "": options.prompt = prompt return options def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: if options is None: return None user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the conditioning image or write a command starting with a slash:\n" "- '/q' to quit" ) while True: img_cond_path = input(user_question) if img_cond_path.startswith("/"): if img_cond_path.startswith("/q"): print("Quitting") return None else: if not img_cond_path.startswith("/h"): print(f"Got invalid command '{img_cond_path}'\n{usage}") print(usage) continue if img_cond_path == "": break if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( (".jpg", ".jpeg", ".png", ".webp") ): print(f"File '{img_cond_path}' does not exist or is not a valid image file") continue else: with Image.open(img_cond_path) as img: width, height = img.size if width % 32 != 0 or height % 32 != 0: print(f"Image dimensions must be divisible by 32, got {width}x{height}") continue options.img_cond_path = img_cond_path break return options def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None: if options is None: return None user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the conditioning mask or write a command starting with a slash:\n" "- '/q' to quit" ) while True: img_mask_path = input(user_question) if img_mask_path.startswith("/"): if img_mask_path.startswith("/q"): print("Quitting") return None else: if not img_mask_path.startswith("/h"): print(f"Got invalid command '{img_mask_path}'\n{usage}") print(usage) continue if img_mask_path == "": break if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith( (".jpg", ".jpeg", ".png", ".webp") ): print(f"File '{img_mask_path}' does not exist or is not a valid image file") continue else: with Image.open(img_mask_path) as img: width, height = img.size if width % 32 != 0 or height % 32 != 0: print(f"Image dimensions must be divisible by 32, got {width}x{height}") continue else: with Image.open(options.img_cond_path) as img_cond: img_cond_width, img_cond_height = img_cond.size if width != img_cond_width or height != img_cond_height: print( f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}" ) continue options.img_mask_path = img_mask_path break return options @torch.inference_mode() def main( seed: int | None = None, prompt: str = "a white paper cup", device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int = 50, loop: bool = False, guidance: float = 30.0, offload: bool = False, output_dir: str = "output", add_sampling_metadata: bool = True, img_cond_path: str = "assets/cup.png", img_mask_path: str = "assets/cup_mask.png", track_usage: bool = False, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. This demo assumes that the conditioning image and mask have the same shape and that height and width are divisible by 32. Args: seed: Set a seed for sampling output_name: where to save the output image, `{idx}` will be replaced by the index of the sample prompt: Prompt used for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata img_cond_path: path to conditioning image (jpeg/png/webp) img_mask_path: path to conditioning mask (jpeg/png/webp) track_usage: track usage of the model for licensing purposes """ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) name = "flux-dev-fill" if name not in configs: available = ", ".join(configs.keys()) raise ValueError(f"Got unknown model name: {name}, chose from {available}") torch_device = torch.device(device) output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 # init all components t5 = load_t5(torch_device, max_length=128) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) rng = torch.Generator(device="cpu") with Image.open(img_cond_path) as img: width, height = img.size opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, img_cond_path=img_cond_path, img_mask_path=img_mask_path, ) if loop: opts = parse_prompt(opts) opts = parse_img_cond_path(opts) with Image.open(opts.img_cond_path) as img: width, height = img.size opts.height = height opts.width = width opts = parse_img_mask_path(opts) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") t0 = time.perf_counter() # prepare input x = get_noise( 1, opts.height, opts.width, device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) opts.seed = None if offload: t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) inp = prepare_fill( t5, clip, x, prompt=opts.prompt, ae=ae, img_cond_path=opts.img_cond_path, mask_path=opts.img_mask_path, ) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs and AE to CPU, load model to gpu if offload: t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s") idx = save_image( nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage ) if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) with Image.open(opts.img_cond_path) as img: width, height = img.size opts.height = height opts.width = width opts = parse_img_mask_path(opts) else: opts = None if __name__ == "__main__": Fire(main) ================================================ FILE: src/flux/cli_kontext.py ================================================ import os import re import time from dataclasses import dataclass from glob import iglob import torch from fire import Fire from flux.content_filters import PixtralContentFilter from flux.sampling import denoise, get_schedule, prepare_kontext, unpack from flux.util import ( aspect_ratio_to_height_width, check_onnx_access_for_trt, load_ae, load_clip, load_flow_model, load_t5, save_image, ) @dataclass class SamplingOptions: prompt: str width: int | None height: int | None num_steps: int guidance: float seed: int | None img_cond_path: str def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the prompt or write a command starting with a slash:\n" "- '/ar :' will set the aspect ratio of the generated image\n" "- '/s ' sets the next seed\n" "- '/g ' sets the guidance (flux-dev only)\n" "- '/n ' sets the number of steps\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/ar"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, ratio_prompt = prompt.split() if ratio_prompt == "auto": options.width = None options.height = None print("Setting resolution to input image resolution.") else: options.width, options.height = aspect_ratio_to_height_width(ratio_prompt) print(f"Setting resolution to {options.width} x {options.height}.") elif prompt.startswith("/h"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, height = prompt.split() if height == "auto": options.height = None else: options.height = 16 * (int(height) // 16) if options.height is not None and options.width is not None: print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) else: print(f"Setting resolution to {options.width} x {options.height}.") elif prompt.startswith("/g"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, guidance = prompt.split() options.guidance = float(guidance) print(f"Setting guidance to {options.guidance}") elif prompt.startswith("/s"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, seed = prompt.split() options.seed = int(seed) print(f"Setting seed to {options.seed}") elif prompt.startswith("/n"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, steps = prompt.split() options.num_steps = int(steps) print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) if prompt != "": options.prompt = prompt return options def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: if options is None: return None user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write a path to an image directly, leave this field empty " "to repeat the last input image or write a command starting with a slash:\n" "- '/q' to quit\n\n" "The input image will be edited by FLUX.1 Kontext creating a new image based" "on your instruction prompt." ) while True: img_cond_path = input(user_question) if img_cond_path.startswith("/"): if img_cond_path.startswith("/q"): print("Quitting") return None else: if not img_cond_path.startswith("/h"): print(f"Got invalid command '{img_cond_path}'\n{usage}") print(usage) continue if img_cond_path == "": break if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( (".jpg", ".jpeg", ".png", ".webp") ): print(f"File '{img_cond_path}' does not exist or is not a valid image file") continue options.img_cond_path = img_cond_path break return options @torch.inference_mode() def main( name: str = "flux-dev-kontext", aspect_ratio: str | None = None, seed: int | None = None, prompt: str = "replace the logo with the text 'Black Forest Labs'", device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int = 30, loop: bool = False, guidance: float = 2.5, offload: bool = False, output_dir: str = "output", add_sampling_metadata: bool = True, img_cond_path: str = "assets/cup.png", trt: bool = False, trt_transformer_precision: str = "bf16", track_usage: bool = False, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: height: height of the sample in pixels (should be a multiple of 16), None defaults to the size of the conditioning width: width of the sample in pixels (should be a multiple of 16), None defaults to the size of the conditioning seed: Set a seed for sampling output_name: where to save the output image, `{idx}` will be replaced by the index of the sample prompt: Prompt used for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation add_sampling_metadata: Add the prompt to the image Exif metadata img_cond_path: path to conditioning image (jpeg/png/webp) trt: use TensorRT backend for optimized inference track_usage: track usage of the model for licensing purposes """ assert name == "flux-dev-kontext", f"Got unknown model name: {name}" torch_device = torch.device(device) output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 if aspect_ratio is None: width = None height = None else: width, height = aspect_ratio_to_height_width(aspect_ratio) if not trt: t5 = load_t5(torch_device, max_length=512) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) else: # lazy import to make install optional from flux.trt.trt_manager import ModuleName, TRTManager # Check if we need ONNX model access (which requires authentication for FLUX models) onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision) trt_ctx_manager = TRTManager( trt_transformer_precision=trt_transformer_precision, trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"), ) engines = trt_ctx_manager.load_engines( model_name=name, module_names={ ModuleName.CLIP, ModuleName.TRANSFORMER, ModuleName.T5, }, engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"), custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""), trt_image_height=height, trt_image_width=width, trt_batch_size=1, trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None), trt_static_batch=False, trt_static_shape=False, ) model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device) clip = engines[ModuleName.CLIP].to(torch_device) t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) content_filter = PixtralContentFilter(torch.device("cpu")) rng = torch.Generator(device="cpu") opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, img_cond_path=img_cond_path, ) if loop: opts = parse_prompt(opts) opts = parse_img_cond_path(opts) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") t0 = time.perf_counter() if content_filter.test_txt(opts.prompt): print("Your prompt has been automatically flagged. Please choose another prompt.") if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) else: opts = None continue if content_filter.test_image(opts.img_cond_path): print("Your input image has been automatically flagged. Please choose another image.") if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) else: opts = None continue if offload: t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device) inp, height, width = prepare_kontext( t5=t5, clip=clip, prompt=opts.prompt, ae=ae, img_cond_path=opts.img_cond_path, target_width=opts.width, target_height=opts.height, bs=1, seed=opts.seed, device=torch_device, ) from safetensors.torch import save_file save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft") inp.pop("img_cond_orig") opts.seed = None timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs and AE to CPU, load model to gpu if offload: t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise t00 = time.time() x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) torch.cuda.synchronize() t01 = time.time() print(f"Denoising took {t01 - t00:.3f}s") # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), height, width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): ae_dev_t0 = time.perf_counter() x = ae.decode(x) torch.cuda.synchronize() ae_dev_t1 = time.perf_counter() print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s") if content_filter.test_image(x.cpu()): print( "Your output image has been automatically flagged. Choose another prompt/image or try again." ) if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) else: opts = None continue if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s") idx = save_image( None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage ) if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) else: opts = None if __name__ == "__main__": Fire(main) ================================================ FILE: src/flux/cli_redux.py ================================================ import os import re import time from dataclasses import dataclass from glob import iglob import torch from fire import Fire from transformers import pipeline from flux.modules.image_embedders import ReduxImageEncoder from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack from flux.util import ( get_checkpoint_path, load_ae, load_clip, load_flow_model, load_t5, save_image, ) @dataclass class SamplingOptions: prompt: str width: int height: int num_steps: int guidance: float seed: int | None img_cond_path: str def parse_prompt(options: SamplingOptions) -> SamplingOptions | None: user_question = "Write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Leave this field empty to do nothing " "or write a command starting with a slash:\n" "- '/w ' will set the width of the generated image\n" "- '/h ' will set the height of the generated image\n" "- '/s ' sets the next seed\n" "- '/g ' sets the guidance (flux-dev only)\n" "- '/n ' sets the number of steps\n" "- '/q' to quit" ) while (prompt := input(user_question)).startswith("/"): if prompt.startswith("/w"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, width = prompt.split() options.width = 16 * (int(width) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/h"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, height = prompt.split() options.height = 16 * (int(height) // 16) print( f"Setting resolution to {options.width} x {options.height} " f"({options.height * options.width / 1e6:.2f}MP)" ) elif prompt.startswith("/g"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, guidance = prompt.split() options.guidance = float(guidance) print(f"Setting guidance to {options.guidance}") elif prompt.startswith("/s"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, seed = prompt.split() options.seed = int(seed) print(f"Setting seed to {options.seed}") elif prompt.startswith("/n"): if prompt.count(" ") != 1: print(f"Got invalid command '{prompt}'\n{usage}") continue _, steps = prompt.split() options.num_steps = int(steps) print(f"Setting number of steps to {options.num_steps}") elif prompt.startswith("/q"): print("Quitting") return None else: if not prompt.startswith("/h"): print(f"Got invalid command '{prompt}'\n{usage}") print(usage) return options def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None: if options is None: return None user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n" usage = ( "Usage: Either write your prompt directly, leave this field empty " "to repeat the conditioning image or write a command starting with a slash:\n" "- '/q' to quit" ) while True: img_cond_path = input(user_question) if img_cond_path.startswith("/"): if img_cond_path.startswith("/q"): print("Quitting") return None else: if not img_cond_path.startswith("/h"): print(f"Got invalid command '{img_cond_path}'\n{usage}") print(usage) continue if img_cond_path == "": break if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith( (".jpg", ".jpeg", ".png", ".webp") ): print(f"File '{img_cond_path}' does not exist or is not a valid image file") continue options.img_cond_path = img_cond_path break return options @torch.inference_mode() def main( name: str = "flux-dev", width: int = 1360, height: int = 768, seed: int | None = None, device: str = "cuda" if torch.cuda.is_available() else "cpu", num_steps: int | None = None, loop: bool = False, guidance: float = 2.5, offload: bool = False, output_dir: str = "output", add_sampling_metadata: bool = True, img_cond_path: str = "assets/robot.webp", track_usage: bool = False, ): """ Sample the flux model. Either interactively (set `--loop`) or run for a single image. Args: name: Name of the base model to use (either 'flux-dev' or 'flux-schnell') height: height of the sample in pixels (should be a multiple of 16) width: width of the sample in pixels (should be a multiple of 16) seed: Set a seed for sampling device: Pytorch device num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled) loop: start an interactive session and sample multiple times guidance: guidance value used for guidance distillation offload: offload models to CPU when not in use output_dir: where to save the output images add_sampling_metadata: Add the prompt to the image Exif metadata img_cond_path: path to conditioning image (jpeg/png/webp) track_usage: track usage of the model for licensing purposes """ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device) if name not in (available := ["flux-dev", "flux-schnell"]): raise ValueError(f"Got unknown model name: {name}, chose from {available}") torch_device = torch.device(device) if num_steps is None: num_steps = 4 if name == "flux-schnell" else 50 output_name = os.path.join(output_dir, "img_{idx}.jpg") if not os.path.exists(output_dir): os.makedirs(output_dir) idx = 0 else: fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)] if len(fns) > 0: idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1 else: idx = 0 # init all components t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512) clip = load_clip(torch_device) model = load_flow_model(name, device="cpu" if offload else torch_device) ae = load_ae(name, device="cpu" if offload else torch_device) # Download and initialize the Redux adapter redux_path = str( get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX") ) img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path) rng = torch.Generator(device="cpu") prompt = "" opts = SamplingOptions( prompt=prompt, width=width, height=height, num_steps=num_steps, guidance=guidance, seed=seed, img_cond_path=img_cond_path, ) if loop: opts = parse_prompt(opts) opts = parse_img_cond_path(opts) while opts is not None: if opts.seed is None: opts.seed = rng.seed() print(f"Generating with seed {opts.seed}:\n{opts.prompt}") t0 = time.perf_counter() # prepare input x = get_noise( 1, opts.height, opts.width, device=torch_device, dtype=torch.bfloat16, seed=opts.seed, ) opts.seed = None if offload: ae = ae.cpu() torch.cuda.empty_cache() t5, clip = t5.to(torch_device), clip.to(torch_device) inp = prepare_redux( t5, clip, x, prompt=opts.prompt, encoder=img_embedder, img_cond_path=opts.img_cond_path, ) timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell")) # offload TEs to CPU, load model to gpu if offload: t5, clip = t5.cpu(), clip.cpu() torch.cuda.empty_cache() model = model.to(torch_device) # denoise initial noise x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance) # offload model, load autoencoder to gpu if offload: model.cpu() torch.cuda.empty_cache() ae.decoder.to(x.device) # decode latents to pixel space x = unpack(x.float(), opts.height, opts.width) with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16): x = ae.decode(x) if torch.cuda.is_available(): torch.cuda.synchronize() t1 = time.perf_counter() print(f"Done in {t1 - t0:.1f}s") idx = save_image( nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage ) if loop: print("-" * 80) opts = parse_prompt(opts) opts = parse_img_cond_path(opts) else: opts = None if __name__ == "__main__": Fire(main) ================================================ FILE: src/flux/content_filters.py ================================================ import torch from einops import rearrange from PIL import Image from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline PROMPT_IMAGE_INTEGRITY = """ Task: Analyze an image to identify potential copyright concerns or depictions of public figures. Output: Respond with only "yes" or "no" Criteria for "yes": - The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.) - The image displays a trademarked logo or brand - The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.) Criteria for "no": - All other cases - When you cannot identify the specific copyrighted work or named individual Critical Requirements: 1. You must be able to name the exact copyrighted work or specific person depicted 2. General references to demographics or characteristics are not sufficient 3. Base your decision solely on visual content, not interpretation 4. Provide only the one-word answer: "yes" or "no" """.strip() PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?" PROMPT_TEXT_INTEGRITY = """ Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures. Output: Respond with only "yes" or "no" Criteria for "Yes": - The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.) - The prompt explicitly mentions a trademarked logo or brand - The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.) Criteria for "No": - All other cases - When you cannot identify the specific copyrighted work or named individual Critical Requirements: 1. You must be able to name the exact copyrighted work or specific person referenced 2. General demographic descriptions or characteristics are not sufficient 3. Analyze only the prompt text, not potential image outcomes 4. Provide only the one-word answer: "yes" or "no" The prompt to check is: ----- {prompt} ----- Does this prompt have copyright concerns or includes public figures? """.strip() class PixtralContentFilter(torch.nn.Module): def __init__( self, device: torch.device = torch.device("cpu"), nsfw_threshold: float = 0.85, ): super().__init__() model_id = "mistral-community/pixtral-12b" self.processor = AutoProcessor.from_pretrained(model_id) self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device) self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"]) self.nsfw_classifier = pipeline( "image-classification", model="Falconsai/nsfw_image_detection", device=device ) self.nsfw_threshold = nsfw_threshold def yes_no_logit_processor( self, input_ids: torch.LongTensor, scores: torch.FloatTensor ) -> torch.FloatTensor: """ Sets all tokens but yes/no to the minimum. """ scores_yes_token = scores[:, self.yes_token].clone() scores_no_token = scores[:, self.no_token].clone() scores_min = scores.min() scores[:, :] = scores_min - 1 scores[:, self.yes_token] = scores_yes_token scores[:, self.no_token] = scores_no_token return scores def test_image(self, image: Image.Image | str | torch.Tensor) -> bool: if isinstance(image, torch.Tensor): image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c") image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy()) elif isinstance(image, str): image = Image.open(image) classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw") if classification["score"] > self.nsfw_threshold: return True # 512^2 pixels are enough for checking w, h = image.size f = (512**2 / (w * h)) ** 0.5 image = image.resize((int(f * w), int(f * h))) chat = [ { "role": "user", "content": [ { "type": "text", "content": PROMPT_IMAGE_INTEGRITY, }, { "type": "image", "image": image, }, { "type": "text", "content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP, }, ], } ] inputs = self.processor.apply_chat_template( chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) generate_ids = self.model.generate( **inputs, max_new_tokens=1, logits_processor=[self.yes_no_logit_processor], do_sample=False, ) return generate_ids[0, -1].item() == self.yes_token def test_txt(self, txt: str) -> bool: chat = [ { "role": "user", "content": [ { "type": "text", "content": PROMPT_TEXT_INTEGRITY.format(prompt=txt), }, ], } ] inputs = self.processor.apply_chat_template( chat, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(self.model.device) generate_ids = self.model.generate( **inputs, max_new_tokens=1, logits_processor=[self.yes_no_logit_processor], do_sample=False, ) return generate_ids[0, -1].item() == self.yes_token ================================================ FILE: src/flux/math.py ================================================ import torch from einops import rearrange from torch import Tensor def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: q, k = apply_rope(q, k, pe) x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") return x def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos, omega) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.float() def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) ================================================ FILE: src/flux/model.py ================================================ from dataclasses import dataclass import torch from torch import Tensor, nn from flux.modules.layers import ( DoubleStreamBlock, EmbedND, LastLayer, MLPEmbedder, SingleStreamBlock, timestep_embedding, ) from flux.modules.lora import LinearLora, replace_linear_with_lora @dataclass class FluxParams: in_channels: int out_channels: int vec_in_dim: int context_in_dim: int hidden_size: int mlp_ratio: float num_heads: int depth: int depth_single_blocks: int axes_dim: list[int] theta: int qkv_bias: bool guidance_embed: bool class Flux(nn.Module): """ Transformer model for flow matching on sequences. """ def __init__(self, params: FluxParams): super().__init__() self.params = params self.in_channels = params.in_channels self.out_channels = params.out_channels if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" ) pe_dim = params.hidden_size // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}") self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity() ) self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) self.double_blocks = nn.ModuleList( [ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, ) for _ in range(params.depth) ] ) self.single_blocks = nn.ModuleList( [ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio) for _ in range(params.depth_single_blocks) ] ) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) def forward( self, img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, timesteps: Tensor, y: Tensor, guidance: Tensor | None = None, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") # running on sequences img img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256)) if self.params.guidance_embed: if guidance is None: raise ValueError("Didn't get guidance strength for guidance distilled model.") vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) vec = vec + self.vector_in(y) txt = self.txt_in(txt) ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) for block in self.double_blocks: img, txt = block(img=img, txt=txt, vec=vec, pe=pe) img = torch.cat((txt, img), 1) for block in self.single_blocks: img = block(img, vec=vec, pe=pe) img = img[:, txt.shape[1] :, ...] img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) return img class FluxLoraWrapper(Flux): def __init__( self, lora_rank: int = 128, lora_scale: float = 1.0, *args, **kwargs, ) -> None: super().__init__(*args, **kwargs) self.lora_rank = lora_rank replace_linear_with_lora( self, max_rank=lora_rank, scale=lora_scale, ) def set_lora_scale(self, scale: float) -> None: for module in self.modules(): if isinstance(module, LinearLora): module.set_scale(scale=scale) ================================================ FILE: src/flux/modules/autoencoder.py ================================================ from dataclasses import dataclass import torch from einops import rearrange from torch import Tensor, nn @dataclass class AutoEncoderParams: resolution: int in_channels: int ch: int out_ch: int ch_mult: list[int] num_res_blocks: int z_channels: int scale_factor: float shift_factor: float def swish(x: Tensor) -> Tensor: return x * torch.sigmoid(x) class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() self.in_channels = in_channels self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1) self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1) def attention(self, h_: Tensor) -> Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous() k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous() v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous() h_ = nn.functional.scaled_dot_product_attention(q, k, v) return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x: Tensor) -> Tensor: return x + self.proj_out(self.attention(x)) class ResnetBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) if self.in_channels != self.out_channels: self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h class Downsample(nn.Module): def __init__(self, in_channels: int): super().__init__() # no asymmetric padding in torch conv, must do it ourselves self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) def forward(self, x: Tensor): pad = (0, 1, 0, 1) x = nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) return x class Upsample(nn.Module): def __init__(self, in_channels: int): super().__init__() self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor): x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") x = self.conv(x) return x class Encoder(nn.Module): def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() block_in = self.ch for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1) def forward(self, x: Tensor) -> Tensor: # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1]) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, ): super().__init__() self.ch = ch self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.ffactor = 2 ** (self.num_resolutions - 1) # compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) # z to block_in self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in) self.mid.attn_1 = AttnBlock(block_in) self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) block_in = block_out up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True) self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) def forward(self, z: Tensor) -> Tensor: # get dtype for proper tracing upscale_dtype = next(self.up.parameters()).dtype # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # cast to proper dtype h = h.to(upscale_dtype) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h class DiagonalGaussian(nn.Module): def __init__(self, sample: bool = True, chunk_dim: int = 1): super().__init__() self.sample = sample self.chunk_dim = chunk_dim def forward(self, z: Tensor) -> Tensor: mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim) if self.sample: std = torch.exp(0.5 * logvar) return mean + std * torch.randn_like(mean) else: return mean class AutoEncoder(nn.Module): def __init__(self, params: AutoEncoderParams, sample_z: bool = False): super().__init__() self.params = params self.encoder = Encoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.decoder = Decoder( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, ) self.reg = DiagonalGaussian(sample=sample_z) self.scale_factor = params.scale_factor self.shift_factor = params.shift_factor def encode(self, x: Tensor) -> Tensor: z = self.reg(self.encoder(x)) z = self.scale_factor * (z - self.shift_factor) return z def decode(self, z: Tensor) -> Tensor: z = z / self.scale_factor + self.shift_factor return self.decoder(z) def forward(self, x: Tensor) -> Tensor: return self.decode(self.encode(x)) ================================================ FILE: src/flux/modules/conditioner.py ================================================ from torch import Tensor, nn from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer class HFEmbedder(nn.Module): def __init__(self, version: str, max_length: int, **hf_kwargs): super().__init__() self.is_clip = version.startswith("openai") self.max_length = max_length self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" if self.is_clip: self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length) self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs) else: self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length) self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs) self.hf_module = self.hf_module.eval().requires_grad_(False) def forward(self, text: list[str]) -> Tensor: batch_encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) outputs = self.hf_module( input_ids=batch_encoding["input_ids"].to(self.hf_module.device), attention_mask=None, output_hidden_states=False, ) return outputs[self.output_key].bfloat16() ================================================ FILE: src/flux/modules/image_embedders.py ================================================ import cv2 import numpy as np import torch from einops import rearrange, repeat from PIL import Image from safetensors.torch import load_file as load_sft from torch import nn from transformers import AutoModelForDepthEstimation, AutoProcessor, SiglipImageProcessor, SiglipVisionModel from flux.util import print_load_warning class DepthImageEncoder: depth_model_name = "LiheYoung/depth-anything-large-hf" def __init__(self, device): self.device = device self.depth_model = AutoModelForDepthEstimation.from_pretrained(self.depth_model_name).to(device) self.processor = AutoProcessor.from_pretrained(self.depth_model_name) def __call__(self, img: torch.Tensor) -> torch.Tensor: hw = img.shape[-2:] img = torch.clamp(img, -1.0, 1.0) img_byte = ((img + 1.0) * 127.5).byte() img = self.processor(img_byte, return_tensors="pt")["pixel_values"] depth = self.depth_model(img.to(self.device)).predicted_depth depth = repeat(depth, "b h w -> b 3 h w") depth = torch.nn.functional.interpolate(depth, hw, mode="bicubic", antialias=True) depth = depth / 127.5 - 1.0 return depth class CannyImageEncoder: def __init__( self, device, min_t: int = 50, max_t: int = 200, ): self.device = device self.min_t = min_t self.max_t = max_t def __call__(self, img: torch.Tensor) -> torch.Tensor: assert img.shape[0] == 1, "Only batch size 1 is supported" img = rearrange(img[0], "c h w -> h w c") img = torch.clamp(img, -1.0, 1.0) img_np = ((img + 1.0) * 127.5).numpy().astype(np.uint8) # Apply Canny edge detection canny = cv2.Canny(img_np, self.min_t, self.max_t) # Convert back to torch tensor and reshape canny = torch.from_numpy(canny).float() / 127.5 - 1.0 canny = rearrange(canny, "h w -> 1 1 h w") canny = repeat(canny, "b 1 ... -> b 3 ...") return canny.to(self.device) class ReduxImageEncoder(nn.Module): siglip_model_name = "google/siglip-so400m-patch14-384" def __init__( self, device, redux_path: str, redux_dim: int = 1152, txt_in_features: int = 4096, dtype=torch.bfloat16, ) -> None: super().__init__() self.redux_dim = redux_dim self.device = device if isinstance(device, torch.device) else torch.device(device) self.dtype = dtype with self.device: self.redux_up = nn.Linear(redux_dim, txt_in_features * 3, dtype=dtype) self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features, dtype=dtype) sd = load_sft(redux_path, device=str(device)) missing, unexpected = self.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) self.siglip = SiglipVisionModel.from_pretrained(self.siglip_model_name).to(dtype=dtype) self.normalize = SiglipImageProcessor.from_pretrained(self.siglip_model_name) def __call__(self, x: Image.Image) -> torch.Tensor: imgs = self.normalize.preprocess(images=[x], do_resize=True, return_tensors="pt", do_convert_rgb=True) _encoded_x = self.siglip(**imgs.to(device=self.device, dtype=self.dtype)).last_hidden_state projected_x = self.redux_down(nn.functional.silu(self.redux_up(_encoded_x))) return projected_x ================================================ FILE: src/flux/modules/layers.py ================================================ import math from dataclasses import dataclass import torch from einops import rearrange from torch import Tensor, nn from flux.math import attention, rope class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list[int]): super().__init__() self.dim = dim self.theta = theta self.axes_dim = axes_dim def forward(self, ids: Tensor) -> Tensor: n_axes = ids.shape[-1] emb = torch.cat( [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3, ) return emb.unsqueeze(1) def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ t = time_factor * t half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( t.device ) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) if torch.is_floating_point(t): embedding = embedding.to(t) return embedding class MLPEmbedder(nn.Module): def __init__(self, in_dim: int, hidden_dim: int): super().__init__() self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) self.silu = nn.SiLU() self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) class RMSNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x: Tensor): x_dtype = x.dtype x = x.float() rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) return (x * rrms).to(dtype=x_dtype) * self.scale class QKNorm(torch.nn.Module): def __init__(self, dim: int): super().__init__() self.query_norm = RMSNorm(dim) self.key_norm = RMSNorm(dim) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: q = self.query_norm(q) k = self.key_norm(k) return q.to(v), k.to(v) class SelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) def forward(self, x: Tensor, pe: Tensor) -> Tensor: qkv = self.qkv(x) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) x = attention(q, k, v, pe=pe) x = self.proj(x) return x @dataclass class ModulationOut: shift: Tensor scale: Tensor gate: Tensor class Modulation(nn.Module): def __init__(self, dim: int, double: bool): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1) return ( ModulationOut(*out[:3]), ModulationOut(*out[3:]) if self.is_double else None, ) class DoubleStreamBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size self.img_mod = Modulation(hidden_size, double=True) self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.img_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) self.txt_mod = Modulation(hidden_size, double=True) self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias) self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.txt_mlp = nn.Sequential( nn.Linear(hidden_size, mlp_hidden_dim, bias=True), nn.GELU(approximate="tanh"), nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]: img_mod1, img_mod2 = self.img_mod(vec) txt_mod1, txt_mod2 = self.txt_mod(vec) # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_qkv = self.img_attn.qkv(img_modulated) img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_qkv = self.txt_attn.qkv(txt_modulated) txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) # run actual attention q = torch.cat((txt_q, img_q), dim=2) k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) attn = attention(q, k, v, pe=pe) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img blocks img = img + img_mod1.gate * self.img_attn.proj(img_attn) img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) # calculate the txt blocks txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) return img, txt class SingleStreamBlock(nn.Module): """ A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface. """ def __init__( self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0, qk_scale: float | None = None, ): super().__init__() self.hidden_dim = hidden_size self.num_heads = num_heads head_dim = hidden_size // num_heads self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) # qkv and mlp_in self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) # proj and mlp_out self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) self.norm = QKNorm(head_dim) self.hidden_size = hidden_size self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: mod, _ = self.modulation(vec) x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) # compute attention attn = attention(q, k, v, pe=pe) # compute activation in mlp stream, cat again and run second linear layer output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) return x + mod.gate * output class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x: Tensor, vec: Tensor) -> Tensor: shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] x = self.linear(x) return x ================================================ FILE: src/flux/modules/lora.py ================================================ import torch from torch import nn def replace_linear_with_lora( module: nn.Module, max_rank: int, scale: float = 1.0, ) -> None: for name, child in module.named_children(): if isinstance(child, nn.Linear): new_lora = LinearLora( in_features=child.in_features, out_features=child.out_features, bias=child.bias, rank=max_rank, scale=scale, dtype=child.weight.dtype, device=child.weight.device, ) new_lora.weight = child.weight new_lora.bias = child.bias if child.bias is not None else None setattr(module, name, new_lora) else: replace_linear_with_lora( module=child, max_rank=max_rank, scale=scale, ) class LinearLora(nn.Linear): def __init__( self, in_features: int, out_features: int, bias: bool, rank: int, dtype: torch.dtype, device: torch.device, lora_bias: bool = True, scale: float = 1.0, *args, **kwargs, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias is not None, device=device, dtype=dtype, *args, **kwargs, ) assert isinstance(scale, float), "scale must be a float" self.scale = scale self.rank = rank self.lora_bias = lora_bias self.dtype = dtype self.device = device if rank > (new_rank := min(self.out_features, self.in_features)): self.rank = new_rank self.lora_A = nn.Linear( in_features=in_features, out_features=self.rank, bias=False, dtype=dtype, device=device, ) self.lora_B = nn.Linear( in_features=self.rank, out_features=out_features, bias=self.lora_bias, dtype=dtype, device=device, ) def set_scale(self, scale: float) -> None: assert isinstance(scale, float), "scalar value must be a float" self.scale = scale def forward(self, input: torch.Tensor) -> torch.Tensor: base_out = super().forward(input) _lora_out_B = self.lora_B(self.lora_A(input)) lora_update = _lora_out_B * self.scale return base_out + lora_update ================================================ FILE: src/flux/sampling.py ================================================ import math from typing import Callable import numpy as np import torch from einops import rearrange, repeat from PIL import Image from torch import Tensor from .model import Flux from .modules.autoencoder import AutoEncoder from .modules.conditioner import HFEmbedder from .modules.image_embedders import CannyImageEncoder, DepthImageEncoder, ReduxImageEncoder from .util import PREFERED_KONTEXT_RESOLUTIONS def get_noise( num_samples: int, height: int, width: int, device: torch.device, dtype: torch.dtype, seed: int, ): return torch.randn( num_samples, 16, # allow for packing 2 * math.ceil(height / 16), 2 * math.ceil(width / 16), dtype=dtype, generator=torch.Generator(device="cpu").manual_seed(seed), ).to(device) def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str]) -> dict[str, Tensor]: bs, c, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def prepare_control( t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], ae: AutoEncoder, encoder: DepthImageEncoder | CannyImageEncoder, img_cond_path: str, ) -> dict[str, Tensor]: # load and encode the conditioning image bs, _, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img_cond = Image.open(img_cond_path).convert("RGB") width = w * 8 height = h * 8 img_cond = img_cond.resize((width, height), Image.Resampling.LANCZOS) img_cond = np.array(img_cond) img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 img_cond = rearrange(img_cond, "h w c -> 1 c h w") with torch.no_grad(): img_cond = encoder(img_cond) img_cond = ae.encode(img_cond) img_cond = img_cond.to(torch.bfloat16) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) return_dict = prepare(t5, clip, img, prompt) return_dict["img_cond"] = img_cond return return_dict def prepare_fill( t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], ae: AutoEncoder, img_cond_path: str, mask_path: str, ) -> dict[str, Tensor]: # load and encode the conditioning image and the mask bs, _, _, _ = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img_cond = Image.open(img_cond_path).convert("RGB") img_cond = np.array(img_cond) img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 img_cond = rearrange(img_cond, "h w c -> 1 c h w") mask = Image.open(mask_path).convert("L") mask = np.array(mask) mask = torch.from_numpy(mask).float() / 255.0 mask = rearrange(mask, "h w -> 1 1 h w") with torch.no_grad(): img_cond = img_cond.to(img.device) mask = mask.to(img.device) img_cond = img_cond * (1 - mask) img_cond = ae.encode(img_cond) mask = mask[:, 0, :, :] mask = mask.to(torch.bfloat16) mask = rearrange( mask, "b (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8, ) mask = rearrange(mask, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if mask.shape[0] == 1 and bs > 1: mask = repeat(mask, "1 ... -> bs ...", bs=bs) img_cond = img_cond.to(torch.bfloat16) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) img_cond = torch.cat((img_cond, mask), dim=-1) return_dict = prepare(t5, clip, img, prompt) return_dict["img_cond"] = img_cond.to(img.device) return return_dict def prepare_redux( t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str | list[str], encoder: ReduxImageEncoder, img_cond_path: str, ) -> dict[str, Tensor]: bs, _, h, w = img.shape if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img_cond = Image.open(img_cond_path).convert("RGB") with torch.no_grad(): img_cond = encoder(img_cond) img_cond = img_cond.to(torch.bfloat16) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img.shape[0] == 1 and bs > 1: img = repeat(img, "1 ... -> bs ...", bs=bs) img_ids = torch.zeros(h // 2, w // 2, 3) img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None] img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :] img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) if isinstance(prompt, str): prompt = [prompt] txt = t5(prompt) txt = torch.cat((txt, img_cond.to(txt)), dim=-2) if txt.shape[0] == 1 and bs > 1: txt = repeat(txt, "1 ... -> bs ...", bs=bs) txt_ids = torch.zeros(bs, txt.shape[1], 3) vec = clip(prompt) if vec.shape[0] == 1 and bs > 1: vec = repeat(vec, "1 ... -> bs ...", bs=bs) return { "img": img, "img_ids": img_ids.to(img.device), "txt": txt.to(img.device), "txt_ids": txt_ids.to(img.device), "vec": vec.to(img.device), } def prepare_kontext( t5: HFEmbedder, clip: HFEmbedder, prompt: str | list[str], ae: AutoEncoder, img_cond_path: str, seed: int, device: torch.device, target_width: int | None = None, target_height: int | None = None, bs: int = 1, ) -> tuple[dict[str, Tensor], int, int]: # load and encode the conditioning image if bs == 1 and not isinstance(prompt, str): bs = len(prompt) img_cond = Image.open(img_cond_path).convert("RGB") width, height = img_cond.size aspect_ratio = width / height # Kontext is trained on specific resolutions, using one of them is recommended _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) width = 2 * int(width / 16) height = 2 * int(height / 16) img_cond = img_cond.resize((8 * width, 8 * height), Image.Resampling.LANCZOS) img_cond = np.array(img_cond) img_cond = torch.from_numpy(img_cond).float() / 127.5 - 1.0 img_cond = rearrange(img_cond, "h w c -> 1 c h w") img_cond_orig = img_cond.clone() with torch.no_grad(): img_cond = ae.encode(img_cond.to(device)) img_cond = img_cond.to(torch.bfloat16) img_cond = rearrange(img_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) if img_cond.shape[0] == 1 and bs > 1: img_cond = repeat(img_cond, "1 ... -> bs ...", bs=bs) # image ids are the same as base image with the first dimension set to 1 # instead of 0 img_cond_ids = torch.zeros(height // 2, width // 2, 3) img_cond_ids[..., 0] = 1 img_cond_ids[..., 1] = img_cond_ids[..., 1] + torch.arange(height // 2)[:, None] img_cond_ids[..., 2] = img_cond_ids[..., 2] + torch.arange(width // 2)[None, :] img_cond_ids = repeat(img_cond_ids, "h w c -> b (h w) c", b=bs) if target_width is None: target_width = 8 * width if target_height is None: target_height = 8 * height img = get_noise( 1, target_height, target_width, device=device, dtype=torch.bfloat16, seed=seed, ) return_dict = prepare(t5, clip, img, prompt) return_dict["img_cond_seq"] = img_cond return_dict["img_cond_seq_ids"] = img_cond_ids.to(device) return_dict["img_cond_orig"] = img_cond_orig return return_dict, target_height, target_width def time_shift(mu: float, sigma: float, t: Tensor): return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) def get_lin_function( x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15 ) -> Callable[[float], float]: m = (y2 - y1) / (x2 - x1) b = y1 - m * x1 return lambda x: m * x + b def get_schedule( num_steps: int, image_seq_len: int, base_shift: float = 0.5, max_shift: float = 1.15, shift: bool = True, ) -> list[float]: # extra step for zero timesteps = torch.linspace(1, 0, num_steps + 1) # shifting the schedule to favor high timesteps for higher signal images if shift: # estimate mu based on linear estimation between two points mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len) timesteps = time_shift(mu, 1.0, timesteps) return timesteps.tolist() def denoise( model: Flux, # model input img: Tensor, img_ids: Tensor, txt: Tensor, txt_ids: Tensor, vec: Tensor, # sampling parameters timesteps: list[float], guidance: float = 4.0, # extra img tokens (channel-wise) img_cond: Tensor | None = None, # extra img tokens (sequence-wise) img_cond_seq: Tensor | None = None, img_cond_seq_ids: Tensor | None = None, ): # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) img_input = img img_input_ids = img_ids if img_cond is not None: img_input = torch.cat((img, img_cond), dim=-1) if img_cond_seq is not None: assert ( img_cond_seq_ids is not None ), "You need to provide either both or neither of the sequence conditioning" img_input = torch.cat((img_input, img_cond_seq), dim=1) img_input_ids = torch.cat((img_input_ids, img_cond_seq_ids), dim=1) pred = model( img=img_input, img_ids=img_input_ids, txt=txt, txt_ids=txt_ids, y=vec, timesteps=t_vec, guidance=guidance_vec, ) if img_input_ids is not None: pred = pred[:, : img.shape[1]] img = img + (t_prev - t_curr) * pred return img def unpack(x: Tensor, height: int, width: int) -> Tensor: return rearrange( x, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=math.ceil(height / 16), w=math.ceil(width / 16), ph=2, pw=2, ) ================================================ FILE: src/flux/trt/__init__.py ================================================ from flux.trt.trt_config import ModuleName from flux.trt.trt_manager import TRTManager __all__ = ["TRTManager", "ModuleName"] ================================================ FILE: src/flux/trt/engine/__init__.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flux.trt.engine.base_engine import BaseEngine, Engine, SharedMemory from flux.trt.engine.clip_engine import CLIPEngine from flux.trt.engine.t5_engine import T5Engine from flux.trt.engine.transformer_engine import TransformerEngine from flux.trt.engine.vae_engine import VAEDecoder, VAEEncoder, VAEEngine __all__ = [ "BaseEngine", "Engine", "SharedMemory", "CLIPEngine", "TransformerEngine", "T5Engine", "VAEEngine", "VAEDecoder", "VAEEncoder", ] ================================================ FILE: src/flux/trt/engine/base_engine.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc from abc import ABC, abstractmethod from collections import OrderedDict from typing import Dict import tensorrt as trt import torch from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes from flux.trt.trt_config import TRTBaseConfig TRT_LOGGER = trt.Logger(trt.Logger.ERROR) class SharedMemory(object): def __new__(cls, *args, **kwargs): if not hasattr(cls, "instance"): cls.instance = super(SharedMemory, cls).__new__(cls) cls.instance.__init__(*args, **kwargs) return cls.instance def __init__(self, size: int, device=torch.device("cuda")): self.allocations = {} self._buffer = torch.empty( size, dtype=torch.uint8, device=device, memory_format=torch.contiguous_format, ) def resize(self, name: str, size: int): self.allocations[name] = size if max(self.allocations.values()) > self._buffer.numel(): self.buffer = self._buffer.resize_(size) torch.cuda.empty_cache() def reset(self, name: str): self.allocations.pop(name) new_max = max(self.allocations.values()) if new_max < self._buffer.numel(): self.buffer = self._buffer.resize_(new_max) torch.cuda.empty_cache() def deallocate(self): del self._buffer torch.cuda.empty_cache() self._buffer = torch.empty( 1024, dtype=torch.uint8, device="cuda", memory_format=torch.contiguous_format, ) @property def shared_device_memory(self): return self._buffer.data_ptr() def __str__(self): def human_readable_size(size): for unit in ["B", "KiB", "MiB", "GiB"]: if size < 1024.0: return size, unit size /= 1024.0 return size, unit allocations_str = [] for name, size_bytes in self.allocations.items(): size, unit = human_readable_size(size_bytes) allocations_str.append(f"\t{name}: {size:.2f} {unit}\n") allocations_output = "".join(allocations_str) size, unit = human_readable_size(self._buffer.numel()) allocations_buffer = f"{size:.2f} {unit}" return f"Shared Memory Allocations: \n{allocations_output} \n\tCurrent: {allocations_buffer}" TRT_ALLOCATION_POLICY = {"global", "dynamic"} TRT_OFFLOAD_POLICY = "cpu_buffer" class BaseEngine(ABC): @staticmethod def trt_datatype_to_torch(datatype): datatype_mapping = { trt.DataType.BOOL: torch.bool, trt.DataType.UINT8: torch.uint8, trt.DataType.INT8: torch.int8, trt.DataType.INT32: torch.int32, trt.DataType.INT64: torch.int64, trt.DataType.HALF: torch.float16, trt.DataType.FLOAT: torch.float32, trt.DataType.BF16: torch.bfloat16, } if datatype not in datatype_mapping: raise ValueError(f"No PyTorch equivalent for TensorRT data type: {datatype}") return datatype_mapping[datatype] @abstractmethod def cpu(self) -> "BaseEngine": pass @abstractmethod def cuda(self) -> "BaseEngine": pass @abstractmethod def to(self, device: str | torch.device) -> "BaseEngine": pass class Engine(BaseEngine): def __init__( self, trt_config: TRTBaseConfig, stream: torch.cuda.Stream, context_memory: SharedMemory | None = None, allocation_policy: str = "global", ): self.trt_config = trt_config self.stream = stream self.context = None self.tensors = OrderedDict() self.context_memory = context_memory self.device: torch.device = torch.device("cpu") if TRT_OFFLOAD_POLICY == "cpu_buffer": self.engine: trt.ICudaEngine | bytes = None self.cpu_engine_buffer: bytes = bytes_from_path(self.trt_config.engine_path) else: self.engine: trt.ICudaEngine | bytes = bytes_from_path(self.trt_config.engine_path) assert allocation_policy in TRT_ALLOCATION_POLICY self.allocation_policy = allocation_policy self.current_input_hash = None self.cuda_graph = None @abstractmethod def __call__(self, *args, **Kwargs) -> torch.Tensor | dict[str, torch.Tensor] | tuple[torch.Tensor]: pass def cpu(self) -> "Engine": if self.device == torch.device("cpu"): return self self.deactivate() if TRT_OFFLOAD_POLICY == "cpu_buffer": del self.engine return self self.engine = memoryview(self.engine.serialize()) return self def cuda(self) -> "Engine": if self.device == torch.device("cuda"): return self buffer = self.cpu_engine_buffer if TRT_OFFLOAD_POLICY == "cpu_buffer" else self.engine self.engine = engine_from_bytes(buffer) gc.collect() self.context = self.engine.create_execution_context_without_device_memory() self.context_memory.resize(self.__class__.__name__, self.device_memory_size) self.context.device_memory = self.context_memory.shared_device_memory return self def to(self, device: str | torch.device) -> "Engine": if not isinstance(device, torch.device): device = torch.device(device) if self.device == device: return self if device == torch.device("cpu"): self.cpu() else: self.cuda() self.device = device return self def deactivate(self): del self.context self.context = None def allocate_buffers( self, shape_dict: dict[str, tuple], device: str | torch.device = "cuda", ): for binding in range(self.engine.num_io_tensors): tensor_name = self.engine.get_tensor_name(binding) tensor_shape = shape_dict[tensor_name] if tensor_name in self.tensors and self.tensors[tensor_name].shape == tensor_shape: continue if self.engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: self.context.set_input_shape(tensor_name, tensor_shape) tensor_dtype = self.trt_datatype_to_torch(self.engine.get_tensor_dtype(tensor_name)) tensor = torch.empty( size=tensor_shape, dtype=tensor_dtype, memory_format=torch.contiguous_format, ).to(device=device) self.tensors[tensor_name] = tensor def get_dtype(self, tensor_name: str): return self.trt_datatype_to_torch(self.engine.get_tensor_dtype(tensor_name)) def override_shapes(self, feed_dict: Dict[str, torch.Tensor]): for name, tensor in feed_dict.items(): shape = tensor.shape assert tensor.dtype == self.trt_datatype_to_torch(self.engine.get_tensor_dtype(name)), ( f"Debug: Mismatched data types for tensor '{name}'. " f"Expected: {self.trt_datatype_to_torch(self.engine.get_tensor_dtype(name))}, " f"Found: {tensor.dtype} " f"in {self.__class__.__name__}" ) self.context.set_input_shape(name, shape) assert self.context.all_binding_shapes_specified self.context.infer_shapes() for idx in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(idx) dtype = self.trt_datatype_to_torch(self.engine.get_tensor_dtype(name)) shape = self.context.get_tensor_shape(name) if -1 in shape: raise Exception("Unspecified shape identified for tensor {}: {} ".format(name, shape)) self.tensors[name] = torch.zeros(tuple(shape), dtype=dtype, device=self.device).contiguous() self.context.set_tensor_address(name, self.tensors[name].data_ptr()) if self.allocation_policy == "dynamic": self.context_memory.resize(self.__class__.__name__, self.device_memory_size) self.current_input_hash = self.calculate_input_hash(feed_dict) def deallocate_buffers(self): if len(self.tensors) == 0: return del self.tensors self.tensors = OrderedDict() torch.cuda.empty_cache() @property def device_memory_size(self): if self.allocation_policy == "global": return self.engine.device_memory_size else: if not self.context.all_binding_shapes_specified: return 0 return self.context.update_device_memory_size_for_shapes() @staticmethod def calculate_input_hash(feed_dict: Dict[str, torch.Tensor]): return hash(tuple(feed_dict[key].shape for key in sorted(feed_dict.keys()))) def _capture_cuda_graph(self): self.cuda_graph = torch.cuda.CUDAGraph() s = torch.cuda.Stream() with torch.cuda.graph(self.cuda_graph, stream=s): noerror = self.context.execute_async_v3(s.cuda_stream) if not noerror: raise ValueError("ERROR: inference failed.") # self.cuda_graph.replay() def infer( self, feed_dict: dict[str, torch.Tensor], ): if self.current_input_hash != self.calculate_input_hash(feed_dict): self.override_shapes(feed_dict) self.context.device_memory = self.context_memory.shared_device_memory for name, tensor in feed_dict.items(): self.tensors[name].copy_(tensor, non_blocking=True) noerror = self.context.execute_async_v3(self.stream.cuda_stream) if not noerror: raise ValueError("ERROR: inference failed.") return self.tensors def __str__(self): if self.engine is None: return "Engine has not been initialized" out = "" for idx in range(self.engine.num_io_tensors): name = self.engine.get_tensor_name(idx) mode = self.engine.get_tensor_mode(name) dtype = self.trt_datatype_to_torch(self.engine.get_tensor_dtype(name)) shape = self.engine.get_tensor_shape(name) out += f"\t{mode.name}: {name}={shape} {dtype.__str__()}\n" return out ================================================ FILE: src/flux/trt/engine/clip_engine.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from transformers import CLIPTokenizer from flux.trt.engine import Engine from flux.trt.trt_config import ClipConfig class CLIPEngine(Engine): def __init__(self, trt_config: ClipConfig, stream: torch.cuda.Stream, **kwargs): super().__init__(trt_config=trt_config, stream=stream, **kwargs) self.tokenizer = CLIPTokenizer.from_pretrained( "openai/clip-vit-large-patch14", max_length=self.trt_config.text_maxlen, ) @torch.inference_mode() def __call__( self, prompt: list[str], ) -> torch.Tensor: with torch.inference_mode(): feed_dict = self.tokenizer( prompt, truncation=True, max_length=self.trt_config.text_maxlen, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) feed_dict = {"input_ids": feed_dict["input_ids"].to(dtype=self.get_dtype("input_ids"))} pooled_embeddings = self.infer(feed_dict)["pooled_embeddings"] return pooled_embeddings ================================================ FILE: src/flux/trt/engine/t5_engine.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from transformers import T5Tokenizer from flux.trt.engine import Engine from flux.trt.trt_config import T5Config class T5Engine(Engine): def __init__(self, trt_config: T5Config, stream: torch.cuda.Stream, **kwargs): super().__init__(trt_config=trt_config, stream=stream, **kwargs) self.tokenizer = T5Tokenizer.from_pretrained( "google/t5-v1_1-xxl", max_length=self.trt_config.text_maxlen, ) @torch.inference_mode() def __call__( self, prompt: list[str], ) -> torch.Tensor: with torch.inference_mode(): feed_dict = self.tokenizer( prompt, truncation=True, max_length=self.trt_config.text_maxlen, return_length=False, return_overflowing_tokens=False, padding="max_length", return_tensors="pt", ) feed_dict = {"input_ids": feed_dict["input_ids"].to(dtype=self.get_dtype("input_ids"))} text_embeddings = self.infer(feed_dict)["text_embeddings"] return text_embeddings ================================================ FILE: src/flux/trt/engine/transformer_engine.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from flux.trt.engine import Engine from flux.trt.trt_config import TransformerConfig class TransformerEngine(Engine): __dd_to_flux__ = { "hidden_states": "img", "img_ids": "img_ids", "encoder_hidden_states": "txt", "pooled_projections": "y", "txt_ids": "txt_ids", "timestep": "timesteps", "guidance": "guidance", "latent": "latent", } __flux_to_dd__ = { "img": "hidden_states", "img_ids": "img_ids", "txt": "encoder_hidden_states", "y": "pooled_projections", "txt_ids": "txt_ids", "timesteps": "timestep", "guidance": "guidance", "latent": "latent", } def __init__(self, trt_config: TransformerConfig, stream: torch.cuda.Stream, **kwargs): super().__init__(trt_config=trt_config, stream=stream, **kwargs) @property def dd_to_flux(self): return TransformerEngine.__dd_to_flux__ @property def flux_to_dd(self): return TransformerEngine.__flux_to_dd__ @torch.inference_mode() def __call__( self, **kwargs, ) -> torch.Tensor: feed_dict = {} if self.trt_config.model_name == "flux-schnell": # remove guidance kwargs.pop("guidance") for tensor_name, tensor_value in kwargs.items(): if tensor_name == "latent": continue dd_name = self.flux_to_dd[tensor_name] feed_dict[dd_name] = tensor_value.to(dtype=self.get_dtype(dd_name)) # remove batch dim to match demo-diffusion feed_dict["img_ids"] = feed_dict["img_ids"][0] feed_dict["txt_ids"] = feed_dict["txt_ids"][0] latent = self.infer(feed_dict=feed_dict)["latent"] return latent ================================================ FILE: src/flux/trt/engine/vae_engine.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from flux.trt.engine.base_engine import BaseEngine, Engine from flux.trt.trt_config import VAEDecoderConfig, VAEEncoderConfig class VAEDecoder(Engine): def __init__(self, trt_config: VAEDecoderConfig, stream: torch.cuda.Stream, **kwargs): super().__init__(trt_config=trt_config, stream=stream, **kwargs) @torch.inference_mode() def __call__( self, z: torch.Tensor, ) -> torch.Tensor: z = z.to(dtype=self.get_dtype("latent")) z = (z / self.trt_config.scale_factor) + self.trt_config.shift_factor feed_dict = {"latent": z} images = self.infer(feed_dict=feed_dict)["images"] return images class VAEEncoder(Engine): def __init__(self, trt_config: VAEEncoderConfig, stream: torch.cuda.Stream, **kwargs): super().__init__(trt_config=trt_config, stream=stream, **kwargs) @torch.inference_mode() def __call__( self, x: torch.Tensor, ) -> torch.Tensor: feed_dict = {"images": x.to(dtype=self.get_dtype("images"))} latent = self.infer(feed_dict=feed_dict)["latent"] latent = self.trt_config.scale_factor * (latent - self.trt_config.shift_factor) return latent class VAEEngine(BaseEngine): def __init__( self, decoder: VAEDecoder, encoder: VAEEncoder | None = None, ): super().__init__() self.decoder = decoder self.encoder = encoder def decode(self, z: torch.Tensor) -> torch.Tensor: return self.decoder(z) def encode(self, x: torch.Tensor) -> torch.Tensor: assert self.encoder is not None, "An encoder is needed to encode an image" return self.encoder(x) def cpu(self): self.decoder = self.decoder.cpu() if self.encoder is not None: self.encoder = self.encoder.cpu() return self def cuda(self): self.decoder = self.decoder.cuda() if self.encoder is not None: self.encoder = self.encoder.cuda() return self def to(self, device): self.decoder = self.decoder.to(device) if self.encoder is not None: self.encoder = self.encoder.to(device) return self @property def device_memory_size(self): device_memory = self.decoder.device_memory_size if self.encoder is not None: device_memory = max(device_memory, self.encoder.device_memory_size) return device_memory ================================================ FILE: src/flux/trt/trt_config/__init__.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flux.trt.trt_config.base_trt_config import ModuleName, TRTBaseConfig, get_config, register_config from flux.trt.trt_config.clip_trt_config import ClipConfig from flux.trt.trt_config.t5_trt_config import T5Config from flux.trt.trt_config.transformer_trt_config import TransformerConfig from flux.trt.trt_config.vae_trt_config import VAEDecoderConfig, VAEEncoderConfig __all__ = [ "register_config", "get_config", "ModuleName", "TRTBaseConfig", "ClipConfig", "T5Config", "TransformerConfig", "VAEDecoderConfig", "VAEEncoderConfig", ] ================================================ FILE: src/flux/trt/trt_config/base_trt_config.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess from abc import abstractmethod from collections import defaultdict from dataclasses import dataclass, field from enum import Enum from typing import Any from colored import fore, style from huggingface_hub import snapshot_download from tensorrt import __version__ as trt_version class ModuleName(Enum): CLIP = "clip" T5 = "t5" TRANSFORMER = "transformer" VAE = "vae" VAE_ENCODER = "vae_encoder" registry = {} @dataclass class TRTBaseConfig: engine_dir: str precision: str trt_verbose: bool trt_static_batch: bool trt_static_shape: bool model_name: str module_name: ModuleName onnx_path: str = field(init=False) engine_path: str = field(init=False) trt_tf32: bool trt_bf16: bool trt_fp8: bool trt_fp4: bool trt_build_strongly_typed: bool custom_onnx_path: str | None = None trt_update_output_names: list[str] | None = None trt_enable_all_tactics: bool = False trt_timing_cache: str | None = None trt_native_instancenorm: bool = True trt_builder_optimization_level: int = 3 trt_precision_constraints: str = "none" min_batch: int = 1 max_batch: int = 4 @staticmethod def build_trt_engine( engine_path: str, onnx_path: str, strongly_typed=False, tf32=True, bf16=False, fp8=False, fp4=False, input_profile: dict[str, Any] | None = None, update_output_names: list[str] | None = None, enable_refit=False, enable_all_tactics=False, timing_cache: str | None = None, native_instancenorm=True, builder_optimization_level=3, precision_constraints="none", verbose=False, ): """ Metod used to build a TRT engine from a given set of flags or configurations using polygraphy. Args: engine_path (str): Output path used to store the build engine. onnx_path (str): Path containing an onnx model used to generated the engine. strongly_typed (bool): Flag indicating if the engine should be strongly typed. tf32 (bool): Whether to build the engine with TF32 precision enabled. bf16 (bool): Whether to build the engine with BF16 precision enabled. fp8 (bool): Whether to build the engine with FP8 precision enabled. Refer to plain dataype and do not interfer with quantization introduced by modelopt. fp4 (bool): Whether to build the engine with FP4 precision enabled. Refer to plain dataype and do not interfer with quantization introduced by modelopt. input_profile (dict[str, Any]): A set of optimization profiles to add to the configuration. Only needed for networks with dynamic input shapes. update_output_names (list[str]): List of output names to use in the trt engines. enable_refit (bool): Enables the engine to be refitted with new weights after it is built. enable_all_tactics (bool): Enables TRT to leverage all tactics or not. timing_cache (str): A path or file-like object from which to load a tactic timing cache. native_instancenorm (bool): support of instancenorm plugin. builder_optimization_level (int): The builder optimization level. precision_constraints (str): If set to "obey", require that layers execute in specified precisions. If set to "prefer", prefer that layers execute in specified precisions but allow TRT to fall back to other precisions if no implementation exists for the requested precision. Otherwise, precision constraints are ignored. verbose (bool): Weather to support verbose output Returns: dict[str, Any]: A dictionary representing the input profile configuration. """ print(f"Building TensorRT engine for {onnx_path}: {engine_path}") # Base command build_command = [f"polygraphy convert {onnx_path} --convert-to trt --output {engine_path}"] # Precision flags build_args = [ "--bf16" if bf16 else "", "--tf32" if tf32 else "", "--fp8" if fp8 else "", "--fp4" if fp4 else "", "--strongly-typed" if strongly_typed else "", ] # Additional arguments build_args.extend( [ "--refittable" if enable_refit else "", "--tactic-sources" if not enable_all_tactics else "", "--onnx-flags native_instancenorm" if native_instancenorm else "", f"--builder-optimization-level {builder_optimization_level}", f"--precision-constraints {precision_constraints}", ] ) # Timing cache if timing_cache: build_args.extend([f"--load-timing-cache {timing_cache}", f"--save-timing-cache {timing_cache}"]) # Verbosity setting verbosity = "extra_verbose" if verbose else "error" build_args.append(f"--verbosity {verbosity}") # Output names if update_output_names: print(f"Updating network outputs to {update_output_names}") build_args.append(f"--trt-outputs {' '.join(update_output_names)}") # Input profiles if input_profile: profile_args = defaultdict(str) for name, dims in input_profile.items(): assert len(dims) == 3 profile_args["--trt-min-shapes"] += f"{name}:{str(list(dims[0])).replace(' ', '')} " profile_args["--trt-opt-shapes"] += f"{name}:{str(list(dims[1])).replace(' ', '')} " profile_args["--trt-max-shapes"] += f"{name}:{str(list(dims[2])).replace(' ', '')} " build_args.extend(f"{k} {v}" for k, v in profile_args.items()) # Filter out empty strings and join command build_args = [arg for arg in build_args if arg] final_command = " \\\n".join(build_command + build_args) # Execute command with improved error handling try: print(f"Engine build command:{fore('yellow')}\n{final_command}\n{style('reset')}") subprocess.run(final_command, check=True, shell=True) except subprocess.CalledProcessError as exc: error_msg = f"Failed to build TensorRT engine. Error details:\nCommand: {exc.cmd}\n" raise RuntimeError(error_msg) from exc @classmethod @abstractmethod def from_args(cls, model_name: str, *args, **kwargs) -> Any: raise NotImplementedError("Factory method is missing") @abstractmethod def get_input_profile( self, batch_size: int, image_height: int | None, image_width: int | None, ) -> dict[str, Any]: """ Generate max and min shape that each input of a TRT engine can have. Subclasses must implement this method to return a dictionary that defines the input profile based on the provided parameters. The input profile typically includes details such as the expected shape of input tensors, whether the batch size or image dimensions are fixed, and any additional configuration required by the data processing or model inference pipeline. Args: batch_size (int): The number of images per batch. image_height (int): Default height of each image in pixels. image_width (int): Defailt width of each image in pixels. static_batch (bool): Flag indicating if the batch size is fixed (static). static_shape (bool): Flag indicating if the image dimensions are fixed (static). Returns: dict[str, Any]: A dictionary representing the input profile configuration. Raises: NotImplementedError: If the subclass does not override this abstract method. """ pass @abstractmethod def check_dims(self, *args, **kwargs) -> None | tuple[int, int] | int: """helper function that check the dimentions associated to each input of a TRT engine""" pass def _check_batch(self, batch_size): assert ( self.min_batch <= batch_size <= self.max_batch ), f"Batch size {batch_size} must be between {self.min_batch} and {self.max_batch}" def __post_init__(self): self.onnx_path = self._get_onnx_path() self.engine_path = self._get_engine_path() assert os.path.isfile(self.onnx_path), "onnx_path do not exists: {}".format(self.onnx_path) def _get_onnx_path(self) -> str: if self.custom_onnx_path: return self.custom_onnx_path repo_id = self._get_repo_id(self.model_name) snapshot_path = snapshot_download(repo_id, allow_patterns=[f"{self.module_name.value}.opt/*"]) onnx_model_path = os.path.join(snapshot_path, f"{self.module_name.value}.opt/model.onnx") return onnx_model_path def _get_engine_path(self) -> str: return os.path.join( self.engine_dir, self.model_name, f"{self.module_name.value}_{self.precision}.trt_{trt_version}.plan", ) @staticmethod def _get_repo_id(model_name: str) -> str: if model_name == "flux-dev": return "black-forest-labs/FLUX.1-dev-onnx" elif model_name == "flux-schnell": return "black-forest-labs/FLUX.1-schnell-onnx" elif model_name == "flux-dev-canny": return "black-forest-labs/FLUX.1-Canny-dev-onnx" elif model_name == "flux-dev-depth": return "black-forest-labs/FLUX.1-Depth-dev-onnx" elif model_name == "flux-dev-kontext": return "black-forest-labs/FLUX.1-Kontext-dev-onnx" else: raise ValueError(f"Unknown model name: {model_name}") def register_config(module_name: ModuleName, precision: str): """Decorator to register a configuration class with specific flag conditions.""" def decorator(cls): key = f"module={module_name.value}_dtype={precision}" registry[key] = cls return cls return decorator def get_config(module_name: ModuleName, precision: str) -> TRTBaseConfig: """Retrieve the appropriate configuration instance based on current flags.""" key = f"module={module_name.value}_dtype={precision}" return registry[key] ================================================ FILE: src/flux/trt/trt_config/clip_trt_config.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from flux.trt.trt_config.base_trt_config import ModuleName, TRTBaseConfig, register_config from flux.util import configs @register_config(module_name=ModuleName.CLIP, precision="bf16") @dataclass class ClipConfig(TRTBaseConfig): text_maxlen: int | None = None hidden_size: int | None = None trt_tf32: bool = True trt_bf16: bool = False trt_fp8: bool = False trt_fp4: bool = False trt_build_strongly_typed: bool = True @classmethod def from_args( cls, model_name: str, **kwargs, ): return cls( text_maxlen=77, hidden_size=configs[model_name].params.vec_in_dim, model_name=model_name, module_name=ModuleName.CLIP, **kwargs, ) def check_dims(self, batch_size: int) -> None: self._check_batch(batch_size) def get_input_profile( self, batch_size: int, image_height=None, image_width=None, ): min_batch = batch_size if self.trt_static_batch else self.min_batch max_batch = batch_size if self.trt_static_batch else self.max_batch self.check_dims(batch_size) return { "input_ids": [ (min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen), ] } ================================================ FILE: src/flux/trt/trt_config/t5_trt_config.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from dataclasses import dataclass from huggingface_hub import snapshot_download from flux.trt.trt_config.base_trt_config import ModuleName, TRTBaseConfig, register_config from flux.util import configs @register_config(module_name=ModuleName.T5, precision="bf16") @register_config(module_name=ModuleName.T5, precision="fp8") @dataclass class T5Config(TRTBaseConfig): text_maxlen: int | None = None hidden_size: int | None = None trt_tf32: bool = True trt_bf16: bool = False trt_fp8: bool = False trt_fp4: bool = False trt_build_strongly_typed: bool = True @classmethod def from_args( cls, model_name: str, **kwargs, ): return cls( text_maxlen=256 if model_name == "flux-schnell" else 512, hidden_size=configs[model_name].params.context_in_dim, model_name=model_name, module_name=ModuleName.T5, **kwargs, ) def check_dims(self, batch_size: int) -> None: self._check_batch(batch_size) def get_input_profile( self, batch_size: int, image_height=None, image_width=None, ): min_batch = batch_size if self.trt_static_batch else self.min_batch max_batch = batch_size if self.trt_static_batch else self.max_batch self.check_dims(batch_size) return { "input_ids": [ (min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen), ] } def _get_onnx_path(self) -> str: if self.custom_onnx_path: return self.custom_onnx_path if self.precision == "fp8": repo_id = self._get_repo_id(self.model_name) snapshot_path = snapshot_download(repo_id, allow_patterns=["t5-fp8.opt/*"]) onnx_model_path = os.path.join(snapshot_path, "t5-fp8.opt/model.onnx") return onnx_model_path else: return super()._get_onnx_path() ================================================ FILE: src/flux/trt/trt_config/transformer_trt_config.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import warnings from dataclasses import dataclass, field from math import ceil from huggingface_hub import snapshot_download from flux.trt.trt_config.base_trt_config import ModuleName, TRTBaseConfig, register_config from flux.util import PREFERED_KONTEXT_RESOLUTIONS, configs @register_config(module_name=ModuleName.TRANSFORMER, precision="bf16") @register_config(module_name=ModuleName.TRANSFORMER, precision="fp8") @register_config(module_name=ModuleName.TRANSFORMER, precision="fp4") @dataclass class TransformerConfig(TRTBaseConfig): guidance_embed: bool | None = None vec_in_dim: int | None = None context_in_dim: int | None = None in_channels: int | None = None out_channels: int | None = None min_image_shape: int | None = None max_image_shape: int | None = None default_image_shape: int = 1024 compression_factor: int = 8 text_maxlen: int | None = None min_latent_dim: int = field(init=False) max_latent_dim: int = field(init=False) min_context_latent_dim: int = field(init=False) max_context_latent_dim: int = field(init=False) trt_tf32: bool = True trt_bf16: bool = False trt_fp8: bool = False trt_fp4: bool = False trt_build_strongly_typed: bool = True @classmethod def from_args( cls, model_name, **kwargs, ): if model_name == "flux-dev-kontext" and kwargs["trt_static_shape"]: warnings.warn("Flux-dev-Kontext does not support static shapes for the encoder.") kwargs["trt_static_shape"] = False if model_name == "flux-dev-kontext": min_image_shape = 1008 max_image_shape = 1040 else: min_image_shape = 768 max_image_shape = 1360 return cls( model_name=model_name, module_name=ModuleName.TRANSFORMER, guidance_embed=configs[model_name].params.guidance_embed, vec_in_dim=configs[model_name].params.vec_in_dim, context_in_dim=configs[model_name].params.context_in_dim, in_channels=configs[model_name].params.in_channels, out_channels=configs[model_name].params.out_channels, text_maxlen=256 if model_name == "flux-schnell" else 512, min_image_shape=min_image_shape, max_image_shape=max_image_shape, **kwargs, ) def _get_onnx_path(self) -> str: if self.custom_onnx_path: return self.custom_onnx_path repo_id = self._get_repo_id(self.model_name) typed_model_path = os.path.join(f"{self.module_name.value}.opt", self.precision) snapshot_path = snapshot_download(repo_id, allow_patterns=[f"{typed_model_path}/*"]) onnx_model_path = os.path.join(snapshot_path, typed_model_path, "model.onnx") return onnx_model_path @staticmethod def _get_latent(image_dim: int, compression_factor: int) -> int: return ceil(image_dim / (2 * compression_factor)) @staticmethod def _get_context_dim( image_height: int, image_width: int, compression_factor: int, ) -> int: seq_len = TransformerConfig._get_latent( image_dim=image_height, compression_factor=compression_factor, ) * TransformerConfig._get_latent( image_dim=image_width, compression_factor=compression_factor, ) return seq_len def __post_init__(self): min_latent_dim = TransformerConfig._get_context_dim( image_height=self.min_image_shape, image_width=self.min_image_shape, compression_factor=self.compression_factor, ) max_latent_dim = TransformerConfig._get_context_dim( image_height=self.max_image_shape, image_width=self.max_image_shape, compression_factor=self.compression_factor, ) if self.model_name == "flux-dev-kontext": # get min context size _, min_context_height, min_context_width = min( (w * h, w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS ) self.min_context_latent_dim = TransformerConfig._get_context_dim( image_height=min_context_height, image_width=min_context_width, compression_factor=self.compression_factor, ) # get max context size _, max_context_height, max_context_width = max( (w * h, w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS ) self.max_context_latent_dim = TransformerConfig._get_context_dim( image_height=max_context_height, image_width=max_context_width, compression_factor=self.compression_factor, ) else: self.min_context_latent_dim = 0 self.max_context_latent_dim = 0 self.min_latent_dim = min_latent_dim + self.min_context_latent_dim self.max_latent_dim = max_latent_dim + self.max_context_latent_dim super().__post_init__() def get_minmax_dims( self, batch_size: int, image_height: int, image_width: int, ): min_batch = batch_size if self.trt_static_batch else self.min_batch max_batch = batch_size if self.trt_static_batch else self.max_batch # if a model has context: it is always dynamic. target image can be static # or dynamic for every-model min_latent_dim = ( self._get_context_dim( image_height=image_height, image_width=image_width, compression_factor=self.compression_factor, ) + self.min_context_latent_dim ) max_latent_dim = ( self._get_context_dim( image_height=image_height, image_width=image_width, compression_factor=self.compression_factor, ) + self.max_context_latent_dim ) # static-shape affects only the target image size min_latent_dim = min_latent_dim if self.trt_static_shape else self.min_latent_dim max_latent_dim = max_latent_dim if self.trt_static_shape else self.max_latent_dim return (min_batch, max_batch, min_latent_dim, max_latent_dim) def check_dims( self, batch_size: int, image_height: int, image_width: int, ) -> int: self._check_batch(batch_size) assert ( image_height % self.compression_factor == 0 or image_width % self.compression_factor == 0 ), f"Image dimensions must be divisible by compression factor {self.compression_factor}" latent_dim = self._get_context_dim( image_height=image_height, image_width=image_width, compression_factor=self.compression_factor, ) if self.model_name == "flux-dev-kontext": # for context models, it is assumed that the optimal context image shape is the same # as target image shape latent_dim = 2 * latent_dim assert self.min_latent_dim <= latent_dim <= self.max_latent_dim, "Image resolution out of boundaries." return latent_dim def get_input_profile( self, batch_size: int, image_height: int | None, image_width: int | None, ) -> dict[str, list[tuple]]: if self.model_name == "flux-dev-kontext": assert not self.trt_static_shape, "If Flux-dev-kontext then static_shape must be False." else: assert isinstance(image_height, int) and isinstance( image_width, int ), "Only Flux-dev-kontext allows None image shape" image_height = self.default_image_shape if image_height is None else image_height image_width = self.default_image_shape if image_width is None else image_width opt_latent_dim = self.check_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) ( min_batch, max_batch, min_latent_dim, max_latent_dim, ) = self.get_minmax_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) input_profile = { "hidden_states": [ (min_batch, min_latent_dim, self.in_channels), (batch_size, opt_latent_dim, self.in_channels), (max_batch, max_latent_dim, self.in_channels), ], "encoder_hidden_states": [ (min_batch, self.text_maxlen, self.context_in_dim), (batch_size, self.text_maxlen, self.context_in_dim), (max_batch, self.text_maxlen, self.context_in_dim), ], "pooled_projections": [ (min_batch, self.vec_in_dim), (batch_size, self.vec_in_dim), (max_batch, self.vec_in_dim), ], "img_ids": [ (min_latent_dim, 3), (opt_latent_dim, 3), (max_latent_dim, 3), ], "txt_ids": [ (self.text_maxlen, 3), (self.text_maxlen, 3), (self.text_maxlen, 3), ], "timestep": [(min_batch,), (batch_size,), (max_batch,)], } if self.guidance_embed: input_profile["guidance"] = [(min_batch,), (batch_size,), (max_batch,)] return input_profile ================================================ FILE: src/flux/trt/trt_config/vae_trt_config.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings from dataclasses import dataclass, field from math import ceil from flux.trt.trt_config.base_trt_config import ModuleName, TRTBaseConfig, register_config from flux.util import configs @dataclass class VAEBaseConfig(TRTBaseConfig): z_channels: int | None = None scale_factor: float | None = None shift_factor: float | None = None default_image_shape: int = 1024 compression_factor: int = 8 min_image_shape: int | None = None max_image_shape: int | None = None min_latent_shape: int = field(init=False) max_latent_shape: int = field(init=False) def _get_latent_dim(self, image_dim: int) -> int: return 2 * ceil(image_dim / (2 * self.compression_factor)) def __post_init__(self): self.min_latent_shape = self._get_latent_dim(self.min_image_shape) self.max_latent_shape = self._get_latent_dim(self.max_image_shape) super().__post_init__() def check_dims( self, batch_size: int, image_height: int, image_width: int, ) -> tuple[int, int]: self._check_batch(batch_size) assert ( image_height % self.compression_factor == 0 or image_width % self.compression_factor == 0 ), f"Image dimensions must be divisible by compression factor {self.compression_factor}" latent_height = self._get_latent_dim(image_height) latent_width = self._get_latent_dim(image_width) assert ( self.min_latent_shape <= latent_height <= self.max_latent_shape ), f"Latent height {latent_height} must be between {self.min_latent_shape} and {self.max_latent_shape}" assert ( self.min_latent_shape <= latent_width <= self.max_latent_shape ), f"Latent width {latent_width} must be between {self.min_latent_shape} and {self.max_latent_shape}" return latent_height, latent_width @register_config(module_name=ModuleName.VAE, precision="bf16") @dataclass class VAEDecoderConfig(VAEBaseConfig): trt_tf32: bool = True trt_bf16: bool = True trt_fp8: bool = False trt_fp4: bool = False trt_build_strongly_typed: bool = False @classmethod def from_args( cls, model_name: str, **kwargs, ): if model_name == "flux-dev-kontext": min_image_shape = 672 max_image_shape = 1568 else: min_image_shape = 768 max_image_shape = 1360 return cls( model_name=model_name, module_name=ModuleName.VAE, z_channels=configs[model_name].ae_params.z_channels, scale_factor=configs[model_name].ae_params.scale_factor, shift_factor=configs[model_name].ae_params.shift_factor, min_image_shape=min_image_shape, max_image_shape=max_image_shape, **kwargs, ) def get_minmax_dims( self, batch_size: int, image_height: int, image_width: int, ): min_batch = batch_size if self.trt_static_batch else self.min_batch max_batch = batch_size if self.trt_static_batch else self.max_batch latent_height = self._get_latent_dim(image_height) latent_width = self._get_latent_dim(image_width) min_latent_height = latent_height if self.trt_static_shape else self.min_latent_shape max_latent_height = latent_height if self.trt_static_shape else self.max_latent_shape min_latent_width = latent_width if self.trt_static_shape else self.min_latent_shape max_latent_width = latent_width if self.trt_static_shape else self.max_latent_shape return ( min_batch, max_batch, min_latent_height, max_latent_height, min_latent_width, max_latent_width, ) def get_input_profile( self, batch_size: int, image_height: int | None, image_width: int | None, ): assert self.model_name == "flux-dev-kontext" or ( image_height is not None and image_width is not None ), "Only Flux-dev-kontext allows None image shape" assert not self.trt_static_shape or ( image_height is not None and image_width is not None ), "If static_shape is True, image_height and image_width must be not None" image_height = self.default_image_shape if image_height is None else image_height image_width = self.default_image_shape if image_width is None else image_width latent_height, latent_width = self.check_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) ( min_batch, max_batch, min_latent_height, max_latent_height, min_latent_width, max_latent_width, ) = self.get_minmax_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) return { "latent": [ (min_batch, self.z_channels, min_latent_height, min_latent_width), (batch_size, self.z_channels, latent_height, latent_width), (max_batch, self.z_channels, max_latent_height, max_latent_width), ] } @register_config(module_name=ModuleName.VAE_ENCODER, precision="bf16") @dataclass class VAEEncoderConfig(VAEBaseConfig): trt_tf32: bool = True trt_bf16: bool = True trt_fp8: bool = False trt_fp4: bool = False trt_build_strongly_typed: bool = False @classmethod def from_args(cls, model_name: str, **kwargs): if model_name == "flux-dev-kontext" and kwargs["trt_static_shape"]: warnings.warn("Flux-dev-Kontext does not support static shapes for the encoder.") kwargs["trt_static_shape"] = False if model_name == "flux-dev-kontext": min_image_shape = 672 max_image_shape = 1568 else: min_image_shape = 768 max_image_shape = 1360 return cls( model_name=model_name, module_name=ModuleName.VAE_ENCODER, z_channels=configs[model_name].ae_params.z_channels, scale_factor=configs[model_name].ae_params.scale_factor, shift_factor=configs[model_name].ae_params.shift_factor, min_image_shape=min_image_shape, max_image_shape=max_image_shape, **kwargs, ) def get_minmax_dims( self, batch_size: int, image_height: int, image_width: int, ): min_batch = batch_size if self.trt_static_batch else self.min_batch max_batch = batch_size if self.trt_static_batch else self.max_batch min_image_height = image_height if self.trt_static_shape else self.min_image_shape max_image_height = image_height if self.trt_static_shape else self.max_image_shape min_image_width = image_width if self.trt_static_shape else self.min_image_shape max_image_width = image_width if self.trt_static_shape else self.max_image_shape return ( min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, ) def get_input_profile( self, batch_size: int, image_height: int | None, image_width: int | None, ): if self.model_name == "flux-dev-kontext": assert ( not self.trt_static_shape ), "Flux-dev-kontext does not support dynamic shapes for the encoder." else: assert isinstance(image_height, int) and isinstance( image_width, int ), "Only Flux-dev-kontext allows None image shape" image_height = self.default_image_shape if image_height is None else image_height image_width = self.default_image_shape if image_width is None else image_width self.check_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) ( min_batch, max_batch, min_image_height, max_image_height, min_image_width, max_image_width, ) = self.get_minmax_dims( batch_size=batch_size, image_height=image_height, image_width=image_width, ) return { "images": [ (min_batch, 3, min_image_height, min_image_width), (batch_size, 3, image_height, image_width), (max_batch, 3, max_image_height, max_image_width), ], } ================================================ FILE: src/flux/trt/trt_manager.py ================================================ # # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import os import sys import warnings import tensorrt as trt import torch from flux.trt.engine import ( BaseEngine, CLIPEngine, Engine, SharedMemory, T5Engine, TransformerEngine, VAEDecoder, VAEEncoder, VAEEngine, ) from flux.trt.trt_config import ( ModuleName, TRTBaseConfig, get_config, ) TRT_LOGGER = trt.Logger() VALID_TRANSFORMER_PRECISIONS = {"bf16", "fp8", "fp4", "fp4_svd32"} VALID_T5_PRECISIONS = {"bf16", "fp8"} class TRTManager: @property def module_to_engine_class(self) -> dict[ModuleName, type[Engine]]: return { ModuleName.CLIP: CLIPEngine, ModuleName.TRANSFORMER: TransformerEngine, ModuleName.T5: T5Engine, ModuleName.VAE: VAEDecoder, ModuleName.VAE_ENCODER: VAEEncoder, } def __init__( self, trt_transformer_precision: str, trt_t5_precision: str, max_batch=2, verbose=False, ): self.max_batch = max_batch self.precisions = self._parse_models_precisions( trt_transformer_precision=trt_transformer_precision, trt_t5_precision=trt_t5_precision, ) self.verbose = verbose self.runtime: trt.Runtime = None self.device_memory = SharedMemory(1024) assert torch.cuda.is_available(), "No cuda device available" @staticmethod def _parse_models_precisions( trt_transformer_precision: str, trt_t5_precision: str ) -> dict[ModuleName, str]: precisions = { ModuleName.CLIP: "bf16", ModuleName.VAE: "bf16", ModuleName.VAE_ENCODER: "bf16", } assert ( trt_transformer_precision in VALID_TRANSFORMER_PRECISIONS ), f"Invalid precision for flux-transformer `{trt_transformer_precision}`. Possible value are {VALID_TRANSFORMER_PRECISIONS}" precisions[ModuleName.TRANSFORMER] = ( trt_transformer_precision if trt_transformer_precision != "fp4_svd32" else "fp4" ) assert ( trt_t5_precision in VALID_T5_PRECISIONS ), f"Invalid precision for T5 `{trt_t5_precision}`. Possible value are {VALID_T5_PRECISIONS}" precisions[ModuleName.T5] = trt_t5_precision return precisions @staticmethod def _parse_custom_onnx_path(custom_onnx_paths: str) -> dict[ModuleName, str]: """Parse a string of comma-separated key-value pairs into a dictionary. Args: string (str): A string of comma-separated key-value pairs. Returns: Dict[str, str]: Parsed dictionary of key-value pairs. Example: >>> parse_key_value_pairs("key1:value1,key2:value2") {"key1": "value1", "key2": "value2"} """ parsed = {} for key_value_pair in custom_onnx_paths.split(","): if not key_value_pair: continue key_value_pair = key_value_pair.split(":") if len(key_value_pair) != 2: raise ValueError(f"Invalid key-value pair: {key_value_pair}. Must have length 2.") key, value = key_value_pair key = ModuleName(key) parsed[key] = value return parsed @staticmethod def _create_directories(engine_dir: str): print(f"[I] Create directory: {engine_dir} if not existing") os.makedirs(engine_dir, exist_ok=True) def _get_trt_configs( self, model_name: str, module_names: set[ModuleName], engine_dir: str, custom_onnx_paths: dict[ModuleName, str], trt_static_batch: bool, trt_static_shape: bool, trt_enable_all_tactics: bool, trt_timing_cache: str | None, trt_native_instancenorm: bool, trt_builder_optimization_level: int, trt_precision_constraints: str, **kwargs, ) -> dict[ModuleName, TRTBaseConfig]: trt_configs = {} for module_name in module_names: config_cls = get_config(module_name=module_name, precision=self.precisions[module_name]) custom_onnx_path = custom_onnx_paths.get(module_name, None) trt_config = config_cls.from_args( model_name=model_name, max_batch=self.max_batch, custom_onnx_path=custom_onnx_path, engine_dir=engine_dir, trt_verbose=self.verbose, precision=self.precisions[module_name], trt_static_batch=trt_static_batch, trt_static_shape=trt_static_shape, trt_enable_all_tactics=trt_enable_all_tactics, trt_timing_cache=trt_timing_cache, trt_native_instancenorm=trt_native_instancenorm, trt_builder_optimization_level=trt_builder_optimization_level, trt_precision_constraints=trt_precision_constraints, **kwargs, ) trt_configs[module_name] = trt_config if ModuleName.TRANSFORMER in trt_configs and ModuleName.T5 in trt_configs: trt_configs[ModuleName.TRANSFORMER].text_maxlen = trt_configs[ModuleName.T5].text_maxlen else: warnings.warn("`text_maxlen` attribute of flux-trasformer is not update. Default value is used.") return trt_configs @staticmethod def _build_engine( trt_config: TRTBaseConfig, batch_size: int, image_height: int | None, image_width: int | None, ): already_build = os.path.exists(trt_config.engine_path) if already_build: return trt_config.build_trt_engine( engine_path=trt_config.engine_path, onnx_path=trt_config.onnx_path, strongly_typed=trt_config.trt_build_strongly_typed, tf32=trt_config.trt_tf32, bf16=trt_config.trt_bf16, fp8=trt_config.trt_fp8, fp4=trt_config.trt_fp4, input_profile=trt_config.get_input_profile( batch_size=batch_size, image_height=image_height, image_width=image_width, ), enable_all_tactics=trt_config.trt_enable_all_tactics, timing_cache=trt_config.trt_timing_cache, update_output_names=trt_config.trt_update_output_names, builder_optimization_level=trt_config.trt_builder_optimization_level, verbose=trt_config.trt_verbose, ) TRTManager._clean_memory() def load_engines( self, model_name: str, module_names: set[ModuleName], engine_dir: str, trt_image_height: int | None, trt_image_width: int | None, trt_batch_size=1, trt_static_batch=True, trt_static_shape=True, trt_enable_all_tactics=False, trt_timing_cache: str | None = None, trt_native_instancenorm=True, trt_builder_optimization_level=3, trt_precision_constraints="none", custom_onnx_paths="", **kwargs, ) -> dict[ModuleName, BaseEngine]: TRTManager._clean_memory() TRTManager._create_directories(engine_dir) custom_onnx_paths = TRTManager._parse_custom_onnx_path(custom_onnx_paths) trt_configs = self._get_trt_configs( model_name, module_names, engine_dir=engine_dir, custom_onnx_paths=custom_onnx_paths, trt_static_batch=trt_static_batch, trt_static_shape=trt_static_shape, trt_enable_all_tactics=trt_enable_all_tactics, trt_timing_cache=trt_timing_cache, trt_native_instancenorm=trt_native_instancenorm, trt_builder_optimization_level=trt_builder_optimization_level, trt_precision_constraints=trt_precision_constraints, **kwargs, ) # Build TRT engines for module_name, trt_config in trt_configs.items(): self._build_engine( trt_config=trt_config, batch_size=trt_batch_size, image_height=trt_image_height, image_width=trt_image_width, ) self.init_runtime() # load TRT engines engines = {} for module_name, trt_config in trt_configs.items(): engine_class = self.module_to_engine_class[module_name] engine = engine_class( trt_config=trt_config, stream=self.stream, context_memory=self.device_memory, allocation_policy=os.getenv("TRT_ALLOCATION_POLICY", "global"), ) engines[module_name] = engine if ModuleName.VAE in engines: engines[ModuleName.VAE] = VAEEngine( decoder=engines.pop(ModuleName.VAE), encoder=engines.pop(ModuleName.VAE_ENCODER, None), ) self._clean_memory() return engines @staticmethod def _clean_memory(): gc.collect() torch.cuda.empty_cache() def init_runtime(self): print("[I] Init TRT runtime") self.runtime = trt.Runtime(TRT_LOGGER) enter_fn = type(self.runtime).__enter__ enter_fn(self.runtime) self.stream = torch.cuda.current_stream() def stop_runtime(self): exit_fn = type(self.runtime).__exit__ exit_fn(self.runtime, *sys.exc_info()) del self.stream del self.device_memory print("[I] Stop TRT runtime") ================================================ FILE: src/flux/util.py ================================================ import getpass import math import os from dataclasses import dataclass from pathlib import Path import requests import torch from einops import rearrange from huggingface_hub import hf_hub_download, login from imwatermark import WatermarkEncoder from PIL import ExifTags, Image from safetensors.torch import load_file as load_sft from flux.model import Flux, FluxLoraWrapper, FluxParams from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams from flux.modules.conditioner import HFEmbedder CHECKPOINTS_DIR = Path("checkpoints") CHECKPOINTS_DIR.mkdir(exist_ok=True) BFL_API_KEY = os.getenv("BFL_API_KEY") os.environ.setdefault("TRT_ENGINE_DIR", str(CHECKPOINTS_DIR / "trt_engines")) (CHECKPOINTS_DIR / "trt_engines").mkdir(exist_ok=True) def ensure_hf_auth(): hf_token = os.environ.get("HF_TOKEN") if hf_token: print("Trying to authenticate to HuggingFace with the HF_TOKEN environment variable.") try: login(token=hf_token) print("Successfully authenticated with HuggingFace using HF_TOKEN") return True except Exception as e: print(f"Warning: Failed to authenticate with HF_TOKEN: {e}") if os.path.exists(os.path.expanduser("~/.cache/huggingface/token")): print("Already authenticated with HuggingFace") return True return False def prompt_for_hf_auth(): try: token = getpass.getpass("HF Token (hidden input): ").strip() if not token: print("No token provided. Aborting.") return False login(token=token) print("Successfully authenticated!") return True except KeyboardInterrupt: print("\nAuthentication cancelled by user.") return False except Exception as auth_e: print(f"Authentication failed: {auth_e}") print("Tip: You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") return False def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path: """Get the local path for a checkpoint file, downloading if necessary.""" if os.environ.get(env_var) is not None: local_path = os.environ[env_var] if os.path.exists(local_path): return Path(local_path) print( f"Trying to load model {repo_id}, {filename} from environment " f"variable {env_var}. But file {local_path} does not exist. " "Falling back to default location." ) # Create a safe directory name from repo_id safe_repo_name = repo_id.replace("/", "_") checkpoint_dir = CHECKPOINTS_DIR / safe_repo_name checkpoint_dir.mkdir(exist_ok=True) local_path = checkpoint_dir / filename if not local_path.exists(): print(f"Downloading {filename} from {repo_id} to {local_path}") try: ensure_hf_auth() hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) except Exception as e: if "gated repo" in str(e).lower() or "restricted" in str(e).lower(): print(f"\nError: Cannot access {repo_id} -- this is a gated repository.") # Try one more time to authenticate if prompt_for_hf_auth(): # Retry the download after authentication print("Retrying download...") hf_hub_download(repo_id=repo_id, filename=filename, local_dir=checkpoint_dir) else: print("Authentication failed or cancelled.") print("You can also run 'huggingface-cli login' or set HF_TOKEN environment variable") raise RuntimeError(f"Authentication required for {repo_id}") else: raise e return local_path def download_onnx_models_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: """Download ONNX models for TRT to our checkpoints directory""" onnx_repo_map = { "flux-dev": "black-forest-labs/FLUX.1-dev-onnx", "flux-schnell": "black-forest-labs/FLUX.1-schnell-onnx", "flux-dev-canny": "black-forest-labs/FLUX.1-Canny-dev-onnx", "flux-dev-depth": "black-forest-labs/FLUX.1-Depth-dev-onnx", "flux-dev-redux": "black-forest-labs/FLUX.1-Redux-dev-onnx", "flux-dev-fill": "black-forest-labs/FLUX.1-Fill-dev-onnx", "flux-dev-kontext": "black-forest-labs/FLUX.1-Kontext-dev-onnx", } if model_name not in onnx_repo_map: return None # No ONNX repository required for this model repo_id = onnx_repo_map[model_name] safe_repo_name = repo_id.replace("/", "_") onnx_dir = CHECKPOINTS_DIR / safe_repo_name # Map of module names to their ONNX file paths (using specified precision) onnx_file_map = { "clip": "clip.opt/model.onnx", "transformer": f"transformer.opt/{trt_transformer_precision}/model.onnx", "transformer_data": f"transformer.opt/{trt_transformer_precision}/backbone.onnx_data", "t5": "t5.opt/model.onnx", "t5_data": "t5.opt/backbone.onnx_data", "vae": "vae.opt/model.onnx", } # If all files exist locally, return the custom_onnx_paths format if onnx_dir.exists(): all_files_exist = True custom_paths = [] for module, onnx_file in onnx_file_map.items(): if module.endswith("_data"): continue # Skip data files local_path = onnx_dir / onnx_file if not local_path.exists(): all_files_exist = False break custom_paths.append(f"{module}:{local_path}") if all_files_exist: print(f"ONNX models ready in {onnx_dir}") return ",".join(custom_paths) # If not all files exist, download them print(f"Downloading ONNX models from {repo_id} to {onnx_dir}") print(f"Using transformer precision: {trt_transformer_precision}") onnx_dir.mkdir(exist_ok=True) # Download all ONNX files for module, onnx_file in onnx_file_map.items(): local_path = onnx_dir / onnx_file if local_path.exists(): continue # Already downloaded # Create parent directories local_path.parent.mkdir(parents=True, exist_ok=True) try: print(f"Downloading {onnx_file}") hf_hub_download(repo_id=repo_id, filename=onnx_file, local_dir=onnx_dir) except Exception as e: if "does not exist" in str(e).lower() or "not found" in str(e).lower(): continue elif "gated repo" in str(e).lower() or "restricted" in str(e).lower(): print(f"Cannot access {repo_id} - requires license acceptance") print("Please follow these steps:") print(f" 1. Visit: https://huggingface.co/{repo_id}") print(" 2. Log in to your HuggingFace account") print(" 3. Accept the license terms and conditions") print(" 4. Then retry this command") raise RuntimeError(f"License acceptance required for {model_name}") else: # Re-raise other errors raise print(f"ONNX models ready in {onnx_dir}") # Return the custom_onnx_paths format that TRT expects: "module1:path1,module2:path2" # Note: Only return the actual module paths, not the data file custom_paths = [] for module, onnx_file in onnx_file_map.items(): if module.endswith("_data"): continue # Skip the data file in the return paths full_path = onnx_dir / onnx_file if full_path.exists(): custom_paths.append(f"{module}:{full_path}") return ",".join(custom_paths) def check_onnx_access_for_trt(model_name: str, trt_transformer_precision: str = "bf16") -> str | None: """Check ONNX access and download models for TRT - returns ONNX directory path""" return download_onnx_models_for_trt(model_name, trt_transformer_precision) def track_usage_via_api(name: str, n=1) -> None: """ Track usage of licensed models via the BFL API for commercial licensing compliance. For more information on licensing BFL's models for commercial use and usage reporting, see the README.md or visit: https://dashboard.bfl.ai/licensing/subscriptions?showInstructions=true """ assert BFL_API_KEY is not None, "BFL_API_KEY is not set" model_slug_map = { "flux-dev": "flux-1-dev", "flux-dev-kontext": "flux-1-kontext-dev", "flux-dev-fill": "flux-tools", "flux-dev-depth": "flux-tools", "flux-dev-canny": "flux-tools", "flux-dev-canny-lora": "flux-tools", "flux-dev-depth-lora": "flux-tools", "flux-dev-redux": "flux-tools", "flux-dev-krea": "flux-1-krea-dev", } if name not in model_slug_map: print(f"Skipping tracking usage for {name}, as it cannot be tracked. Please check the model name.") return model_slug = model_slug_map[name] url = f"https://api.bfl.ai/v1/licenses/models/{model_slug}/usage" headers = {"x-key": BFL_API_KEY, "Content-Type": "application/json"} payload = {"number_of_generations": n} response = requests.post(url, headers=headers, json=payload) if response.status_code != 200: raise Exception(f"Failed to track usage: {response.status_code} {response.text}") else: print(f"Successfully tracked usage for {name} with {n} generations") def save_image( nsfw_classifier, name: str, output_name: str, idx: int, x: torch.Tensor, add_sampling_metadata: bool, prompt: str, nsfw_threshold: float = 0.85, track_usage: bool = False, ) -> int: fn = output_name.format(idx=idx) print(f"Saving {fn}") # bring into PIL format and save x = x.clamp(-1, 1) x = embed_watermark(x.float()) x = rearrange(x[0], "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) if nsfw_classifier is not None: nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0] else: nsfw_score = nsfw_threshold - 1.0 if nsfw_score < nsfw_threshold: exif_data = Image.Exif() if name in ["flux-dev", "flux-schnell"]: exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" else: exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux" exif_data[ExifTags.Base.Make] = "Black Forest Labs" exif_data[ExifTags.Base.Model] = name if add_sampling_metadata: exif_data[ExifTags.Base.ImageDescription] = prompt img.save(fn, exif=exif_data, quality=95, subsampling=0) if track_usage: track_usage_via_api(name, 1) idx += 1 else: print("Your generated image may contain NSFW content.") return idx @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams repo_id: str repo_flow: str repo_ae: str lora_repo_id: str | None = None lora_filename: str | None = None configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-krea": ModelSpec( repo_id="black-forest-labs/FLUX.1-Krea-dev", repo_flow="flux1-krea-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-schnell": ModelSpec( repo_id="black-forest-labs/FLUX.1-schnell", repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=False, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-canny": ModelSpec( repo_id="black-forest-labs/FLUX.1-Canny-dev", repo_flow="flux1-canny-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=128, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-canny-lora": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", lora_repo_id="black-forest-labs/FLUX.1-Canny-dev-lora", lora_filename="flux1-canny-dev-lora.safetensors", params=FluxParams( in_channels=128, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-depth": ModelSpec( repo_id="black-forest-labs/FLUX.1-Depth-dev", repo_flow="flux1-depth-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=128, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-depth-lora": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", lora_repo_id="black-forest-labs/FLUX.1-Depth-dev-lora", lora_filename="flux1-depth-dev-lora.safetensors", params=FluxParams( in_channels=128, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-redux": ModelSpec( repo_id="black-forest-labs/FLUX.1-Redux-dev", repo_flow="flux1-redux-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-fill": ModelSpec( repo_id="black-forest-labs/FLUX.1-Fill-dev", repo_flow="flux1-fill-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=384, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-dev-kontext": ModelSpec( repo_id="black-forest-labs/FLUX.1-Kontext-dev", repo_flow="flux1-kontext-dev.safetensors", repo_ae="ae.safetensors", params=FluxParams( in_channels=64, out_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), } PREFERED_KONTEXT_RESOLUTIONS = [ (672, 1568), (688, 1504), (720, 1456), (752, 1392), (800, 1328), (832, 1248), (880, 1184), (944, 1104), (1024, 1024), (1104, 944), (1184, 880), (1248, 832), (1328, 800), (1392, 752), (1456, 720), (1504, 688), (1568, 672), ] def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2) -> tuple[int, int]: width = float(aspect_ratio.split(":")[0]) height = float(aspect_ratio.split(":")[1]) ratio = width / height width = round(math.sqrt(area * ratio)) height = round(math.sqrt(area / ratio)) return 16 * (width // 16), 16 * (height // 16) def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_flow_model(name: str, device: str | torch.device = "cuda", verbose: bool = True) -> Flux: # Loading Flux print("Init model") config = configs[name] ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_flow, "FLUX_MODEL")) with torch.device("meta"): if config.lora_repo_id is not None and config.lora_filename is not None: model = FluxLoraWrapper(params=config.params).to(torch.bfloat16) else: model = Flux(config.params).to(torch.bfloat16) print(f"Loading checkpoint: {ckpt_path}") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) sd = optionally_expand_state_dict(model, sd) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) if verbose: print_load_warning(missing, unexpected) if config.lora_repo_id is not None and config.lora_filename is not None: print("Loading LoRA") lora_path = str(get_checkpoint_path(config.lora_repo_id, config.lora_filename, "FLUX_LORA")) lora_sd = load_sft(lora_path, device=str(device)) # loading the lora params + overwriting scale values in the norms missing, unexpected = model.load_state_dict(lora_sd, strict=False, assign=True) if verbose: print_load_warning(missing, unexpected) return model def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: config = configs[name] ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE")) # Loading the autoencoder print("Init AE") with torch.device("meta"): ae = AutoEncoder(config.ae_params) print(f"Loading AE checkpoint: {ckpt_path}") sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dict) -> dict: """ Optionally expand the state dict to match the model's parameters shapes. """ for name, param in model.named_parameters(): if name in state_dict: if state_dict[name].shape != param.shape: print( f"Expanding '{name}' with shape {state_dict[name].shape} to model parameter with shape {param.shape}." ) # expand with zeros: expanded_state_dict_weight = torch.zeros_like(param, device=state_dict[name].device) slices = tuple(slice(0, dim) for dim in state_dict[name].shape) expanded_state_dict_weight[slices] = state_dict[name] state_dict[name] = expanded_state_dict_weight return state_dict class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [-1, 1] Returns: same as input but watermarked """ image = 0.5 * image + 0.5 squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( image.device ) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] image = 2 * image - 1 return image # A fixed 48-bit message that was chosen at random WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]] embed_watermark = WatermarkEmbedder(WATERMARK_BITS)